In [25]:
import gc
import warnings
from collections import OrderedDict
from functools import partial
from itertools import combinations

import plotly.express as px
import torch.nn.functional as f
from pandas import DataFrame
from torch import cuda, cdist, sum
from transformers import AutoModelForCausalLM, AutoTokenizer

warnings.simplefilter(action='ignore', category=Warning)

In [2]:
MODEL = "google/gemma-2-2b"
DEVICE = "cpu"

In [3]:
cuda.empty_cache()
gc.collect()
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    device_map=DEVICE,
    torch_dtype="auto",
    trust_remote_code=True,
)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [49]:
tokenizer = AutoTokenizer.from_pretrained(MODEL)
emotions = ["happy", "elated", "peaceful", "calm", "content", "relaxed", "sad", "anxious", "fearful", "scared",
            "depressed", "lonely", "bitter", "jealous", "angry", "guilty", "passionate", "brave", "confident"]

# squeeze tensor dim, strip <bos>
print("orignal shape: ", tokenizer("happy", return_tensors='pt').input_ids)
emo_tokens = OrderedDict({emo: tokenizer(emo, return_tensors='pt').input_ids for emo in emotions})

print(emo_tokens)

orignal shape:  tensor([[    2, 11896]])
OrderedDict({'happy': tensor([[    2, 11896]]), 'elated': tensor([[  2, 521, 840]]), 'peaceful': tensor([[     2, 211749]]), 'calm': tensor([[     2, 116051]]), 'content': tensor([[   2, 3312]]), 'relaxed': tensor([[     2, 163861]]), 'sad': tensor([[    2, 37968]]), 'anxious': tensor([[    2,   481, 24192]]), 'fearful': tensor([[    2, 71339,  1329]]), 'scared': tensor([[     2, 221959]]), 'depressed': tensor([[   2, 3243, 3734]]), 'lonely': tensor([[     2, 151738]]), 'bitter': tensor([[     2, 158930]]), 'jealous': tensor([[    2,  1792, 22108]]), 'angry': tensor([[    2, 70709]]), 'guilty': tensor([[     2, 206971]]), 'passionate': tensor([[    2, 94364,   607]]), 'brave': tensor([[     2, 149142]]), 'confident': tensor([[     2, 131181]])})


In [17]:
# Create a table to record initial embeddings, and then output after selected layers
layers = list(range(len(model.model.layers)+1))
emb_sums = DataFrame(index=emotions, columns=layers)

In [40]:
def pandas_hook(module, input, output, word, layer_id):
    if layer_id == 0: # put in initial embeddings
        emb_sums[layer_id][word] = sum(input[0].squeeze(), 0)
    emb_sums[layer_id + 1][word] = sum(output[0].squeeze(), 0)

def pass_word_through_model(word):
    for l in layers[:-1]:
            model.model.layers[l]._forward_hooks = OrderedDict() # clear all the old hooks first
            model.model.layers[l].register_forward_hook(partial(pandas_hook, word=word, layer_id=l))
    model.model(emo_tokens[word])
 # model.model separates out attention head

In [43]:
for emo in emotions:
    pass_word_through_model(emo)
print(emb_sums[25]['happy'])
print(emb_sums[25]['sad'])

tensor([  8.8409,   7.5258,  -6.5818,  ..., -15.0469,  -5.6643,  -5.0516],
       grad_fn=<SumBackward1>)
tensor([ 2.2251,  9.1453,  6.4427,  ..., -5.4765, -3.3644,  6.1680],
       grad_fn=<SumBackward1>)


In [67]:
def cos_dist(emb1, emb2):
    return 1 - f.cosine_similarity(emb1, emb2, dim=0).item()

emo_combos = {c[0] + "-" + c[1]: (c[0], c[1]) for c in combinations(emotions, r=2)}

cos_df = DataFrame(index=layers, columns=emo_combos.keys())

for k, v in emo_combos.items():
    for l in layers:
        cos_df[k][l] = cos_dist(emb_sums[l][v[0]], emb_sums[l][v[1]])

cos_fig = px.line(cos_df, title="Cosine Distance Between Emotions", labels={"index": "Layer", "value": "Cosine Distance"})
cos_fig.show()

In [68]:
def euc(emb1, emb2):
    return cdist(emb1.unsqueeze(0), emb2.unsqueeze(0)).item()

euc_df = DataFrame(index=layers, columns=emo_combos.keys())

for k, v in emo_combos.items():
    for l in layers:
        euc_df[k][l] = euc(emb_sums[l][v[0]], emb_sums[l][v[1]])

euc_fig = px.line(cos_df, title="Euclidean Distance Between Emotions", labels={"index": "Layer", "value": "Euclidean Distance"})
euc_fig.show()

In [69]:
def figures_to_html(figs, filename="dashboard.html"):
    with open(filename, 'w', encoding='utf-8') as dashboard:
        dashboard.write("<html><head></head><body>" + "\n")
        for fig in figs:
            inner_html = fig.to_html().split('<body>')[1].split('</body>')[0]
            dashboard.write(inner_html)
        dashboard.write("</body></html>" + "\n")

figures_to_html([cos_fig] + [euc_fig], filename="results/06.layer_distances.html")