In [18]:
import torch
import einops
import transformer_lens
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [8]:
p = 113
a_vector = einops.repeat(torch.arange(p,device=device), 'i -> (i j)', j=p)
b_vector = einops.repeat(torch.arange(p,device=device), 'j -> (i j)', i=p)
equals_vector = einops.repeat(torch.tensor(113), ' -> (i j)', i=p,j=p)

In [9]:
dataset = torch.stack([a_vector,b_vector,equals_vector],dim=1)
dataset = dataset.to(device)
print(dataset[5:])

tensor([[  0,   5, 113],
        [  0,   6, 113],
        [  0,   7, 113],
        ...,
        [112, 110, 113],
        [112, 111, 113],
        [112, 112, 113]])


In [10]:
labels = (dataset[:,0] + dataset[:,1]) % p
print(labels.shape)
print(labels[:5])

torch.Size([12769])
tensor([0, 1, 2, 3, 4])


In [11]:
indices = torch.randperm(p*p)
cutoff = int(p*p*0.3)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]
train_data = dataset[train_indices]
test_data = dataset[test_indices]
train_labels = labels[train_indices]
test_labels = labels[test_indices]

In [359]:
indices

tensor([1379, 4519, 5648,  ..., 2701, 9462, 1597])

In [19]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = 'relu',
    normalization_type = None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True
 )

In [20]:
model = HookedTransformer(cfg)
m = model.to(device)

Moving model to device:  cpu


In [306]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

In [23]:
lr = 1e-3
wd = 1.
betas = (0.9,0.98)
num_epochs = 100
checkpoint_every = 10

In [24]:
optimizer = torch.optim.AdamW(model.parameters(),lr=lr,weight_decay=wd,betas=betas)

In [21]:
def loss_fn(logits,labels):
    if len(logits.shape)==3:
        logits = logits[:,-1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1,index=labels[:,None])[:,0]
    return -correct_log_probs.mean()

In [357]:
train_data

tensor([[ 12,  23, 113],
        [ 39, 112, 113],
        [ 49, 111, 113],
        ...,
        [104,  29, 113],
        [ 62,  58, 113],
        [ 88,  20, 113]])

In [355]:
train_labels[114]

tensor(59)

In [310]:
train_logits = model(train_data)
train_loss = loss_fn(train_logits,train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits,test_labels)
print(test_loss)

tensor(4.7330, dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.7331, dtype=torch.float64, grad_fn=<NegBackward0>)


In [311]:
print(np.log(p))

4.727387818712341


In [312]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
for epoch in tqdm.tqdm(range(num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits,train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())

    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits,test_labels)
        test_losses.append(test_loss.item())

    if ((epoch)%checkpoint_every)==0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch} Train Loss {train_loss.item()} Test :pss {test_loss.item()}")


  1%|█▊                                                                                                                                                                             | 1/100 [00:00<01:12,  1.37it/s]

Epoch 0 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 11%|███████████████████▏                                                                                                                                                          | 11/100 [00:07<01:01,  1.45it/s]

Epoch 10 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 21%|████████████████████████████████████▌                                                                                                                                         | 21/100 [00:14<00:58,  1.36it/s]

Epoch 20 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 31%|█████████████████████████████████████████████████████▉                                                                                                                        | 31/100 [00:21<00:48,  1.42it/s]

Epoch 30 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 41%|███████████████████████████████████████████████████████████████████████▎                                                                                                      | 41/100 [00:28<00:40,  1.46it/s]

Epoch 40 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 51%|████████████████████████████████████████████████████████████████████████████████████████▋                                                                                     | 51/100 [00:35<00:32,  1.53it/s]

Epoch 50 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                   | 61/100 [00:41<00:24,  1.58it/s]

Epoch 60 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                  | 71/100 [00:48<00:18,  1.56it/s]

Epoch 70 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                 | 81/100 [00:54<00:12,  1.58it/s]

Epoch 80 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎               | 91/100 [01:00<00:05,  1.59it/s]

Epoch 90 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:07<00:00,  1.49it/s]
