In [1]:
%matplotlib widget

In [2]:
# First we grab the model and the unembedding weight matrix
import torch
from easy_transformer import EasyTransformer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(f"Using {device} device")
torch.set_grad_enabled(False)

model = EasyTransformer.from_pretrained('gpt2').to(device)

# Convenience function for decoding token
decode = model.tokenizer.decode

unembed = model.unembed.W_U.data
d_model = model.cfg.d_model
d_vocab = model.cfg.d_vocab

  return torch._C._cuda_getDeviceCount() > 0


Using cpu device
Loading model: gpt2


Using pad_token, but it is not set yet.


Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!
Moving model to device:  cpu


In [3]:
ausp = torch.nn.functional.normalize(torch.linalg.pinv(unembed.to('cpu')), dim=1)

In [17]:
import math
starting_tok = ' cat'
sorting = 'dot'     # dot ausp unembed
starting_t, = model.tokenizer.encode(starting_tok)

values = torch.matmul(ausp, ausp[starting_t,:])
uvalues = torch.matmul(ausp[starting_t,:], unembed)
cosine_sim = torch.nn.CosineSimilarity(dim=0)

def unembed_score(t):
    utvalues = torch.matmul(ausp[t,:], unembed)
    return cosine_sim(uvalues, utvalues)

values = [(v.item(),i,uvalues[i].item() - uvalues[starting_t].item()) for i,v in enumerate(values)]
if sorting == 'uasp':
    values.sort(reverse=True)
elif sorting == 'dot':
    values.sort(key=lambda v:v[2], reverse=True)
else:
    raise Exception("No such sorting")
    
for i in range(50):
    t = values[i][1]
    print(f'{values[i][0]:20}', f'{decode(t):15}', f'{values[i][2]:20}', f'{unembed_score(t):20}')


  0.9999999403953552  cat                             0.0   1.0000001192092896
 0.07929762452840805  Cat              -1.036684274673462   0.5645228028297424
 0.36347606778144836 cat              -1.0856115818023682  0.42552128434181213
 0.07540363073348999  cats            -1.1090970039367676   0.5671601295471191
0.058433420956134796 Cat              -1.1182819604873657   0.4754162132740021
  0.3631359040737152  Cats             -1.451160192489624    0.359244704246521
 0.03059900552034378 cats             -1.4670222997665405   0.3103601932525635
 0.17044615745544434  CAT              -1.578359603881836   0.3349841833114624
-0.008014477789402008  dog             -1.6610358357429504   0.3743366003036499
-0.040606603026390076  kitten           -1.700634777545929   0.2816951274871826
 0.22325390577316284  catcher          -1.787339448928833   0.2676182985305786
 0.10287782549858093  kittens         -1.8044873476028442   0.2652002274990082
-0.20099559426307678  rabbit          -1.914835214

In [5]:
wanted = 5000
logits = torch.matmul(ausp[:wanted,:], ausp[:wanted,:].T)

In [6]:
from sklearn.manifold import TSNE

distances = torch.clamp(torch.diag(logits).reshape((wanted, 1)).expand((wanted, wanted)) - logits, 0, None).numpy()

xy = TSNE(n_components=2, verbose=2, metric='precomputed').fit_transform(distances)



[t-SNE] Computing 91 nearest neighbors...
[t-SNE] Indexed 5000 samples in 0.046s...
[t-SNE] Computed neighbors for 5000 samples in 0.312s...
[t-SNE] Computed conditional probabilities for sample 1000 / 5000
[t-SNE] Computed conditional probabilities for sample 2000 / 5000
[t-SNE] Computed conditional probabilities for sample 3000 / 5000
[t-SNE] Computed conditional probabilities for sample 4000 / 5000
[t-SNE] Computed conditional probabilities for sample 5000 / 5000
[t-SNE] Mean sigma: 0.061407
[t-SNE] Computed conditional probabilities in 0.133s
[t-SNE] Iteration 50: error = 98.2690887, gradient norm = 0.2193233 (50 iterations in 1.290s)
[t-SNE] Iteration 100: error = 98.7448959, gradient norm = 0.1663137 (50 iterations in 1.188s)
[t-SNE] Iteration 150: error = 97.9777298, gradient norm = 0.1918240 (50 iterations in 1.030s)
[t-SNE] Iteration 200: error = 99.9186935, gradient norm = 0.1790054 (50 iterations in 0.902s)
[t-SNE] Iteration 250: error = 101.1616058, gradient norm = 0.150612

In [11]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [8, 8]
plt.scatter(xy[:,0], xy[:,1], picker=True)

<matplotlib.collections.PathCollection at 0x6b022f8953c0>