#  GPT2 Evaluations  

This notebook analyzes residual stream behavior in **GPT-2**, motivated by the analysis implemented on GPT-2 with TransformerLens from [Exploratory Analysis Demo - TransformerLens (Colab)](https://colab.research.google.com/github/neelnanda-io/TransformerLens/blob/main/demos/Exploratory_Analysis_Demo.ipynb#scrollTo=Q-L0x0cIrkXq). As part of my effort to test my understanding of the concepts from [Superposition & SAEs - Arena 3 Chapter 1](https://arena3-chapter1-transformer-interp.streamlit.app/%5B1.3.1%5D_Superposition_&_SAEs), I have used Hugging Face `transformers` library instead of the library used in the mentioned experiments, `transformer_lens`. 

Below is my implementation and analysis aimed at reinforcing my comprehension of **superposition and sparse autoencoders (SAEs)** in transformers.

- this is useful in the decision making part

In [253]:
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import numpy as np
import pandas as pd

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2', return_dict=True)
model.eval()
#inputs = tokenizer("Hello, my dog is cute", return_tensors="pt", add_special_tokens=True)
#inputs = tokenizer("Oh, I love ice cream so much!", return_tensors="pt", add_special_tokens=True)
inputs = tokenizer("I am not sure what you are saying", return_tensors="pt", add_special_tokens=True)
with torch.no_grad():
    outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss
logits = outputs.logits

In [21]:
vocab_ = {idx: i for idx, i in enumerate(tokenizer.get_vocab())}

In [120]:
in_ids = inputs.input_ids[0]

In [164]:
def decode_text(vocab_dict, logits): 
    # tried with log_softmax, min, max.. 
    logs = logits[0].log_softmax(-1)
    votes = [np.argmin(i) for i in logs]
    print([vocab_dict[i.item()] for i in votes])
    print(votes)
    return

def idx_layers(input_ids, logits):
    # output the layers tensors (softmax) for vocab index
    logs = logits[0].softmax(-1)
    posit = []
    for i in input_ids:
        temp = {}
        temp['input_id'] = i.item()
        for idj, j in enumerate(logs):
            #posit.append({'input_id': i.item(), f'layer_{idj}': j[i].item()})
            temp[f'layer_{idj}'] = j[i].item()
        posit.append(temp)
    data = pd.DataFrame(posit)

    mat = np.eye(len(input_ids), dtype=int)
    for idx, row in enumerate(mat):
        data[f'pos_{idx}'] = row
    return data

#decode_text(vocab_, logits)

### text 1

In [167]:
vals = idx_layers(in_ids, logits)

In [168]:
vals

Unnamed: 0,input_id,layer_0,layer_1,layer_2,layer_3,layer_4,layer_5,pos_0,pos_1,pos_2,pos_3,pos_4,pos_5
0,15496,4.5e-05,3.9e-05,2e-06,7.795149e-07,3.51872e-08,1.209484e-07,1,0,0,0,0,0
1,11,0.096024,0.000113,0.001364,0.08521014,0.0007885353,0.133517,0,1,0,0,0,0
2,616,0.00095,0.02245,0.000527,0.0006723823,0.005774269,0.000277839,0,0,1,0,0,0
3,3290,1.8e-05,3.5e-05,0.001355,0.0002823597,0.0002465781,0.0001116447,0,0,0,1,0,0
4,318,0.008551,0.005259,0.000194,0.1778956,8.198121e-05,0.0001858972,0,0,0,0,1,0
5,13779,7e-06,3.2e-05,0.000302,2.071713e-05,0.004165464,0.0001076876,0,0,0,0,0,1


In [251]:
rel = []
testing = []
for i in range(5):
    term = f'pos_{i}'
    relt = vals[['layer_0', 'layer_1', 'layer_2', 'layer_3', 'layer_4', 'layer_5', term]].corr()[term].values[:-1]
    dove = [1 + i for i in relt] # asserts to one direction
    rel.append(dove)
    print(np.argmin(dove)) # gets the index of the min correlated 
    ct = dove - np.min(dove)
    testing.append(ct[i]) # normalise / standardise 


2
1
3
4
2


In [252]:
testing

[0.2894119575631866,
 0.0,
 0.20755465126025996,
 0.02400450020538991,
 0.007921430576414146]

In [227]:
vals.corr()['pos_0'].sort_values(ascending=True)

layer_2    -0.512407
layer_4    -0.362522
layer_3    -0.291883
layer_1    -0.252291
layer_0    -0.222995
layer_5    -0.201228
pos_2      -0.200000
pos_3      -0.200000
pos_4      -0.200000
pos_5      -0.200000
pos_1      -0.200000
input_id    0.681041
pos_0       1.000000
Name: pos_0, dtype: float64

In [229]:
vals.corr()['pos_1'].sort_values(ascending=True)

input_id   -0.383021
layer_1    -0.248246
layer_4    -0.207403
pos_2      -0.200000
pos_3      -0.200000
pos_4      -0.200000
pos_5      -0.200000
pos_0      -0.200000
layer_3     0.273205
layer_2     0.609783
layer_0     0.996248
layer_5     0.999999
pos_1       1.000000
Name: pos_1, dtype: float64

In [230]:
vals.corr()['pos_2'].sort_values(ascending=True)

input_id   -0.341448
layer_3    -0.287429
layer_0    -0.211498
pos_0      -0.200000
pos_1      -0.200000
pos_3      -0.200000
pos_4      -0.200000
pos_5      -0.200000
layer_5    -0.198729
layer_2    -0.079875
layer_4     0.773420
layer_1     0.972652
pos_2       1.000000
Name: pos_2, dtype: float64

### text 2

In [238]:
outputs.logits[0].shape

torch.Size([9, 50257])

In [236]:
text2 = idx_layers(inputs.input_ids[0], outputs.logits)

In [242]:
text2

Unnamed: 0,input_id,layer_0,layer_1,layer_2,layer_3,layer_4,layer_5,layer_6,layer_7,layer_8,pos_0,pos_1,pos_2,pos_3,pos_4,pos_5,pos_6,pos_7,pos_8
0,5812,7.1e-05,1.1e-05,3.062762e-09,8.67939e-08,8.759316e-08,2.939493e-07,1.340631e-07,6.194054e-07,2e-05,1,0,0,0,0,0,0,0,0
1,11,0.074408,6e-05,0.001334366,0.001132241,0.007897571,0.1280085,0.005436836,0.1541862,8.5e-05,0,1,0,0,0,0,0,0,0
2,314,0.008329,0.120552,3.716908e-05,0.0005832549,4.499895e-05,0.0005181681,0.05229215,0.03471816,0.299114,0,0,1,0,0,0,0,0,0
3,1842,0.000199,0.000255,0.01323957,0.0008150191,0.0001043985,3.853797e-05,0.0003477961,5.235338e-05,1.3e-05,0,0,0,1,0,0,0,0,0
4,4771,2.2e-05,1.9e-05,4.131381e-07,0.0001521628,0.0008609202,0.0009561596,1.514738e-05,3.349561e-05,1e-05,0,0,0,0,1,0,0,0,0
5,8566,6e-06,3e-06,9.258783e-07,6.199149e-06,0.8419573,0.0002660695,2.180498e-05,4.69828e-06,2e-06,0,0,0,0,0,1,0,0,0
6,523,0.001847,0.013288,7.787708e-05,0.0003182518,0.0003309257,0.01256997,0.0009461293,0.002561176,0.000125,0,0,0,0,0,0,1,0,0
7,881,0.000756,0.000318,2.033654e-05,4.400202e-05,7.769692e-06,0.0002299094,0.8259637,6.091613e-05,1e-06,0,0,0,0,0,0,0,1,0
8,0,0.002759,5.3e-05,4.552175e-05,0.0001431292,0.002398171,0.06411234,0.000644014,0.1002238,0.000266,0,0,0,0,0,0,0,0,1


In [241]:
t_cols = [
 'layer_0',
 'layer_1',
 'layer_2',
 'layer_3',
 'layer_4',
 'layer_5',
 'layer_6',
 'layer_7',
 'layer_8']

In [247]:
rel = []
testing = []
for i in range(8):
    term = f'pos_{i}'
    relt = text2[t_cols+[term]].corr()[term].values[:-1]
    dove = [1 + i for i in relt] # asserts to one direction
    rel.append(dove)
    print(np.argmin(dove)) # gets the index of the min correlated 
    print((dove - np.min(dove))) # normalise / standardise 
    ct = dove - np.min(dove)
    testing.append(ct[i])

3
[0.18008246 0.18952541 0.18950535 0.         0.2032085  0.1370206
 0.19515668 0.11505376 0.20497997]
1
[1.13419497 0.         0.11398414 0.86339541 0.02379302 1.02349289
 0.012629   0.94785077 0.01524064]
5
[0.16579342 1.18276395 0.05131798 0.4012167  0.06189306 0.
 0.12551504 0.20397536 1.1887755 ]
7
[0.06664515 0.07642766 1.20982698 0.64283371 0.08795117 0.02194526
 0.08023666 0.         0.0895565 ]
7
[0.0640529  0.07433034 0.07426868 0.02622293 0.08908882 0.0297869
 0.07990543 0.         0.08967163]
3
[0.17339499 0.18376721 0.18389744 0.         1.32442293 0.13356851
 0.18949934 0.10939375 0.19922522]
7
[0.07537539 0.18245814 0.06414645 0.16398848 0.07161221 0.11068359
 0.06441529 0.         0.073338  ]
3
[0.14976057 0.15155458 0.15038992 0.         0.1623591  0.09809182
 1.28734959 0.07459406 0.16405068]


In [248]:
testing

[0.18008246411162887,
 0.0,
 0.05131797851911002,
 0.6428337149444726,
 0.08908881767349763,
 0.13356851258712876,
 0.06441529355334963,
 0.07459405813466546]

### text 3

In [254]:
text3 = idx_layers(inputs.input_ids[0], outputs.logits)

In [255]:
text3

Unnamed: 0,input_id,layer_0,layer_1,layer_2,layer_3,layer_4,layer_5,layer_6,layer_7,pos_0,pos_1,pos_2,pos_3,pos_4,pos_5,pos_6,pos_7
0,40,0.002661,6.063355e-07,6.423055e-08,6e-06,1.1e-05,4.170237e-07,2.639108e-07,6e-06,1,0,0,0,0,0,0,0
1,716,0.001746,0.0003026406,0.0001029701,2.6e-05,0.00056,0.0002702318,8.604529e-05,3.8e-05,0,1,0,0,0,0,0,0
2,407,0.002509,0.1058596,0.0004131822,0.000176,0.000145,1.442704e-05,0.0008333408,7.5e-05,0,0,1,0,0,0,0,0
3,1654,0.000338,0.02475403,0.09361645,8.1e-05,3e-06,2.25626e-05,0.0007844536,4e-06,0,0,0,1,0,0,0,0
4,644,0.00055,0.0003524971,0.0007186912,0.139963,0.000465,8.011254e-06,7.344699e-05,0.000308,0,0,0,0,1,0,0,0
5,345,0.002529,0.000279095,0.0002946961,0.008093,0.068503,5.29631e-05,3.192253e-05,0.001501,0,0,0,0,0,1,0,0
6,389,0.006713,9.016027e-06,6.146159e-06,7.9e-05,0.001563,0.2194551,4.04877e-05,0.000317,0,0,0,0,0,0,1,0
7,2282,0.000119,0.001543669,0.0307891,3.4e-05,9e-06,4.123323e-05,0.09786332,4.7e-05,0,0,0,0,0,0,0,1


In [260]:
cols = [i for i in text3.columns if 'layer' in str(i)]
rel = []
testing = []
for i in range(8):
    term = f'pos_{i}'
    relt = text3[cols+[term]].corr()[term].values[:-1]
    dove = [1 + i for i in relt] # asserts to one direction
    rel.append(dove)
    print(np.argmin(dove)) # gets the index of the min correlated 
    print((dove - np.min(dove))) # normalise / standardise 
    ct = dove - np.min(dove)
    testing.append(ct[i])

7
[0.32220221 0.04262744 0.03266816 0.07153088 0.07484477 0.08092448
 0.07814297 0.        ]
7
[0.1219353  0.01988662 0.00788436 0.04565956 0.05802381 0.05629492
 0.05311235 0.        ]
2
[0.25568243 1.15951931 0.         0.03523701 0.03939199 0.04330615
 0.05020627 0.01722835]
0
[0.         0.43282826 1.29118129 0.19236495 0.19491855 0.20125984
 0.20754516 0.1185733 ]
0
[0.         0.12639617 0.12133676 1.30237946 0.16238538 0.16089476
 0.15893061 0.32067811]
2
[0.26081376 0.00941424 0.         0.1017856  1.18759874 0.04494757
 0.04226304 1.15557393]
2
[1.06153272 0.00997705 0.         0.03938444 0.06812869 1.19134159
 0.04587185 0.21499322]
0
[0.         0.22151732 0.56909586 0.23382453 0.23687449 0.24319709
 1.38609417 0.19511951]


In [261]:
testing

[0.3222022062647363,
 0.019886624951979126,
 0.0,
 0.1923649491478262,
 0.16238537787156082,
 0.04494757481957479,
 0.04587184879143125,
 0.19511950529074373]

In [267]:
t_mat = text3[cols].values

In [274]:
[np.mean(i) for i in t_mat] # average logits for token 1 by the logits for other tokens in the text

[0.00033562888189386797,
 0.00039160656297099194,
 0.01375314479548706,
 0.014950348429721316,
 0.01780478736270652,
 0.010160676942177815,
 0.028522721655861005,
 0.016305899860867612]

In [277]:
[np.mean(i) for i in text2[t_cols].values]

[1.1445471724084536e-05,
 0.04139429997343945,
 0.05735431251297188,
 0.0016737773724647316,
 0.00022992320202459067,
 0.0935853628791329,
 0.003562724509619228,
 0.09193350504661642,
 0.01896057293278217]

### Using MLP

In [56]:
mlp_layers = []
for i in range(11):
    mlp_layers.append(model.transformer.h[i].mlp.c_proj.weight)

In [92]:
model.transformer

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [83]:
model.transformer.wpe.weight.shape

torch.Size([1024, 768])

In [91]:
model.transformer.wte.weight.shape

torch.Size([50257, 768])

In [75]:
mlp_layers[0].shape

torch.Size([3072, 768])

In [78]:
outputs.logits.shape

torch.Size([1, 6, 50257])