In [364]:
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, AutoConfig, StoppingCriteriaList, StoppingCriteria
from collections import defaultdict
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr, kendalltau

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go

import itertools
import torch

In [None]:
model_id = "mistralai/Mistral-7B-v0.1"
torch.set_default_device("cpu")

In [None]:
hf_key = ""
if model_id in ["meta-llama/Llama-2-7b-hf"]:
    hf_key = input("Hugging Face Key: ")
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_key)
model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, token=hf_key)
model_config = AutoConfig.from_pretrained(model_id, trust_remote_code=True, token=hf_key)
del hf_key

In [316]:
if model_id in ["microsoft/phi-1_5"]:
    stopgen_tokens = [
        torch.tensor([198, 198]),  # \n\n
        torch.tensor([628])        # \n\n
    ]
    prompt_structure = "Question: {prompt}\n\nAnswer:"
    exclude_token_offset = 3
    fix_characters = [("Ġ", "␣"), ("Ċ", "\n")]
elif model_id in ["meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1"]:
    stopgen_tokens = [
        torch.tensor([1]),  # <s>
        torch.tensor([2])   # </s>
    ]
    prompt_structure = "{prompt}"
    exclude_token_offset = None
    fix_characters = [("<0x0A>", "\n")]

fix_characters += [("\n", "\\n")]

In [402]:
import requests 
import urllib 
  
def ngram_query(words, start_year=2018, end_year=2019, corpus="en-2019", smoothing=0): 
    query = ",".join([urllib.parse.quote(w) for w in words]) 
    # creating the URL 
    url = 'https://books.google.com/ngrams/json?content=' + query + '&year_start=' + str(start_year) + '&year_end=' +\
            str(end_year) + '&corpus=' + str(corpus) + '&smoothing=' + str(smoothing) + '' 
    # requesting data from the above url 
    response = requests.get(url) 
    # extracting the json data from the response we got 
    output = response.json() 
    # creating a list to store the ngram data 
    return_data = [] 
  
    if len(output) == 0: 
        return None
    else: 
        # if data returned from site, store the data in return_data list 
        for num in range(len(output)): 
            # getting the name 
            return_data.append((
                output[num]['ngram'],
                # getting ngram data 
                output[num]['timeseries']
            ))
    return return_data 

In [387]:
def get_first_order_model(model):
    input_emb = model.model.get_input_embeddings()
    output_weights = model.lm_head.weight
    output_bias = model.lm_head.bias
    if output_bias is not None:
        print("Warning, output bias not utilized")
        bias = 0
    else:
        bias = 0
    output_emb = torch.nn.Embedding.from_pretrained(output_weights, freeze=True)
    return input_emb, output_emb, bias

def get_fom_prediction(fom, tokens):
    in_emb, out_emb, out_bias = fom
    all_logits = []
    out_tokens = []
    for token in tokens:
        hidden = in_emb(token)
        logits = torch.matmul(hidden, out_emb.weight.T) + out_bias
        all_logits.append(logits)
        out_tokens.append(logits.argmax(dim=-1))
    return all_logits, out_tokens

def check_2gram(fom, tokenizer, word, batch=500):
    in_emb, _, _ = fom
    # Workaround to prevent spaces from being trimmed
    words = [word + tokenizer.decode([1, i])[3:] for i in range(in_emb.weight.size()[0])]
    # Remove words that contain arithmetic symbols, Google Ngram separators or don't contain the original word
    words = [w for w in words if word in w and all(c not in w for c in ["/", "+", "-", "*", ",", ":"])]

    results = []
    n_batches = len(words) // batch
    n_residual = len(words) % batch

    # Probability of original word
    original_prob = ngram_query([word])
    if original_prob and original_prob[0][-1][-1] > 0:
        original_prob = original_prob[0][-1][-1]
        for b in tqdm(range(0, n_batches)):
            a = ngram_query(words[b*batch : (b+1)*batch])
            results.extend(a)
        if n_residual != 0:
            results.extend(ngram_query(words[-n_residual:]))
        # Print words that will be removed due to resulting in more than one token id
        for k, v in results:
            if len(tokenizer(k, add_special_tokens=False).input_ids[1:]) > 1:
                print(k + ' ' + str(tokenizer(k, add_special_tokens=False).input_ids) + ' ' + str(v[-1] if v else 0))
        # Create a dictionary containing the latest results and remove words that don't contain the original word or that only consist of the original word
        # Normalize by original word probability
        # Add token ids
        results = {k: (tokenizer(k, add_special_tokens=False).input_ids[1], v[-1] / original_prob) for k, v in results if 
                   v and word.strip() != k and word.strip() in k and len(tokenizer(k, add_special_tokens=False).input_ids[1:]) == 1}
        
        return results
    return None

In [388]:
# Get n-gram probabilities
results_ngram = check_2gram(fom, tokenizer, " how", batch=400)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 78/78 [25:04<00:00, 19.29s/it]


how . " [910, 842, 345] 4.081721840520913e-07
how = " [910, 327, 345] 1.620947820413221e-09
howng [295, 656, 28721] 0
howning [295, 656, 288] 0
how ? " [910, 1550, 345] 7.020018415460072e-07
howness [295, 656, 409] 5.169509265101624e-09
hownot [295, 656, 322] 3.942846049653781e-10
howno [295, 656, 28709] 0.0
howna [295, 656, 28708] 0
how " ; [910, 345, 2753] 2.6285640331025206e-09
how | | [910, 342, 342] 8.761880110341735e-11
hownet [295, 656, 299] 2.6285640331025206e-10
how ' ; [910, 464, 2753] 1.0514256132410083e-09
how . ' [910, 842, 464] 6.38302992683748e-08
how U.S. [910, 500, 28723, 28735, 28723] 7.46263140172232e-08
how . " [910, 842, 345] 4.081721840520913e-07
how " " [910, 345, 345] 7.009504088273388e-10
how " . [910, 345, 842] 1.0470446731858374e-08
how = ' [910, 327, 464] 4.9504622623430805e-09
how_ . [910, 28730, 842] 1.2493245327149793e-10
how ! " [910, 918, 345] 4.3502733859668297e-08
howns [295, 656, 28713] 0.0
howni [295, 656, 28710] 0
hownow [295, 656, 336] 4.380940055

In [397]:
# Get fom probabilities
fom = get_first_order_model(model)
token_ids = tokenizer("how", return_tensors="pt").input_ids
logits_fom, outputs_fom = get_fom_prediction(fom, token_ids)

paired_outputs = []
for input_tok, output_tok in zip(token_ids[0], outputs[0]):
    original_token = tokenizer.decode(input_tok)
    fom_token = tokenizer.decode(output_tok.item())
    paired_outputs.append((original_token, fom_token))

In [410]:
# Print most likely tokens 
print("N-Gram")
max_ngram_k = max(results_ngram, key=lambda k: results_ngram[k][1])
print(f"'{max_ngram_k}' P = {results_ngram[max_ngram_k][1]}")
print(f"Corresponding FOM logits: {logits_fom[0][1][results_ngram[max_ngram_k][0]].detach()}") # Assuming that source word is at position 1 of logits
print("--------------------")
print("FOM")
for out in paired_outputs:
    print(out)
    ngram_prob = ngram_query(["".join(out)])
    print(f"Corresponding N-Gram probability: {ngram_prob[0][-1][-1] if ngram_prob else '-'}")

N-Gram
'how to' P = 0.1313060090509322
Corresponding FOM logits: 0.0004904341185465455
--------------------
FOM
('<s>', 'ocker')
Corresponding N-Gram probability: -
('how', 'Pra')
Corresponding N-Gram probability: -


In [394]:
# Create paired logits/probabilities to compare results
values_ngram = []
values_fom = []
for k, v in results_ngram.items():
    id, val = v
    values_ngram.append(val)
    values_fom.append(logits_fom[0][1][id].detach())  # Assuming that source word is at position 1 of logits

In [396]:
# Print statistics
print(spearmanr(values_ngram, values_fom))
print(pearsonr(values_ngram, values_fom))
print(kendalltau(values_ngram, values_fom))

SignificanceResult(statistic=-0.01452817825887986, pvalue=0.12338516856542413)
PearsonRResult(statistic=0.01366765436597835, pvalue=0.14721156467350693)
SignificanceResult(statistic=-0.009783086174432727, pvalue=0.1223985314364765)
