In [9]:
# First we grab the model and the unembedding weight matrix
import torch
from easy_transformer import EasyTransformer

#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print(f"Using {device} device")
torch.set_grad_enabled(False)

model = EasyTransformer.from_pretrained('gpt2').to(device)

# Convenience function for decoding token
decode = model.tokenizer.decode
encode = model.tokenizer.encode

M_to_V = model.unembed.W_U.data
d_M = model.cfg.d_model
d_V = model.cfg.d_vocab

Using cpu device
Loading model: gpt2


Using pad_token, but it is not set yet.


Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!
Moving model to device:  cpu


In [35]:
from sklearn.decomposition import PCA

V_to_M = torch.nn.functional.normalize(torch.linalg.pinv(M_to_V), dim=1)

pca_model = PCA(n_components=d_M)
M_to_D = torch.tensor(pca_model.fit_transform(M_to_V), dtype=torch.float)
D_to_V = torch.tensor(pca_model.components_, dtype=torch.float)

norms = torch.linalg.vector_norm(M_to_D, dim=0)
#M_to_D = M_to_D / norms.reshape(1, d_D).expand(d_M, d_D)
#D_to_V = D_to_V * norms.reshape(d_D, 1).expand(d_D, d_V)

#print(M_to_V)
#print(torch.matmul(M_to_D, D_to_V))


V_to_D = torch.matmul(V_to_M, M_to_D)
d_D = d_M

print(V_to_M.shape, M_to_D.shape, D_to_V.shape, M_to_V.shape, V_to_D.shape)

torch.Size([50257, 768]) torch.Size([768, 768]) torch.Size([768, 50257]) torch.Size([768, 50257]) torch.Size([50257, 768])


In [46]:
example_tokens = [' cat', ' war', ' banana', ' bat', ' bark', ' leaves']

def print_similar(heading, td):
    print(heading)
    similarity_vec = torch.matmul(V_to_D, td)
    values = [(v.item(), i) for i,v in enumerate(similarity_vec)]
    values.sort(reverse=True)
    for v,i in values[:20]:
        print(f'  {decode(i):20}', f'{v:20}')
    print()

for tok in example_tokens:
    t, = encode(tok)
    td = V_to_D[t,:]
    print_similar(tok, td)

 cat
   cat                       634.8876953125
   cats                  404.85211181640625
   Cat                        401.826171875
  Cat                    327.85736083984375
  cat                     311.1290283203125
   Cats                   273.6062316894531
   dog                   254.74285888671875
   CAT                      220.56103515625
  cats                   218.95774841308594
   catcher                 211.718994140625
   kittens               210.29351806640625
   kitten                202.44485473632812
   dogs                  166.30856323242188
   animal                 162.6132049560547
   pet                   160.79998779296875
   goat                    159.751220703125
   Dog                   155.82252502441406
   rabbit                155.04403686523438
   tiger                  149.2018280029297
   fel                   142.49505615234375

 war
   war                      575.61865234375
   War                   425.15631103515625
   wars              

In [57]:
def subtract_meaning(v0, tok):
    t1, = encode(tok)
    v1 = V_to_D[t1,:]
    similarity0 = torch.matmul(V_to_D, v0)
    similarity1 = torch.matmul(V_to_D, v1)
    ratio = similarity0[t1] / similarity1[t1]
    return v0 - ratio * v1

start_tok = ' leaves'
minus_toks = [' leaf', ' stems']
t0, = encode(start_tok)
v0 = V_to_D[t0,:]
heading = start_tok
for minus_tok in minus_toks:
    v0 = subtract_meaning(v0, minus_tok)
    heading += f' - {minus_tok}'
print_similar(heading, v0)

 leaves -  leaf -  stems
   leaves                  473.925048828125
   leave                   324.559814453125
   leaving                  268.76904296875
   Leaves                 257.9458923339844
   left                  237.17068481445312
   Leaving               174.34120178222656
   depart                167.43597412109375
   gives                 149.91940307617188
   departed               135.7659912109375
  leave                       132.298828125
   begs                  130.66094970703125
  left                   127.87675476074219
  Left                    127.3669662475586
   puts                  125.24998474121094
   makes                 122.95455169677734
   paints                119.03379821777344
  Leave                  117.99737548828125
   departure             117.94159698486328
   Leave                 116.73554992675781
   brings                116.35395812988281



In [61]:
start_tok = ' leaves'
minus_toks = [' leave', ' begs', ' drops', ]
t0, = encode(start_tok)
v0 = V_to_D[t0,:]
heading = start_tok
for minus_tok in minus_toks:
    v0 = subtract_meaning(v0, minus_tok)
    heading += f' - {minus_tok}'
print_similar(heading, v0)

 leaves -  leave -  begs -  drops
   leaves                 387.2460632324219
   Leaves                 244.9817352294922
   leaf                   185.0762939453125
   foliage               166.66305541992188
   plants                152.86337280273438
   stems                 131.53860473632812
   Leaf                  129.63812255859375
   flowers               114.71847534179688
   left                  114.15510559082031
   trees                 107.00654602050781
   bloss                 106.58168029785156
   branches               98.30564880371094
   poses                  97.27914428710938
   limbs                  97.08460235595703
   flies                  96.93152618408203
   logs                    96.4465560913086
   shoots                 96.31671142578125
   bark                   95.44673919677734
   shapes                 94.48100280761719
   crosses                94.38465881347656

