In [1]:
import torch
from easy_transformer import EasyTransformer

#device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'
print(f"Using {device} device")
torch.set_grad_enabled(False)

model = EasyTransformer.from_pretrained('gpt2').to(device)

# Convenience function for decoding token
decode = model.tokenizer.decode

# Convenience function for encoding token
def encode(t):
    global model
    result = model.tokenizer.encode(t)
    if len(result) != 1:
        raise Exception(f"Not a single token: {t}")
    return result[0]

unembed = model.unembed.W_U.data
embed = model.embed.W_E.data
d_M = model.cfg.d_model
d_V = model.cfg.d_vocab

print(unembed.shape, embed.shape, d_M, d_V)

Using cpu device
Loading model: gpt2


  return torch._C._cuda_getDeviceCount() > 0
Using pad_token, but it is not set yet.


Moving model to device:  cpu
Finished loading pretrained model gpt2 into EasyTransformer!
Moving model to device:  cpu
torch.Size([768, 50257]) torch.Size([50257, 768]) 768 50257


In [65]:
from sklearn.cluster import KMeans
unembed_norm = torch.nn.functional.normalize(unembed, dim=0)
cluster_indices = KMeans(n_clusters=50).fit_predict(unembed_norm.T)

In [66]:
import random
from collections import defaultdict
random.seed(12345)
bins = defaultdict(list)
for t in range(d_V):
    cluster = cluster_indices[t]
    bins[cluster].append(t)
for b in sorted(bins.keys()):
    print(f'+++ cluster {b}+++')
    show = random.sample(bins[b], k=10)
    print([decode(t) for t in show])

+++ cluster 0+++
['holm', 'v', 'mega', 'imgur', 'dr', 'len', 'riot', 'nas', 'mas', 'thy']
+++ cluster 1+++
[' its', ' meaning', ' least', ' elsewhere', ' courtesy', ' —', ' beneath', ' anymore', ' those', ' itself']
+++ cluster 2+++
[' Fork', ' Pour', ' Thursday', ' Bin', ' Warn', ' Lot', ' Grove', ' Bub', ' Cre', ' Set']
+++ cluster 3+++
[' nicotine', ' tum', ' amp', ' psycho', ' gravitational', ' prote', ' vaccinated', ' arter', ' flu', ' rodents']
+++ cluster 4+++
['Brother', 'J', 'R', 'Legend', 'Revolution', 'Iron', 'Spell', 'Pub', 'Tumblr', 'Republic']
+++ cluster 5+++
[' dim', ' cer', ' null', ' iso', ' pap', ' inv', ' fur', ' feat', ' javascript', ' fab']
+++ cluster 6+++
[' Planned', ' Championships', ' Governors', ' Clubs', ' Film', ' Special', ' Cabin', ' Associates', ' Republican', ' Budget']
+++ cluster 7+++
[' +++', ' ©', '<<', ' �', '/-', 'Ã', '—"', 'soever', ':', '.']
+++ cluster 8+++
[' striking', ' combating', ' wanting', ' surfing', ' caring', ' writing', ' unfolding'

In [89]:
from sklearn.neighbors import BallTree
from sklearn.decomposition import PCA
unembed_norm = torch.nn.functional.normalize(unembed, dim=0)
ball_tree0 = BallTree(unembed_norm.T)
unembed_norm = torch.nn.functional.normalize(unembed, dim=1)
ball_tree1 = BallTree(unembed_norm.T)

#pca = PCA(n_components=d_M)
#unembed_pca = pca.fit_transform(unembed)
#ball_tree2 = BallTree(unembed_pca)


In [94]:
import numpy as np
toks = [' peace', ' love', ' war', ' cat', ' dog', ' leaves']
ts = [encode(tok) for tok in toks]
v = unembed[:,ts].T
d0,q0 = ball_tree0.query(v, k=20, return_distance=True)
d1,q1 = ball_tree1.query(v, k=20, return_distance=True)
for i,tok in enumerate(toks):
    print(f'+++{tok}+++')
    for j,(t0,t1) in enumerate(zip(q0[i,:],q1[i,:])):
        print('    ', f'{decode(t0):20}', f'{decode(t1):20}')


+++ peace+++
      peace                peace              
      Peace                Peace              
     peace                peace               
     Peace                Peace               
      peaceful             peaceful           
      ceasefire            truce              
      truce                peac               
      peac                 tranqu             
      tranqu               ceasefire          
      freedom              peacefully         
      war                  pacif              
      prosperity           freedom            
      pacif                war                
      security             harmony            
      happiness            security           
      harmony              unity              
      calm                 prosperity         
      peacefully           reconciliation     
      unity                calm               
      conflict             happiness          
+++ love+++
      love                 love    