In [134]:
from common_methods import *
import torch.nn.functional as F
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
from itertools import product
import json
import math

In [135]:
all_entropies = torch.load("runs/all_entropies.pt")

  all_entropies = torch.load("runs/all_entropies.pt")


In [136]:
all_entropies.shape # Note that the new shape reflects we minimised over the layers dim.

torch.Size([215, 32000])

In [137]:
inputs,_,context_outputs,memory_outputs = load_memotrap()

In [6]:
MODEL_ID = "mistralai/Mistral-7B-Instruct-v0.2"
model = AutoModelForCausalLM.from_pretrained(MODEL_ID,torch_dtype=torch.float16).eval()
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
device = torch.device("cuda")
model.to(device)
activations = {}

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [7]:
id = 0
prompt = inputs[id]

In [8]:
device = torch.device("cuda")
tokenizer.pad_token = "<s>"
eos_token = tokenizer.eos_token_id
input_ids = tokenizer(prompt,return_tensors="pt",padding=True).input_ids.to(device)

In [9]:
last_token_logits = model(input_ids).logits[0,-1,:]

We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)


In [10]:
last_token_probs = F.softmax(last_token_logits)

  last_token_probs = F.softmax(last_token_logits)


In [11]:
prompt

'Write a quote that ends in the word "fur": A rag and a bone and a hank of'

In [12]:
context = "fur"
memory = "hair"
tokenizer("fur"),tokenizer("hair")

({'input_ids': [1, 2982], 'attention_mask': [1, 1]},
 {'input_ids': [1, 3691], 'attention_mask': [1, 1]})

In [13]:
print(f"Prob('fur') = {last_token_probs[2982].item()}\nProb('hair') = {last_token_probs[3691].item()}")

Prob('fur') = 0.004942772909998894
Prob('hair') = 0.9871297478675842


I'm going to do entropy-based penalisation to ensure "fur" is of higher probability and hair is of lower probability. I need to choose alpha that ensures this. Note that the assumption is that fur's entropy is minimum across all tokens. I don't think this assumption is valid, but if we initially ignore all tokens with very low probabilities (the minimum prob tokens), then the entropy thing may work.

In [14]:
import math

l = 12 # length of the fur/hair context.
math.log(l) # this tells what the maximum possible entropy is, if the entire row is equal

2.4849066497880004

In [15]:
all_entropies[id] # Note how most vocab tokens have the max possible entropy. Only a few context tokens will differ from this

tensor([2.4849, 2.4849, 2.4849,  ..., 2.4849, 2.4849, 2.4849])

In [16]:
words = torch.nonzero(all_entropies[id] < 2.4849).squeeze().tolist()
words_strings = tokenizer.batch_decode(words)

In [17]:
for word in words_strings:
    print(word,end=", ")
print("\nDone")

a, d, in, ", that, with, an, ', from, up, about, ious, “, every, down, word, program, report, following, short, enc, sum, code, example, story, fur, write, words, hope, write, ‘, ha, anger, song, review, express, written, capt, wrote, Write, excl, letter, insp, blog, statement, describe, script, brief, ended, reflect, iously, …, Fur, motiv, letters, ends, atr, essay, inspired, Write, sentence, ending, quote, writes, describes, ell, smith, phrase, persu, courage, poem, reson, wisdom, quot, summar, paragraph, dialogue, fitting, quotes, punct, â, essays, sentences, ugno, accurately, reflects, rhet, dash, attributed, inspire, fur, …, 
Done


- One thing to note is that very few words have an entropy that even deviates slightly. That's a good sign.
- Furthermore, a lot of these words seem closely related to the words in the context. For example, the word "write" as well as the word "written", "review", "express" are all similar. I guess this is what we need for our multi-token thing too.
- There are also a lot of stop words.

Next step, see if there are any very-low-probability tokens here. I guess we want to eliminate those before modifying them.

Because what the paper also says is to perform a filtering operation first (like in DoLa), by reducing those lowest probability tokens to zero, then softmaxing to normalise the remaining tokens, and finally performing the entropy penalisation.

In [18]:
for i in range(len(words_strings)):
    if last_token_probs[words][i].item() > 9.87e-7:
        print(last_token_probs[words][i].item(), words_strings[i])
print("\nDone")

1.6104830137919635e-05 a
1.289014471694827e-05 in
1.8788913394018891e-06 "
3.7512456856347853e-06 that
4.8166953092732e-06 '
3.0616777166869724e-06 every
0.004942772909998894 fur
2.755211653493461e-06 hope
1.1575513099160162e-06 ‘
6.583642061741557e-06 ha
6.532407496706583e-06 Fur
2.776821247607586e-06 courage
1.0770118933578487e-05 fur

Done


This filtering mechanism is really good, because it shows that words like "write" which scored high on entropy (due to presence in the context) scored low on output probability (not a very viable next token). Such tokens can be eliminated immediately.

Since the max prob is 0.987, we need an alpha such that threshold becomes 1e-6.

In [19]:
alpha = 1e-6 # Let it be 1e-6, just an approximation.
alpha*0.987

9.87e-07

1. Filter out according to this $\alpha=10^{-6}$ by setting the lower ones to -inf.
2. Softmax all tokens so these ones are now zero.
3. Apply the exponential decay on probabilities with some hyperparameter $\beta$. Only the non-zero ones

In [20]:
last_token_probs.shape

torch.Size([32000])

In [21]:
torch.max(last_token_probs)

tensor(0.9871, device='cuda:0', grad_fn=<MaxBackward1>)

In [22]:
def filter_low_prob_tokens(input_logits):
    indices_to_filter = torch.nonzero(last_token_probs<alpha*torch.max(last_token_probs)).squeeze()
    filtered_logits = input_logits[:]
    filtered_logits[indices_to_filter] = float('-inf')
    filtered_probs = F.softmax(filtered_logits)
    return filtered_probs
filtered_probs = filter_low_prob_tokens(last_token_logits).detach()

  filtered_probs = F.softmax(filtered_logits)


In [52]:
def get_context_length(id):
    input_ids = tokenizer(inputs[id]).input_ids
    input_tokens = tokenizer.batch_decode(input_ids)
    # print(input_tokens)
    # print(input_tokens.index("\":"))
    return input_tokens.index("\":")+1
get_context_length(0)

12

In [80]:
def scale_prob_by_entropy(beta,probs):
    # probs is a vector 32000-dim
    entropies = all_entropies[id].to("cuda")
    # print(entropies)
    subtract = entropies-math.log(get_context_length(id))
    min_entropy_token = torch.argmin(subtract).item()
    max_prior_prob_token = torch.argmax(probs).item()

    print("Token with smallest entropy: ",min_entropy_token) # If this is 2982, it shows that "fur" is the most likely predicted according to entropy among the tokens that have high enough output probability
    print("Token with max probab: ",max_prior_prob_token)
    print("\nMaximum probability token (subtraction amt per beta):  ",subtract[max_prior_prob_token].item())
    print("Minimum entropy token (subtraction amt per beta): ",subtract[min_entropy_token].item())
    print()
    
    print("Maximum probability token (prior probab): ",max_prior_prob_token,probs[max_prior_prob_token].item())
    print("Minimum entropy token (prior probab): ",min_entropy_token,probs[min_entropy_token].item())
    print()
    logs = torch.log(probs)
    # print("Maximum probability token (prior logs): ",max_prior_prob_token,logs[max_prior_prob_token].item())
    # print("Minimum entropy token (prior logs): ",min_entropy_token,logs[min_entropy_token].item())
    # print()
    logs -= beta*subtract
    # print("Maximum probability token (posterior logs): ",max_prior_prob_token,logs[max_prior_prob_token].item())
    # print("Minimum entropy token (posterior logs): ",min_entropy_token,logs[min_entropy_token].item())
    # print()
    final_probs = F.softmax(logs,dim=0)
    print("Maximum probability token (posterior prob): ",max_prior_prob_token,final_probs[max_prior_prob_token].item())
    print("Minimum entropy token (posterior prob): ",min_entropy_token,final_probs[min_entropy_token].item())
    return final_probs
final_probs = scale_prob_by_entropy(90,filtered_probs)

Token with smallest entropy:  2982
Token with max probab:  3691

Maximum probability token (subtraction amt per beta):   -2.384185791015625e-07
Minimum entropy token (subtraction amt per beta):  -0.07077932357788086

Maximum probability token (prior probab):  3691 0.9877297282218933
Minimum entropy token (prior probab):  2982 0.004945777356624603

Maximum probability token (posterior prob):  3691 0.25430625677108765
Minimum entropy token (posterior prob):  2982 0.7438071966171265


We see the 3691'th token got scaled down and 2982th token got scaled up. If we increase beta, we can scale it such that 2982th is above 3691th.

In [84]:
# Greedy decoding
tokenizer.batch_decode([torch.argmax(final_probs).tolist()])

['fur']

Great! Now we try it on another input id.

In [87]:
id = 2

prompt = inputs[id]
device = torch.device("cuda")
tokenizer.pad_token = "<s>"
eos_token = tokenizer.eos_token_id
input_ids = tokenizer(prompt,return_tensors="pt",padding=True).input_ids.to(device)
last_token_logits = model(input_ids).logits[0,-1,:]
last_token_probs = F.softmax(last_token_logits)
prompt

  last_token_probs = F.softmax(last_token_logits)


'Write a quote that ends in the word "boat": Rats desert a sinking'

In [89]:
words = torch.nonzero(all_entropies[id] < 2.4849).squeeze().tolist()
words_strings = tokenizer.batch_decode(words)
for word in words_strings:
    print(word,end=", ")
print("\nDone")


, a, d, in, ", that, with, an, ', from, up, about, “, every, down, word, load, program, report, following, short, enc, sum, code, story, write, words, hope, write, ‘, ha, anger, song, review, express, written, capt, wrote, Write, excl, letter, insp, blog, statement, describe, script, brief, ended, reflect, …, motiv, boat, letters, ends, atr, essay, inspired, Write, sentence, ending, quote, writes, describes, ell, quote, smith, phrase, persu, courage, poem, reson, wisdom, quot, summar, paragraph, dialogue, boats, fitting, dock, quotes, punct, â, essays, sentences, accurately, reflects, rhet, dash, attributed, inspire, loads, …, 
Done


In [91]:
alpha = 1e-6

In [92]:
def filter_low_prob_tokens(input_logits):
    indices_to_filter = torch.nonzero(last_token_probs<alpha*torch.max(last_token_probs)).squeeze()
    filtered_logits = input_logits[:]
    filtered_logits[indices_to_filter] = float('-inf')
    filtered_probs = F.softmax(filtered_logits)
    return filtered_probs
filtered_probs = filter_low_prob_tokens(last_token_logits).detach()

  filtered_probs = F.softmax(filtered_logits)


In [127]:
def multiplications_needed(a):
    print(a)
    return math.ceil(-1 - math.log10(a))

multiplications_needed(0.00000253)

2.53e-06


5

In [149]:
def scale_prob_by_entropy(id,beta,probs):
    wordC,wordM = context_outputs[id],memory_outputs[id]
    wordC_token = tokenizer(wordC).input_ids[1]
    wordM_token = tokenizer(wordM).input_ids[1]
    print(f"Prob (context token) = {probs[wordC_token]}")
    print(f"Context: {wordC},{wordC_token}\nMemory: {wordM},{wordM_token}\n")

    entropies = all_entropies[id].to("cuda")
    subtract = entropies-math.log(get_context_length(id))

    # Min entropy among those tokens whose probabilities are non-zero (exclude tokens which are filtered out)
    min_entropy_token = torch.argmin(torch.where(probs > 0.0, subtract, torch.tensor(float('inf')))).item()
    max_prior_prob_token = torch.argmax(probs).item()

    multiply = math.ceil(-1 - math.log10(-subtract[min_entropy_token].item()))
    subtract *= 10**(multiply) # Multiply full thing by a constant

    print("Token with smallest entropy: ",min_entropy_token) # If this is 2982, it shows that "fur" is the most likely predicted according to entropy among the tokens that have high enough output probability
    print("Token with max probab: ",max_prior_prob_token)
    print("\nMaximum probability token (subtraction amt per beta):  ",subtract[max_prior_prob_token].item())
    print("Minimum entropy token (subtraction amt per beta): ",subtract[min_entropy_token].item())
    print()
    
    print("Maximum probability token (prior probab): ",max_prior_prob_token,probs[max_prior_prob_token].item())
    print("Minimum entropy token (prior probab): ",min_entropy_token,probs[min_entropy_token].item())
    print()
    logs = torch.log(probs)
    logs -= beta*subtract
    final_probs = F.softmax(logs,dim=0)
    print("Maximum probability token (posterior prob): ",max_prior_prob_token,final_probs[max_prior_prob_token].item())
    print("Minimum entropy token (posterior prob): ",min_entropy_token,final_probs[min_entropy_token].item())
    return final_probs
# final_probs = scale_prob_by_entropy(40,filtered_probs)

Now trying another input_id ('child', which was a bit adversarial).

In [147]:
def filter_low_prob_tokens(input_logits,alpha):
    indices_to_filter = torch.nonzero(last_token_probs<alpha*torch.max(last_token_probs)).squeeze()
    filtered_logits = input_logits[:]
    filtered_logits[indices_to_filter] = float('-inf')
    filtered_probs = F.softmax(filtered_logits)
    return filtered_probs

In [151]:
id = 1

prompt = inputs[id]
device = torch.device("cuda")
tokenizer.pad_token = "<s>"
eos_token = tokenizer.eos_token_id
input_ids = tokenizer(prompt,return_tensors="pt",padding=True).input_ids.to(device)
last_token_logits = model(input_ids).logits[0,-1,:]
last_token_probs = F.softmax(last_token_logits)
alpha = 1e-6

filtered_probs = filter_low_prob_tokens(last_token_logits,alpha).detach()
final_probs = scale_prob_by_entropy(id,40,filtered_probs)

Prob (context token) = 1.0947568625852e-05
Context: child,1502
Memory: bull,10386

Token with smallest entropy:  13
Token with max probab:  10386

Maximum probability token (subtraction amt per beta):   0.0
Minimum entropy token (subtraction amt per beta):  -0.2956390380859375

Maximum probability token (prior probab):  10386 0.9994819760322571
Minimum entropy token (prior probab):  13 1.7254038766623125e-06

Maximum probability token (posterior prob):  10386 0.8087160587310791
Minimum entropy token (posterior prob):  13 0.1908482313156128


  last_token_probs = F.softmax(last_token_logits)
  filtered_probs = F.softmax(filtered_logits)


In [140]:
tokenizer.batch_decode([13])

['\n']

In [152]:
round(3.3234123,4)

3.3234

In [3]:
import torch
import math
math.log(12)

2.4849066497880004

In [5]:
eps=1e-7
thresh = 0.005
output_entropies = torch.tensor([2.4400,2.4780,2.4849])
entropies = torch.tensor([[2.4400,2.4780,2.44]])
torch.where(math.log(12) - output_entropies > thresh,-eps,entropies-output_entropies)

tensor([[-1.0000e-07, -1.0000e-07, -4.4900e-02]])