In [11]:
import torch
from easy_transformer import EasyTransformer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(f"Using {device} device")
torch.set_grad_enabled(False)

model = EasyTransformer.from_pretrained('gpt2').to(device)

# Convenience function for decoding token
decode = model.tokenizer.decode

# Convenience function for encoding token
def encode(t):
    global model
    result = model.tokenizer.encode(t)
    if len(result) != 1:
        raise Exception(f"Not a single token: {t}")
    return result[0]

unembed = model.unembed.W_U.data
embed = model.embed.W_E.data
d_M = model.cfg.d_model
d_V = model.cfg.d_vocab

print(unembed.shape, d_M, d_V)

Using cpu device
Loading model: gpt2


Using pad_token, but it is not set yet.


Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!
Moving model to device:  cpu
torch.Size([768, 50257]) 768 50257


In [18]:
# Choose a random set of tokens (TODO: do better than random)
import random
n_dict = 2000
random.seed(12345)
indices = torch.tensor(random.sample(range(d_V), k=n_dict))
print([decode(i) for i in indices[:10]])

dictionary = embed[indices,:]
print(dictionary.shape)

[' savage', 'Magn', 'ian', './', ' Grove', ' Others', ' faction', 'omsky', 'ewitness', ' Atlantic']
torch.Size([2000, 768])


In [54]:
# Choose another set of tokens
toks = [' peace', ' love', ' war', ' cat', ' dog', ' Mario', ' Giles']
ts = torch.tensor([encode(tok) for tok in toks])
n_t = len(ts)
vs = embed[ts,:]    # gather lookup
print(vs.shape)

from sklearn.decomposition import SparseCoder
transformed = torch.tensor(SparseCoder(dictionary=dictionary, transform_n_nonzero_coefs=5).transform(vs))
last_tok = None
for i in range(n_t):
    print(f'+++{toks[i]}+++')
    nz = torch.nonzero(transformed[i,:])
    values = [(transformed[i,j].item(), decode(indices[j.item()])) for j in nz]
    values.sort(reverse=True)
    for val, tok in values:
        print('    ', f'{tok:20}', val)
    
    v = embed[ts[i],:]
    reconstituted = torch.matmul(transformed[i,:], dictionary)
    #print(v[:20])
    #print(reconstituted[:20])
    print(torch.linalg.vector_norm(v - reconstituted).item())

torch.Size([7, 768])
+++ peace+++
      settlers            0.17874906957149506
      conscience          0.16404199600219727
      cessation           0.14536640048027039
     �                    0.11943838745355606
     afety                0.061829060316085815
3.0729010105133057
+++ love+++
      hatred              0.32344797253608704
     -                    0.17184175550937653
      goodness            0.12907937169075012
      Beautiful           0.11889779567718506
     favorite             0.09814704209566116
2.5666236877441406
+++ war+++
      fight               0.29392004013061523
      hatred              0.2270224690437317
      troop               0.16807401180267334
      arms                0.15888924896717072
      Github              -0.12153104692697525
2.557136297225952
+++ cat+++
      rabbits             0.18691937625408173
      bas                 0.1638754904270172
      sch                 0.16002868115901947
      catch               0.14967380464076996
  