In [1]:
import random
import torch
import copy
import gc

import pandas as pd

from tqdm.notebook import tqdm

from gensim.test.utils import datapath

from scipy.stats import pearsonr, spearmanr, kendalltau

from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig

import matplotlib.pyplot as plt
from matplotlib import colormaps as clmp
from matplotlib.ticker import AutoMinorLocator

# Import Models

In [2]:
noadd = lambda x: x
addspace = lambda x: " " + x
addall = lambda x: (x.capitalize(), " " + x.capitalize(), x.lower(), " " + x.lower())

In [3]:
model_ids = ["gpt2", "meta-llama/Llama-2-7b-hf", "mistralai/Mistral-7B-v0.1", "google/gemma-7b"]
model_names = [ "GPT 2", "LLaMa 2", "Mistral", "Gemma"]
model_format = [addspace, noadd, noadd, addspace]
need_key = [ "meta-llama/Llama-2-7b-hf", "google/gemma-7b" ]
device = "cuda"
torch.set_default_device(device)

In [4]:
def extract_embeddings(model_ids, need_key_list):
    tokenizers = []
    in_embeddings = []
    out_embeddings = []
    for model_id, model_name in zip(model_ids, model_names):
        print(f"Loading {model_name}...")
        hf_key=None
        if model_id in need_key_list:
            hf_key = input("Hugging Face Key: ")
        model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True, token=hf_key, torch_dtype=torch.float16)
        tokenizers.append(AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, token=hf_key))
        if hf_key:
            del hf_key
        in_embeddings.append(model.get_input_embeddings())
        out_embeddings.append(get_output_embeddings(model))
        del model
        torch.cuda.empty_cache()
        gc.collect()
        print(f"{torch.cuda.memory_allocated(0) / 1024**2} ({torch.cuda.memory_reserved(0) / 1024**2}) / {torch.cuda.get_device_properties(0).total_memory / 1024**2}")
    return tokenizers, in_embeddings, out_embeddings

def get_output_embeddings(model):
    weights = model.lm_head.weight
    bias = model.lm_head.bias
    if bias is None:
        bias = 0
    else:
        print("Warning, bias not utilized")
    return torch.nn.Embedding.from_pretrained(weights, freeze=True)

In [None]:
tokenizers, in_emb, out_emb = extract_embeddings(model_ids, need_key)

In [None]:
#for in_emb_mod, out_emb_mod in zip(in_emb, out_emb):
#    in_emb_mod.to("cuda")
#    out_emb_mod.to("cuda")
#torch.set_default_device("cuda")

# Define Helper Functions

In [None]:
def multiencode(tok, words):
    if (isinstance(words, list) or isinstance(words, tuple)) and not isinstance(words, str):
        # Encode a list of words
        return torch.cat([tok.encode(word, return_tensors="pt", add_special_tokens=False) for word in words], dim=-1)
    else:
        # Encode a single word
        return tok.encode(words, return_tensors="pt", add_special_tokens=False)
    
def avgencode(emb, word, tok=None, avg=True):
    source = word
    # If input is a string tokenize it
    if (isinstance(word, str) or isinstance(word[0], str)) and tok is not None:
        word = emb(multiencode(tok, word))
    # Calculate average if avg flag is true and if it is needed
    if word.shape[1] != 1 and avg:
        word = torch.unsqueeze(torch.mean(word, dim=1), dim=1)
    elif word.shape[1] != 1 and not avg:
        raise Exception(f"{source} is not a single token: {word}")
    return word

In [None]:
def calc_distance(emb, word1, word2, tok=None, avg=True, dist="cosine", multi=False):
    # Encode and average (if multi is True, word1 represents the embedding matrix)
    if not multi:
        word1 = avgencode(emb, word1, tok, avg=avg)
    word2 = avgencode(emb, word2, tok, avg=avg)
    # Compute distances
    if dist == "L2":
        distances = torch.norm(word1 - word2, dim=2)
    elif dist == "cosine":
        cs = torch.nn.CosineSimilarity(dim=2)
        distances = 1 - cs(word1, word2)
    else:
        raise Exception("Unknown distance")
    return distances

def get_closest_emb(emb, word, k=1, decode=True, tok=None, avg=True, dist="cosine"):
    # Compute distances from matrix
    distances = calc_distance(emb, emb.weight.data, word, tok=tok, avg=avg, dist=dist, multi=True)
    # Compute top k smalles indices
    topk = torch.squeeze(torch.topk(distances, k=k, largest=False).indices)
    # If one element, unsqueeze it
    if k == 1:
        topk = torch.unsqueeze(topk, dim=0)
    # Decode closest k
    if decode and tok is not None:
        topk = [tok.decode(c) for c in topk.tolist()]
    return topk

def emb_arithmetic(emb, tok, words, k=1, dist="cosine"):
    # Compute embeddings
    w1 = words[0]
    w2 = words[1]
    w3 = words[2] if len(words) > 2 else 0
    # Do embedding arithmetic
    if dist == "L2":
        w = torch.nn.functional.normalize(w1 - w2 + w3, dim=0)
    else:
        w = w1 - w2 + w3
    # Get closest k
    closest = get_closest_emb(emb, w, k=k, decode=True, tok=tok, dist=dist)
    return (w, closest)

def print_results(res):
     for i, r in enumerate(res):
        print(f"{i+1}) {repr(r)}")

In [None]:
def batch_emb_arithmetic(emb, tok, queries, k=5, avg=True, dist="cosine", out=True):
    ret = []
    for q in queries:
        if out:
            print("##########################")
            # Print title
            title_q = q
            if not isinstance(q[0], str):
                title_q = [qq[0] for qq in q]
            print(f"{title_q[0]} - {title_q[1]} + {title_q[2]} =")
        # Compute and print results
        res = emb_arithmetic(emb, tok, [avgencode(emb, word, tok, avg=avg) for word in q], k=k, dist=dist)
        if out:
            print_results(res[1])
        ret.append(res)
    return ret

def advanced_arithmetic(emb, tok, queries, k=5, avg=True, dist="cosine", out=True):
    ret = []
    ref_pairs = []
    # Compute reference pairs for delta
    for q in queries:
        ref_q = q
        if not isinstance(q[0], str):
            ref_q = [qq[0] for qq in q]
        ref_pairs.extend([(ref_q[0], ref_q[1])])
    # Remove duplicate ref pairs
    ref_pairs = set(frozenset(t) for t in ref_pairs)
    delta = torch.mean(torch.stack([emb_arithmetic(emb, tok, [avgencode(emb, word, tok, avg=avg) for word in pair], k=1, dist=dist)[0] for pair in ref_pairs]).squeeze(), dim=0)
    for q in queries:
        if out:
            print("##########################")
            # Print title
            title_q = q
            if not isinstance(q[0], str):
                title_q = [qq[0] for qq in q]
            print(f"{title_q[0]} + Δ =")
        # Compute and print results
        res = emb_arithmetic(emb, tok, [avgencode(emb, q[0], tok, avg=avg), 0, delta], k=k, dist=dist)
        if out:
            print_results(res[1])
        ret.append(res)
    return ret

def evaluate_batch(results, solutions, out=True, score="rankscore", k=None, tok=None, subdivide_tokens=False):
    
    def get_rank(r, s, out=0):
        try:
            return r.index(s)
        except ValueError:
            return out
    
    n = len(results[0][1])
    ev = []
    for res, sol in zip(list(map(lambda x: x[1], results)), solutions):
        # If model encodes solutions with more than one token, add all the tokens separately
        if subdivide_tokens and tok:
            new_sol = []
            for s in sol:
                encoded_s = tok.encode(s, return_tensors="pt", add_special_tokens=False).squeeze()
                if encoded_s.size():
                    new_sol.extend([tok.decode(token) for token in encoded_s])
                else:
                    new_sol.append(s)
            sol = new_sol
        # Get rank of each solution for each result outputs
        ranks = [get_rank(res, s, out=n) for s in sol]
        # Append best rank to final evaluation list
        ev.append(min(ranks))
    # Return score
    if score == "rankscore":
        score = 1 - ( sum(ev) / (n * len(solutions)) )
    elif score == "topk":
        if not k:
            raise Exception(f"Invalid k for topk")
        score = len([i for i in ev if i < k]) / len(ev)
    else:
        raise Exception(f"Unknown Score")
    if out:
        print(f"{ev} -> {score}")
    return score

# Define Test Cases

### Capital Arithmetic

In [None]:
capital_single = [
    ["Rome", "Italy", "France"],
    ["Rome", "Italy", "Australia"],
    ["Paris", "France", "Italy"],
    ["Paris", "France", "Australia"],
    ["Canberra", "Australia", "Italy"],
    ["Canberra", "Australia", "France"],
]
capital_advanced = [
    ["Rome", "Italy"],
    ["Paris", "France"],
    ["Canberra", "Australia"],
    ["Ankara", "Turkey"],
    ["Berlin", "Germany"],
    ["Washington", "USA"],
    ["Madrid", "Spain"],
    ["Dublin", "Ireland"],
    ["Copenaghen", "Denmark"],
    ["Amsterdam", "Netherlands"],
    ["Vienna", "Austria"],
    ["Tokyo", "Japan"],
    ["Seoul", "South Korea"],
]
capital_sol = [
    addall("Paris"),
    addall("Canberra"),
    addall("Rome"),
    addall("Canberra"),
    addall("Rome"),
    addall("Paris"),
]
capital_advanced_sol = [
    addall("Italy"),
    addall("France"),
    addall("Australia"),
    addall("Turkey"),
    addall("Germany"),
    addall("USA") + addall("America"),
    addall("Spain"),
    addall("Ireland"),
    addall("Denmark"),
    addall("Netherlands"),
    addall("Austria"),
    addall("Japan"),
    addall("South Korea"),
]

### Sex Arithmetic

In [None]:
sex_single = [
    ["king", "man", "woman"],
    ["queen", "woman", "man"],
    ["prince", "man", "woman"],
    ["princess", "woman", "man"],
    ["priest", "man", "woman"],
    ["nun", "woman", "man"],
]
sex_sol = [
    addall("Queen"),
    addall("King"),
    addall("Princess"),
    addall("Prince"),
    addall("Nun"),
    addall("Priest"),
]

# Calculate and Display Test Results

In [None]:
print("Capital rankings")
for model_name, tok, in_emb_mod, out_emb_mod, format in zip(model_names, tokenizers, in_emb, out_emb, model_format):
    print("---------------------------")
    print(model_name)
    test_format = [ [format(el) for el in test] for test in capital_single]
    test_capital = [
        ["Input Capital", batch_emb_arithmetic(in_emb_mod, tok, test_format, k=100, out=False)],
        ["Output Capital", batch_emb_arithmetic(out_emb_mod, tok, test_format, k=100, out=False)],
    ]
    for name, result in test_capital:
        print("%%%")
        print(name)
        evaluate_batch(result, capital_sol, tok=tok, subdivide_tokens=True)
print("###############################")
print("Capital rankings")
for model_name, tok, in_emb_mod, out_emb_mod, format in zip(model_names, tokenizers, in_emb, out_emb, model_format):
    print("---------------------------")
    print(model_name)
    test_format = [ [format(el) for el in test] for test in sex_single]
    test_sex = [
        ["Input Sex", batch_emb_arithmetic(in_emb_mod, tok, test_format, k=100, out=False)],
        ["Output Sex", batch_emb_arithmetic(out_emb_mod, tok, test_format, k=100, out=False)],
    ]
    for name, result in test_sex:
        print("%%%")
        print(name)
        evaluate_batch(result, sex_sol, tok=tok, subdivide_tokens=True)

# Visualizations 

In [None]:
index = 3

In [None]:
_ = batch_emb_arithmetic(in_emb[index], tokenizers[index], [ [model_format[index](el) for el in test] for test in capital_single], k=10)

In [None]:
_ = batch_emb_arithmetic(out_emb[index], tokenizers[index], [ [model_format[index](el) for el in test] for test in capital_single], k=10)

In [None]:
_ = batch_emb_arithmetic(in_emb[index], tokenizers[index], [ [model_format[index](el) for el in test] for test in sex_single], k=10)

In [None]:
_ = batch_emb_arithmetic(out_emb[index], tokenizers[index], [ [model_format[index](el) for el in test] for test in sex_single], k=10)

In [None]:
res = get_closest_emb(in_emb[index], model_format[index]("nun"), k=10, tok=tokenizers[index])
print_results(res)

In [None]:
res = get_closest_emb(out_emb[index], model_format[index]("nun"), k=10, tok=tokenizers[index])
print_results(res)

## Gensim datasets

In [None]:
def load_question_words(path):
    with open(path, 'r') as file:
        lines = file.readlines()
    data = {}
    current_category = None
    for line in lines:
        line = line.strip()
        # Check if the line denotes a new category
        if line.startswith(':'):
            current_category = line[2:]
            data[current_category] = []
        else:
            data[current_category].append(line.split())
    # Create DataFrames for each category
    dfs = {}
    for category, attributes in data.items():
        df = pd.DataFrame(attributes, columns=['A', 'B', 'Solution', 'C'])
        # Reassign order
        df = df.reindex(columns = ['A', 'B', 'C', 'Solution'])
        dfs[category] = df
    return dfs

def change_words(batch, transform=lambda x: x):
    return [[transform(word) for word in entry] for entry in batch]

In [None]:
data_sim = pd.read_csv(datapath('wordsim353.tsv'), sep='\t', skiprows=2, names=["Word1", "Word2", "Human"])
data_sim["Human"] = round(data_sim["Human"] / 10, 3)
data_quest = load_question_words(datapath('questions-words.txt'))

print("Word-Similarity Data")
print(data_sim.size)
print("#############################################")
print("Question-Words Data")
for category, dataset in data_quest.items():
    print(f"{category:<25} \t Size: {dataset.size}")

### Word similarity

In [None]:
def word_sim_function(emb, tok, x, format, dist="cosine"):
    result = torch.squeeze(1 - calc_distance(emb, format(x["Word1"]), format(x["Word2"]), tok=tok, dist=dist))
    result = torch.round(result, decimals=3)
    return result.detach().cpu().numpy()

In [None]:
correlations = []
for model_name, tok, in_emb_mod, out_emb_mod, format in zip(model_names, tokenizers, in_emb, out_emb, model_format):
    # Calculate similarities
    data_sim[model_name + "_in"] = data_sim[["Word1", "Word2"]].apply(lambda x: word_sim_function(in_emb_mod, tok, x, format), axis=1).astype(float)
    data_sim[model_name + "_out"] = data_sim[["Word1", "Word2"]].apply(lambda x: word_sim_function(out_emb_mod, tok, x, format), axis=1).astype(float)
    # Generate correlation coefficients
    correlations.append((
        [
            pearsonr(data_sim[model_name + "_in"], data_sim["Human"]),
            spearmanr(data_sim[model_name + "_in"], data_sim["Human"]),
            kendalltau(data_sim[model_name + "_in"], data_sim["Human"]),
        ],
        [
            pearsonr(data_sim[model_name + "_out"], data_sim["Human"]),
            spearmanr(data_sim[model_name + "_out"], data_sim["Human"]),
            kendalltau(data_sim[model_name + "_out"], data_sim["Human"]),
        ]
    ))

In [None]:
for model_name, corr in zip(model_names, correlations):
    print("###################################")
    print(f"Model: {model_name}")
    print("%%%%")
    print("Input Embeddings")
    print(f"{'Pearson Correlation Coefficient':<40} - {' r '} : {corr[0][0][0]:.3} {'pv'} : {corr[0][0][1]:.3}")
    print(f"{'Spearman Correlation Coefficient':<40} - {'rho'} : {corr[0][1][0]:.3} {'pv'} : {corr[0][1][1]:.3}")
    print(f"{'Kendall Correlation Coefficient':<40} - {'tau'} : {corr[0][2][0]:.3} {'pv'} : {corr[0][2][1]:.3}")
    print("%%%%")
    print("Output Embeddings")
    print(f"{'Pearson Correlation Coefficient':<40} - {' r '} : {corr[1][0][0]:.3} {'pv'} : {corr[1][0][1]:.3}")
    print(f"{'Spearman Correlation Coefficient':<40} - {'rho'} : {corr[1][1][0]:.3} {'pv'} : {corr[1][1][1]:.3}")
    print(f"{'Kendall Correlation Coefficient':<40} - {'tau'} : {corr[1][2][0]:.3} {'pv'} : {corr[1][2][1]:.3}")

### Question Words

In [None]:
test_questions = []
for model_name, tok, in_emb_mod, out_emb_mod, format in zip(model_names, tokenizers, in_emb, out_emb, model_format):
    test_questions.append([
        {
            category: batch_emb_arithmetic(in_emb_mod, tok, change_words(dataset.values.tolist(), format), k=50, out=False)
            for category, dataset in tqdm(data_quest.items())
        },
        {
            category: batch_emb_arithmetic(out_emb_mod, tok, change_words(dataset.values.tolist(), format), k=50, out=False)
            for category, dataset in tqdm(data_quest.items())
        },
    ])

test_question_sol = [[addall(entry) for entry in dataset.iloc[:, -1].to_list()] for dataset in data_quest.values()]

In [None]:
test_k = [5, 10, 15, 25, 50]
for category, sol in zip(data_quest.keys(),test_question_sol):
    print("########################")
    print(category)
    for model_name, tq, trace_idx in zip(model_names, test_questions, range(0, len(model_names)*2, 2)):
        print("---------------------------")
        print(f"{model_name} Input Rank Score: {evaluate_batch(tq[0][category], sol, out=False, score='rankscore'):.2f}")
        print(f"{model_name} Output Rank Score: {evaluate_batch(tq[1][category], sol, out=False, score='rankscore'):.2f}")
        plt.plot(test_k, [evaluate_batch(tq[0][category], sol, out=False, score='topk', k=k) for k in test_k], 
                 marker='o', alpha=0.9, label=f'{model_name} Input Embeddings', c=clmp["Paired"](trace_idx))
        plt.plot(test_k, [evaluate_batch(tq[1][category], sol, out=False, score='topk', k=k) for k in test_k], 
                 marker='o', alpha=0.9, label=f'{model_name} Output Embeddings', c=clmp["Paired"]((trace_idx+1)))
        plt.xlabel('k')
        plt.ylabel('Accuracy')
        plt.xticks(test_k)
        plt.title(category + " Top-K Accuracy")
    plt.legend()
    plt.gca().xaxis.set_minor_locator(AutoMinorLocator(1))
    plt.grid(linestyle = '--', linewidth = 0.5, which="minor")
    plt.grid(linestyle = '--', linewidth = 1, which="major")
    plt.savefig(f"analogy_{category}.png")
    plt.show()

In [None]:
print("Capital rankings")
for model_name, tok, in_emb_mod, out_emb_mod, format in zip(model_names, tokenizers, in_emb, out_emb, model_format):
    print("---------------------------")
    print(model_name)
    test_format = [ [format(el) for el in test] for test in capital_advanced]
    test_capital = [
        ["Input Capital", advanced_arithmetic(in_emb_mod, tok, test_format, k=100, out=False)],
        ["Output Capital", advanced_arithmetic(out_emb_mod, tok, test_format, k=100, out=False)],
    ]
    for name, result in test_capital:
        print("%%%")
        print(name)
        evaluate_batch(result, capital_advanced_sol, tok=tok, subdivide_tokens=True)

In [None]:
test_questions = []
for model_name, tok, in_emb_mod, out_emb_mod, format in zip(model_names, tokenizers, in_emb, out_emb, model_format):
    test_questions.append([
        {
            category: advanced_arithmetic(in_emb_mod, tok, change_words(dataset.values.tolist(), format), k=50, out=False)
            for category, dataset in tqdm(data_quest.items())
        },
        {
            category: advanced_arithmetic(out_emb_mod, tok, change_words(dataset.values.tolist(), format), k=50, out=False)
            for category, dataset in tqdm(data_quest.items())
        },
    ])

test_question_sol = [[addall(entry) for entry in dataset.iloc[:, 1].to_list()] for dataset in data_quest.values()]

In [None]:
test_k = [5, 10, 15, 25, 50]
for category, sol in zip(data_quest.keys(),test_question_sol):
    print("########################")
    print(category)
    for model_name, tq, trace_idx in zip(model_names, test_questions, range(0, len(model_names)*2, 2)):
        print("---------------------------")
        print(f"{model_name} Input Rank Score: {evaluate_batch(tq[0][category], sol, out=False, score='rankscore'):.2f}")
        print(f"{model_name} Output Rank Score: {evaluate_batch(tq[1][category], sol, out=False, score='rankscore'):.2f}")
        plt.plot(test_k, [evaluate_batch(tq[0][category], sol, out=False, score='topk', k=k) for k in test_k], 
                 marker='o', alpha=0.9, label=f'{model_name} Input Embeddings', c=clmp["Paired"](trace_idx))
        plt.plot(test_k, [evaluate_batch(tq[1][category], sol, out=False, score='topk', k=k) for k in test_k], 
                 marker='o', alpha=0.9, label=f'{model_name} Output Embeddings', c=clmp["Paired"]((trace_idx+1)))
        plt.xlabel('k')
        plt.ylabel('Accuracy')
        plt.xticks(test_k)
        plt.title(category + " Top-K Accuracy")
    plt.legend()
    plt.gca().xaxis.set_minor_locator(AutoMinorLocator(1))
    plt.grid(linestyle = '--', linewidth = 0.5, which="minor")
    plt.grid(linestyle = '--', linewidth = 1, which="major")
    plt.savefig(f"delta_analogy_{category}.png")
    plt.show()