In [1]:
# coding: utf-8
import torch
import torch.nn.functional as F
import numpy
import pandas
from nnsight import LanguageModel, apply

from sklearn.metrics.pairwise import cosine_distances
import matplotlib.pyplot as plt
import seaborn

In [2]:
# value-zeroing distance calculation
def calc_distances(x, y):
    # cosine similarity = normalize the vectors & multiply
    return torch.diagonal(
            1 - (
                    F.normalize(x.detach().clone().squeeze(0)[:-1], dim=-1) @ 
                    F.normalize(y.detach().clone().squeeze(0)[:-1],dim=-1).t()
                )
        )

def forwardpass (model, prompt, tokens):
    layers = model.backbone.blocks

    # run the text through the model to get the baseline to
    # calculate the differences against
    with model.trace() as tracer:
        hidden_state = []
        with tracer.invoke(prompt) as invoker:
            for layer_ix in range(32):
                layer = getattr(layers, str(layer_ix))
                if layer_ix == 0:
                    layer_input = layer.input.clone().save()
                    hidden_state.append(layer_input)
                layer_output = layer.output.save()
                hidden_state.append(layer_output)

    vz_results = []
    # value-zeroing
    with model.trace() as tracer:
        for token_ix in range(len(tokens)):
            token_vz_results = []
            with tracer.invoke(prompt) as invoker:
                for layer_ix in range(32):
                    layer = getattr(layers, str(layer_ix))
                    # set the layer input to the baseline
                    layer.input[0] = hidden_state[layer_ix][0]
                    layer_v = getattr(getattr(layer, "mlstm_layer"), "v")
                    v_output = layer_v.output
                    # set the value vector of token token_ix to 0
                    v_output[0,token_ix,:] = torch.zeros(1, 4096)
                    # get the layer output
                    layer_output = layer.output
                    # calculate the distance, and save the result
                    dists = apply(calc_distances, layer_output[0], hidden_state[layer_ix+1][0]).save()
                    token_vz_results.append(dists)
            vz_results.append(token_vz_results)
    return hidden_state, vz_results

In [3]:
model = LanguageModel("NX-AI/xLSTM-7b", device_map="cuda")


In [4]:
tokenizer = model.tokenizer

In [5]:
### prompt = "Sarah told the author about her book"
#prompt = "The clown is playing his show tonight in the circus"
#prompt = '"I am going home," he said'
#prompt = 'Elizabeth Warren (D-MA) says she disagrees with the policies'
#prompt = 'the sum of 5 and 8 is 13'
prompt = "Either you win the game, or you lose the game."


inputs = tokenizer(prompt, add_special_tokens=True)
hidden_state, vz_hidden_state = forwardpass(model, prompt, inputs['input_ids'])


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

You're using a GPTNeoXTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([1, 13, 4096]) torch.Size([1, 13, 4096])
torch.Size([

In [35]:
score_matrix = torch.stack([torch.stack([vv[:] for vv in v]) for v in vz_hidden_state]).transpose(0,1).transpose(1,2).detach().cpu()

In [36]:
score_matrix.shape

torch.Size([32, 12, 12])

In [37]:
seq_length = score_matrix.shape[-1]
print(seq_length)

12


In [38]:
vmax = numpy.round(score_matrix.max() + 0.05, decimals=1)
print(vmax)

tensor(0.2000)


In [39]:
input_words = tokenizer.convert_ids_to_tokens(inputs['input_ids'])
corr_ix = pandas.MultiIndex.from_product(
    [
        range(32),
        range(len(input_words)),
        range(len(input_words)),
    ], 
    names=['layer', 'token_hs', 'token_vz'])
corr_df = pandas.DataFrame(score_matrix.flatten(), columns=['activation_corr'], index=corr_ix).reset_index()


In [40]:
def draw_heatmap(columns, index, values, data=None, **kwargs):
    d = data.pivot(index=index, columns=columns, values=values)
    seaborn.heatmap(d, **kwargs)

# g = seaborn.FacetGrid(data=corr_df, col="layer", col_wrap=4, height=4, sharey=False, sharex=False)
cbar_ax = g.fig.add_axes([.92, .3, .02, .4])  # <-- Create a colorbar axes

g.map_dataframe(draw_heatmap, "token_hs", "token_vz", "activation_corr", 
        xticklabels=input_words,
        yticklabels=input_words,
        #annot=True, fmt=".3f",# annot_kws={'size': 'small'},
        cbar_ax=cbar_ax, vmin=0, vmax=vmax,
        cmap="Blues",
        square=True,
    )
g.fig.subplots_adjust(right=.9)  # <-- Add space so the colorbar doesn't overlap the plot