In [42]:
from common_methods import *
import torch.nn.functional as F
import torch
from itertools import product

In [3]:
inputs,plain_inputs,context_outputs,outputs = load_memotrap()

In [4]:
torch.cuda.is_available()

True

In [5]:
model_id = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(model_id,torch_dtype=torch.float16)
tokenizer = AutoTokenizer.from_pretrained(model_id)
device = torch.device("cuda")
model.to(device)
activations = {}

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [6]:
def get_hook(layer_num):
    def hook(model,input,output):
        activations[layer_num] = output[0].detach() # not just last token, entire set of activations
    return hook

layer_list = [4,8,16,18,20,22,24,26,28,30,31,32]
for i in layer_list:
    model.model.layers[i-1].register_forward_hook(get_hook(i))

In [7]:
out = regular_decoding(model,tokenizer,inputs[0],debug=False,max_tokens=1,show_tqdm=True, return_prob=False)

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
  last_token_probs = F.softmax(last_token_logits)


In [8]:
inputs[0],out

('Write a quote that ends in the word "fur": A rag and a bone and a hank of',
 'hair')

In [20]:
attrs = {'prompt':inputs[0],'context_output':context_outputs[0],'memory_output':outputs[0],'cont_correct?':0,'mem_correct?':0} # populate with keys {"prompt":"",output:"","correct":True/False,'correct_pred':True/False}
attrs


{'prompt': 'Write a quote that ends in the word "fur": A rag and a bone and a hank of',
 'context_output': 'fur',
 'memory_output': 'hair',
 'cont_correct?': 0,
 'mem_correct?': 0}

Now we try and predict whether our output is correct using only internal activations (layers 24,26,28) and save it to the dict.

In [11]:
activations[26],activations[26].shape

(tensor([[[-0.1450, -0.3086, -0.0419,  ..., -0.0010,  0.0010,  0.0802],
          [ 0.4109, -0.5459, -0.0330,  ...,  0.0524, -0.3105,  0.4011],
          [ 0.2129, -0.0693,  0.2617,  ...,  0.0240, -0.2040,  0.4707],
          ...,
          [ 0.1296,  0.0505,  0.1825,  ..., -0.1674,  0.1279, -0.1014],
          [-0.1222, -0.0828,  0.1844,  ...,  0.1724,  0.0905, -0.1588],
          [-0.0698,  0.1282, -0.0683,  ..., -0.0945, -0.2358,  0.2554]]],
        device='cuda:0', dtype=torch.float16),
 torch.Size([1, 22, 4096]))

In [13]:
input_token_list = tokenizer.batch_decode(tokenizer(inputs[0]).input_ids)

In [25]:
input_token_list

['<s>',
 'Write',
 'a',
 'quote',
 'that',
 'ends',
 'in',
 'the',
 'word',
 '"',
 'fur',
 '":',
 'A',
 'rag',
 'and',
 'a',
 'bone',
 'and',
 'a',
 'h',
 'ank',
 'of']

In [18]:
wordO = out
wordO_tok = tokenizer(wordO).input_ids[1]
print(wordO_tok)
tokenizer.batch_decode([3691])

2982 3691


['fur', 'hair']

Okay, instead of this, we can use wordC and wordM to check whether the activation idea is actually sensible. If it works, then it will easily be able to predict whether whatever output is the correct output or the wrong one. Predicting like this is the first step to efficient decoding.

In [19]:
wordC = "fur"
wordM = "hair"
wordC_tok,wordM_tok = tokenizer(wordC).input_ids[1],tokenizer(wordM).input_ids[1]
print(wordC_tok,wordM_tok)
tokenizer.batch_decode([2982,3691])

2982 3691


['fur', 'hair']

In [35]:
len(input_token_list),activations[22].shape,input_token_list[10:12]

(22, torch.Size([1, 22, 4096]), ['fur', '":'])

In [53]:

layers_to_check = [22,24,26,28]
intokens_to_check = [9,10,11] # These are the token positions at the input.

actC = False
actM = False

for layer, intoken in product(layers_to_check,intokens_to_check):
    act = F.softmax(model.lm_head(model.model.norm(activations[layer][0,intoken,:])),dim=0)

    # print(f"Layer {layer}, Intoken {intoken}")
    top_values, top_indices = torch.topk(act, k=50)
    print(tokenizer.batch_decode(top_indices.tolist()))
    
    if wordC_tok in top_indices:
        actC=True
    if wordM_tok in top_indices:
        actM = True



['ing', 'justice', 'inspire', 'wisdom', 'mel', 'courage', 'innovation', 'bekan', 'mor', 'ends', 'hope', 'quot', 'qu', 'revolution', 'ment', 'love', 'aph', 'ize', 'democracy', 'ins', 'ugno', 'adventure', 'insp', 'parad', 'smithy', 'em', 'mente', 'úblic', '‘', 'Dream', 'ende', 'inspiration', 'butter', 'creativity', 'th', 'reflection', '–', 'fish', 'momentum', 'happiness', 'ly', '–', 'Budd', 'soul', '̶', 'ment', 'ál', 'joy', 'st', 'ep']
['fur', 'fur', 'Fur', 'iously', 'ious', 'rier', 'rows', 'ther', 'ry', 'thur', 'ance', 'coat', 'THER', 'gery', 'ged', 'ball', 'thy', 'balls', 'ges', 'iety', 'riers', 'rr', 'Bent', 'ries', 'ment', 'row', 'thers', 'rest', 'dy', 'iance', 'flying', 'rying', 'sten', 'or', 'fol', 'med', 'ces', 'ugno', 'mouse', 'cats', 'lég', 'ment', 'th', 'rab', 'Für', 'animals', 'bear', 'ring', 'um', 'uent']
['fur', 'coat', 'fur', 'Fur', 'litter', '"', 'dogs', 'quote', 'rier', 'dog', 'cats', 'cat', 'pets', 'animal', 'wides', 'honor', 'favor', 'ending', 'thur', 'iously', '\n', 'r

In [47]:
attrs['cont_correct?'] = actC
attrs['mem_correct?'] = actM

In [48]:
attrs

{'prompt': 'Write a quote that ends in the word "fur": A rag and a bone and a hank of',
 'context_output': 'fur',
 'memory_output': 'hair',
 'cont_correct?': True,
 'mem_correct?': False}

In [51]:
input_token_list.index("\":")

11

In [52]:
input_token_list[11]

'":'