In [87]:
%load_ext autoreload
%autoreload 2
from alphatoe import evals, data, game, train
import torch
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from transformer_lens import HookedTransformerConfig, HookedTransformer
import json
from transformer_lens import HookedTransformer, HookedTransformerConfig
from typing import Callable, Any
import einops

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [88]:
train_data, train_labels, test_data, test_labels = data.gen_data('minimax all', device='cuda')

Generating all possible games...
Generated 255168 games
Generated array of moves


100%|██████████| 255168/255168 [01:16<00:00, 3317.99it/s]


torch.Size([255168, 10, 10])
Generated all data and minimax labels
torch.Size([255168, 10, 10])
torch.Size([255168, 10, 10])


In [89]:
train_labels.shape

torch.Size([204134, 10, 10])

In [90]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 4,
    d_model = 16,
    d_head = 4,
    d_mlp = 64,
    act_fn = "relu",
    #normalization_type=None,
    normalization_type='LN',
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device="cuda",
    seed = 1337,
)

lr = 1e-5
weight_decay = 1e-4
test_train_split = 0.7
epochs = 10_000
batch_size = 4096

In [91]:
model = HookedTransformer(cfg).to(cfg.device)

Moving model to device:  cuda


In [92]:
loss_fn = cross_entropy
optimizer =  torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)

In [93]:
train_losses = []
test_losses = []
# for epoch in tqdm.tqdm(range(epochs)):
for epoch in range(epochs):
    for batch in range(0, len(train_data), batch_size):
        t_logits = train_data[batch:batch+batch_size]
        t_labels = train_labels[batch:batch+batch_size]
        rearranged_train_logits = einops.rearrange(model(t_logits), "batch seq_len one_hots -> (batch seq_len) one_hots")
        rearranged_train_labels = einops.rearrange(t_labels, "batch seq_len one_hots -> (batch seq_len) one_hots")
        train_loss = loss_fn(rearranged_train_logits,rearranged_train_labels)

        train_loss.backward()

        train_losses.append(train_loss.item())
        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)

            rearranged_test_logits = einops.rearrange(test_logits, "batch seq_len one_hots -> (batch seq_len) one_hots")
            rearranged_test_labels = einops.rearrange(test_labels, "batch seq_len one_hots -> (batch seq_len) one_hots")

            test_loss = loss_fn(rearranged_test_logits, rearranged_test_labels)
            test_losses.append(test_loss.item())

    print(f"Epoch {epoch} | Train Loss: {train_loss.item()} | Test Loss: {test_loss.item()}")

Epoch 0 | Train Loss: 2.4147329330444336 | Test Loss: 2.4136805534362793
Epoch 1 | Train Loss: 2.257502555847168 | Test Loss: 2.2587597370147705
Epoch 2 | Train Loss: 2.129323959350586 | Test Loss: 2.1316447257995605
Epoch 3 | Train Loss: 2.035799503326416 | Test Loss: 2.0383479595184326
Epoch 4 | Train Loss: 1.973514199256897 | Test Loss: 1.9756914377212524
Epoch 5 | Train Loss: 1.9300286769866943 | Test Loss: 1.931776523590088
Epoch 6 | Train Loss: 1.897434949874878 | Test Loss: 1.898775577545166
Epoch 7 | Train Loss: 1.8707481622695923 | Test Loss: 1.871781587600708
Epoch 8 | Train Loss: 1.8469898700714111 | Test Loss: 1.8478059768676758
Epoch 9 | Train Loss: 1.8250296115875244 | Test Loss: 1.8256806135177612
Epoch 10 | Train Loss: 1.8041012287139893 | Test Loss: 1.804573655128479
Epoch 11 | Train Loss: 1.7838408946990967 | Test Loss: 1.7841761112213135
Epoch 12 | Train Loss: 1.7641794681549072 | Test Loss: 1.764359474182129
Epoch 13 | Train Loss: 1.7447607517242432 | Test Loss: 1.7

KeyboardInterrupt: 