In [1]:
import torch
from easy_transformer import EasyTransformer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
#device = 'cpu'
print(f"Using {device} device")
torch.set_grad_enabled(False)

model = EasyTransformer.from_pretrained('gpt2').to(device)

# Convenience function for decoding token
decode = model.tokenizer.decode

# Convenience function for encoding token
def encode(t):
    global model
    result = model.tokenizer.encode(t)
    if len(result) != 1:
        raise Exception(f"Not a single token: {t}")
    return result[0]

unembed = model.unembed.W_U.data
embed = model.embed.W_E.data
d_M = model.cfg.d_model
d_V = model.cfg.d_vocab

unembed_norm = torch.nn.functional.normalize(unembed, dim=0)

Using cuda device
Loading model: gpt2


Using pad_token, but it is not set yet.


Moving model to device:  cuda
Finished loading pretrained model gpt2 into EasyTransformer!
Moving model to device:  cuda


In [32]:
class OrthModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.m = torch.nn.parameter.Parameter(torch.nn.functional.normalize(torch.normal(0,1,(d_M,d_M)), dim=0).detach())
        self.eye = torch.nn.parameter.Parameter(torch.eye(d_M).detach(), requires_grad=False)
        
    def forward(self, x):
        proj = torch.matmul(self.m.reshape(1, d_M, d_M), x)
        n0 = torch.linalg.vector_norm(proj[:,:d_M//2], dim=1)
        n1 = torch.linalg.vector_norm(proj[:,d_M//2:], dim=1)
        nm = torch.linalg.matrix_norm(torch.matmul(self.m, self.m.T) - self.eye)
        return torch.maximum(nm, torch.minimum(n0, n1) / (n0 + n1)).sum()

In [33]:
import random

model = OrthModel().to(device)
optim = torch.optim.Adam(model.parameters())
torch.set_grad_enabled(True)

total_loss = torch.zeros(()).to(device)
for i in range(10000):
    v = unembed[:, random.choices(range(d_V), k=64)]
    optim.zero_grad()
    loss = model(v)
    loss.backward()
    optim.step()
    total_loss += loss.detach().sum()
    if i % 100 == 99:
        print(total_loss.item())
        total_loss = torch.zeros(()).to(device)


42142.12109375
3169.84814453125
3094.2080078125
3004.171630859375
3129.3564453125
2982.315673828125
2904.879638671875
2815.7685546875
2779.310791015625
2753.8740234375
2720.760498046875
2732.699462890625
2813.479248046875
2730.5810546875
2626.838623046875
2599.5673828125
2652.43212890625
2778.969482421875
2555.29150390625
2542.2626953125
2519.49853515625
2588.35400390625
2668.92138671875
2930.870849609375
2502.163330078125
2477.93408203125
2461.53857421875
2448.879150390625
2441.616455078125
2442.08544921875
2498.725830078125
2470.211669921875
2417.249755859375
2419.16552734375
2407.465576171875
2421.244140625
2572.24658203125
2396.061279296875
2402.81103515625
2395.29736328125
2377.490478515625
2400.641845703125
2637.973876953125
2359.19677734375
2431.5263671875
2351.19775390625
2370.3115234375
2518.938232421875
2458.9609375
2338.905029296875
2384.572021484375
2333.310791015625
2427.00830078125
2472.6923828125
2317.54833984375
2326.238525390625
2324.6171875
2324.80126953125
2327.01440

In [34]:
torch.set_grad_enabled(False)
mat = model.m
for i in random.sample(range(d_V), k=100):
    v = unembed[:,i]
    w = torch.matmul(mat, v)
    n0 = torch.linalg.vector_norm(w[:d_M//2]).item()
    n1 = torch.linalg.vector_norm(w[d_M//2:]).item()
    print(i, decode(i), n0, n1)

31485  clauses 2.2015624046325684 4.288162708282471
12525  ally 1.978772521018982 3.9806699752807617
10277 uz 2.947826385498047 3.9549875259399414
49511  Useful 2.2316324710845947 3.57596492767334
48772  dehuman 1.7610663175582886 4.324909210205078
1024  De 2.0421578884124756 3.0899152755737305
49863  UFOs 1.950603723526001 4.148329734802246
41723  hesitated 2.081098794937134 3.7363429069519043
17442 hett 3.2232377529144287 4.6391496658325195
37058 Generally 1.6794567108154297 3.5697903633117676
44544  Khalid 2.1657705307006836 4.27364444732666
47627 Pitt 2.0935235023498535 4.602341175079346
23348 falls 3.1604034900665283 4.276700019836426
6433 ework 3.395779848098755 4.734602451324463
6171  syn 2.0913803577423096 4.118034362792969
29224  Leicester 1.5256083011627197 3.992495059967041
24259 Details 2.1429195404052734 3.9507293701171875
10499  Stadium 2.3156659603118896 3.9979496002197266
14800  fierce 1.6110990047454834 3.693042278289795
7694  pref 1.9805717468261719 3.760049343109131
