In [346]:
p = 113
a_vector = einops.repeat(torch.arange(p,device=device), 'i -> (i j)', j=p)
b_vector = einops.repeat(torch.arange(p,device=device), 'j -> (i j)', i=p)
equals_vector = einops.repeat(torch.tensor(113), ' -> (i j)', i=p,j=p)

In [347]:
dataset = torch.stack([a_vector,b_vector,equals_vector],dim=1)
dataset = dataset.to(device)
print(dataset[5:])

tensor([[  0,   5, 113],
        [  0,   6, 113],
        [  0,   7, 113],
        ...,
        [112, 110, 113],
        [112, 111, 113],
        [112, 112, 113]])


In [348]:
labels = (dataset[:,0] + dataset[:,1]) % p
print(labels.shape)
print(labels[:5])

torch.Size([12769])
tensor([0, 1, 2, 3, 4])


In [349]:
indices = torch.randperm(p*p)
cutoff = int(p*p*0.3)
train_indices = indices[:cutoff]
test_indices = indices[cutoff:]
train_data = dataset[train_indices]
test_data = dataset[test_indices]
train_labels = labels[train_indices]
test_labels = labels[test_indices]

In [359]:
indices

tensor([1379, 4519, 5648,  ..., 2701, 9462, 1597])

In [304]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 128,
    d_head = 32,
    d_mlp = 512,
    act_fn = 'relu',
    normalization_type = None,
    d_vocab=p+1,
    d_vocab_out=p,
    n_ctx=3,
    init_weights=True
 )

In [305]:
model = HookedTransformer(cfg)
m = model.to(device)

Moving model to device:  cpu


In [306]:
for name, param in model.named_parameters():
    if "b_" in name:
        param.requires_grad = False

In [307]:
lr = 1e-3
wd = 1.
betas = (0.9,0.98)
num_epochs = 100
checkpoint_every = 10

In [308]:
optimizer = torch.optim.AdamW(model.parameters(),lr=lr,weight_decay=wd,betas=betas)

In [309]:
def loss_fn(logits,labels):
    if len(logits.shape)==3:
        logits = logits[:,-1]
    logits = logits.to(torch.float64)
    log_probs = logits.log_softmax(dim=-1)
    correct_log_probs = log_probs.gather(dim=-1,index=labels[:,None])[:,0]
    return -correct_log_probs.mean()

In [357]:
train_data

tensor([[ 12,  23, 113],
        [ 39, 112, 113],
        [ 49, 111, 113],
        ...,
        [104,  29, 113],
        [ 62,  58, 113],
        [ 88,  20, 113]])

In [355]:
train_labels[114]

tensor(59)

In [310]:
train_logits = model(train_data)
train_loss = loss_fn(train_logits,train_labels)
print(train_loss)
test_logits = model(test_data)
test_loss = loss_fn(test_logits,test_labels)
print(test_loss)

tensor(4.7330, dtype=torch.float64, grad_fn=<NegBackward0>)
tensor(4.7331, dtype=torch.float64, grad_fn=<NegBackward0>)


In [311]:
print(np.log(p))

4.727387818712341


In [312]:
train_losses = []
test_losses = []
model_checkpoints = []
checkpoint_epochs = []
for epoch in tqdm.tqdm(range(num_epochs)):
    train_logits = model(train_data)
    train_loss = loss_fn(train_logits,train_labels)
    train_loss.backward()
    train_losses.append(train_loss.item())

    with torch.inference_mode():
        test_logits = model(test_data)
        test_loss = loss_fn(test_logits,test_labels)
        test_losses.append(test_loss.item())

    if ((epoch)%checkpoint_every)==0:
        checkpoint_epochs.append(epoch)
        model_checkpoints.append(copy.deepcopy(model.state_dict()))
        print(f"Epoch {epoch} Train Loss {train_loss.item()} Test :pss {test_loss.item()}")


  1%|█▊                                                                                                                                                                             | 1/100 [00:00<01:12,  1.37it/s]

Epoch 0 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 11%|███████████████████▏                                                                                                                                                          | 11/100 [00:07<01:01,  1.45it/s]

Epoch 10 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 21%|████████████████████████████████████▌                                                                                                                                         | 21/100 [00:14<00:58,  1.36it/s]

Epoch 20 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 31%|█████████████████████████████████████████████████████▉                                                                                                                        | 31/100 [00:21<00:48,  1.42it/s]

Epoch 30 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 41%|███████████████████████████████████████████████████████████████████████▎                                                                                                      | 41/100 [00:28<00:40,  1.46it/s]

Epoch 40 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 51%|████████████████████████████████████████████████████████████████████████████████████████▋                                                                                     | 51/100 [00:35<00:32,  1.53it/s]

Epoch 50 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 61%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                   | 61/100 [00:41<00:24,  1.58it/s]

Epoch 60 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 71%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                                  | 71/100 [00:48<00:18,  1.56it/s]

Epoch 70 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 81%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                                 | 81/100 [00:54<00:12,  1.58it/s]

Epoch 80 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


 91%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▎               | 91/100 [01:00<00:05,  1.59it/s]

Epoch 90 Train Loss 4.732968422566483 Test :pss 4.7331132198038075


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [01:07<00:00,  1.49it/s]


In [313]:
checkpoint_epochs

[0, 10, 20, 30, 40, 50, 60, 70, 80, 90]

In [314]:
model_checkpoints

[OrderedDict([('embed.W_E',
               tensor([[-0.0106, -0.0661,  0.0012,  ...,  0.0381, -0.0521, -0.0787],
                       [ 0.0327,  0.0118, -0.0100,  ..., -0.0464, -0.0770, -0.0137],
                       [ 0.0692,  0.0479,  0.0567,  ...,  0.0835, -0.0384, -0.0117],
                       ...,
                       [-0.1038,  0.0034,  0.0092,  ...,  0.0478, -0.0129, -0.0111],
                       [ 0.0038,  0.0146, -0.0019,  ...,  0.0053,  0.0078,  0.2036],
                       [-0.0591, -0.0171,  0.0252,  ...,  0.0513,  0.0428,  0.1326]])),
              ('pos_embed.W_pos',
               tensor([[-0.0424,  0.0233,  0.1180, -0.0160, -0.1724,  0.0658, -0.0862,  0.0274,
                        -0.0062, -0.0037, -0.0680,  0.0419, -0.0163,  0.0013, -0.0145,  0.0193,
                        -0.0666, -0.0192,  0.0205, -0.0302,  0.0523, -0.0978,  0.0278,  0.0159,
                         0.0943, -0.0321,  0.0016,  0.0172,  0.0212,  0.0226, -0.0032,  0.1008,
             

TypeError: gather() received an invalid combination of arguments - got (device=str, index=Tensor, dim=int, ), but expected one of:
 * (int dim, Tensor index, *, bool sparse_grad)
 * (name dim, Tensor index, *, bool sparse_grad)


tensor([[-0.0062, -0.0394,  0.1386,  ..., -0.1019, -0.0360, -0.1167],
        [ 0.0274, -0.0405,  0.1726,  ..., -0.0881, -0.0698, -0.1460],
        [-0.0181, -0.0391,  0.1916,  ..., -0.1270, -0.0470, -0.1326],
        ...,
        [-0.0127, -0.1076,  0.1516,  ..., -0.1245, -0.0340, -0.1398],
        [ 0.0103, -0.0076,  0.1515,  ..., -0.0583, -0.0530, -0.1329],
        [ 0.0033, -0.0625,  0.1520,  ..., -0.1035, -0.0948, -0.1225]],
       device='mps:0', grad_fn=<SelectBackward0>)

tensor([[[-3.9132e-04,  1.4174e-02, -1.1057e-02,  ..., -3.3036e-01,
           2.5441e-03,  1.8339e-01],
         [-6.7852e-02,  1.3749e-01,  1.2203e-01,  ...,  1.1942e-01,
           1.2071e-01,  5.9258e-02],
         [ 6.7510e-02,  7.1115e-02, -4.2953e-02,  ..., -7.9911e-02,
           3.8432e-02,  2.6050e-02]],

        [[-3.4140e-02,  1.5122e-01, -8.7454e-03,  ..., -4.1693e-01,
          -9.7607e-03,  3.1843e-01],
         [-1.8030e-01,  2.3272e-02,  9.6851e-02,  ...,  5.2350e-02,
          -8.3729e-04,  8.9723e-02],
         [ 3.7811e-03,  1.0317e-01, -6.0901e-02,  ..., -8.1732e-02,
           2.9960e-02,  1.8591e-02]],

        [[ 6.0347e-02,  1.5281e-01,  7.3794e-02,  ..., -1.2587e-01,
           6.5444e-02, -9.8312e-02],
         [-7.1741e-02,  9.7552e-02,  4.0437e-02,  ...,  1.6383e-02,
          -2.8668e-02, -1.4494e-01],
         [ 7.4807e-02,  7.6854e-02, -2.1289e-02,  ..., -4.3670e-02,
          -2.1867e-02, -4.7954e-02]],

        ...,

        [[ 1.0538e-01,  8.1364e-02,

tensor(88)

tensor([ 36,  52, 113])