In [1]:
from common_methods import *
import torch.nn.functional as F
import torch

[nltk_data] Downloading package punkt to
[nltk_data]     /home/nikhilanand/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package wordnet to
[nltk_data]     /home/nikhilanand/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package stopwords to
[nltk_data]     /home/nikhilanand/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [2]:
inputs,plain_inputs,context_outputs,outputs = load_memotrap()

In [3]:
torch.cuda.is_available()

True

In [4]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
device = torch.device("cuda")
model.to(device)
activations = {}

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
def get_hook(layer_num):
    def hook(model,input,output):
        activations[layer_num] = output[0].detach() # not just last token, entire set of activations
    return hook

layer_list = [4,8,16,18,20,22,24,26,28,30,31,32]
for i in layer_list:
    model.model.layers[i-1].register_forward_hook(get_hook(i))

In [6]:
inputs[0]

'Write a quote that ends in the word "fur": A rag and a bone and a hank of'

The first few tokens have their activations constant while after the prompt starts to change, different prompts evoke different activations.

Pass anything into the LM Head to get the output logits and then put softmax on that.

- So I have a variable number of intoken positions but a fixed number of layers. Make a 2D matrix for each prompt (intokens x layers).
- For each (intoken, layer)-coordinate position, I will write a number that represents the probability of the target token (context_output and default_output, one plot for each).
- Do a softmax across the rows of that matrix.
- Make heatmaps and save the plots in folders for each prompt. Token word strings should be mentioned on the X-axis. Y-axis has the layer numbers.

In [7]:
out = regular_decoding(model,tokenizer,inputs[0],debug=False,max_tokens=1,show_tqdm=True, return_prob=False)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  last_token_probs = F.softmax(last_token_logits)


In [8]:
inputs[0],out

('Write a quote that ends in the word "fur": A rag and a bone and a hank of',
 'hair')

In [22]:
activations[26]

tensor([[[-0.1450, -0.3086, -0.0419,  ..., -0.0010,  0.0010,  0.0802],
         [ 0.4111, -0.5464, -0.0333,  ...,  0.0527, -0.3110,  0.4009],
         [ 0.2130, -0.0696,  0.2612,  ...,  0.0239, -0.2042,  0.4705],
         ...,
         [ 0.1298,  0.0510,  0.1823,  ..., -0.1677,  0.1282, -0.1010],
         [-0.1219, -0.0835,  0.1846,  ...,  0.1719,  0.0908, -0.1587],
         [-0.0696,  0.1281, -0.0679,  ..., -0.0950, -0.2351,  0.2556]]],
       device='cuda:0', dtype=torch.float16)

In [23]:
model.lm_head

Linear(in_features=4096, out_features=32000, bias=False)

In [24]:
activations[26].shape

torch.Size([1, 22, 4096])

In [27]:
all_tokens_activns = activations[26][0,:,:]

In [29]:
input_token_list = tokenizer.batch_decode(tokenizer(inputs[0]).input_ids)

In [36]:
len(input_token_list),input_token_list[5]

(22, 'ends')

In [78]:
wordC = "fur"
wordM = "hair"
wordC_tok,wordM_tok = tokenizer(wordC).input_ids[1],tokenizer(wordM).input_ids[1]
print(wordC_tok,wordM_tok)
tokenizer.batch_decode([2982,3691])

2982 3691


['fur', 'hair']

In [83]:
matrixC = torch.zeros((len(layer_list),len(input_token_list)))
matrixM = torch.zeros((len(layer_list),len(input_token_list)))

for l in range(len(layer_list)):
    for i in range(len(input_token_list)):
        layer = layer_list[l]
        act = F.softmax(model.lm_head(model.model.norm(activations[layer][0,i,:])),dim=0)
        probC = act[wordC_tok]
        probM = act[wordM_tok]
        
        matrixC[l,i] = probC
        matrixM[l,i] = probM

In [84]:
matrixC

tensor([[1.5259e-05, 5.0366e-05, 2.7716e-05, 7.6890e-06, 8.8274e-05, 1.5020e-05,
         2.2840e-04, 8.3148e-05, 3.2425e-05, 2.3735e-04, 8.2731e-04, 5.0449e-04,
         3.1829e-05, 1.1128e-04, 6.5327e-05, 1.0443e-04, 1.5080e-05, 1.3351e-05,
         8.6963e-05, 8.7798e-05, 3.4070e-04, 2.9981e-05],
        [1.5199e-05, 2.8396e-04, 3.0398e-05, 7.0930e-06, 7.4983e-05, 2.2352e-05,
         4.4942e-05, 1.7679e-04, 1.4222e-04, 5.6028e-05, 4.4899e-03, 2.0294e-03,
         7.0095e-05, 2.6836e-03, 1.0166e-03, 2.9421e-04, 2.7013e-04, 6.7294e-05,
         1.6522e-04, 1.7703e-04, 8.2970e-04, 2.5249e-04],
        [1.5080e-05, 1.3959e-04, 3.4988e-05, 1.9610e-05, 1.8597e-05, 5.1439e-05,
         5.5909e-05, 6.6817e-05, 6.4075e-05, 3.2365e-05, 3.0565e-04, 6.2180e-04,
         6.1274e-05, 5.5075e-04, 4.3821e-04, 9.3460e-04, 2.8133e-04, 1.6427e-04,
         2.5320e-04, 1.1456e-04, 4.8542e-04, 1.0033e-03],
        [1.5020e-05, 1.0663e-04, 3.0100e-05, 1.4424e-05, 1.8179e-05, 8.6427e-06,
         3.9160e

In [85]:
matrixM

tensor([[1.8895e-05, 1.8358e-05, 3.3379e-05, 2.7359e-05, 4.4107e-05, 3.7014e-05,
         2.9087e-05, 1.3173e-05, 1.0192e-05, 5.0068e-06, 4.2415e-04, 7.6413e-05,
         3.8207e-05, 7.5579e-05, 1.6749e-05, 4.5300e-05, 3.7074e-05, 1.3709e-06,
         3.8147e-05, 2.6298e-04, 1.8179e-05, 3.8683e-05],
        [1.8835e-05, 1.2279e-05, 1.7643e-05, 2.5511e-05, 4.6134e-05, 2.6166e-05,
         1.9550e-05, 1.2279e-05, 1.5438e-05, 5.5730e-05, 6.7902e-04, 6.0976e-05,
         3.5644e-05, 2.1684e-04, 1.0985e-04, 4.4823e-05, 2.6107e-05, 2.2829e-05,
         5.7101e-05, 1.3685e-04, 6.0940e-04, 1.2074e-03],
        [1.8835e-05, 3.8147e-06, 3.7789e-05, 2.9087e-05, 1.2934e-05, 1.9073e-05,
         1.6272e-05, 1.0848e-05, 1.9610e-05, 2.3127e-05, 6.0976e-05, 1.9133e-05,
         1.6451e-05, 5.9605e-05, 6.3372e-04, 2.6679e-04, 4.0293e-05, 4.7088e-05,
         1.1802e-04, 1.7536e-04, 6.9284e-04, 1.3374e-02],
        [1.8895e-05, 3.3975e-06, 3.6955e-05, 4.2081e-05, 4.5121e-05, 2.7478e-05,
         7.3910e

In [95]:
sf_matM = F.softmax(matrixM,dim=1).detach()
sf_matC = F.softmax(matrixC,dim=1).detach()

In [109]:
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np

# matrix = sf_matM

# x_labels = input_token_list
# y_labels = layer_list

sf_matM_ = sf_matM[torch.arange(sf_matM.size(0) - 1, -1, -1)]
layer_list_ = layer_list[::-1]
input_token_list_ = input_token_list[:]

plt.figure(figsize=(10, 8))
sns.heatmap(sf_matM_, cmap='coolwarm', fmt='g',
            xticklabels=input_token_list_, yticklabels=layer_list_)

plt.title(f'Matrix for the output "{wordM}"')
plt.xlabel('Tokens')
plt.ylabel('Layers')

plt.savefig(f'runs/plots/id{0}_{wordM}.png')
plt.close()