In [2]:
%load_ext autoreload
%autoreload 2
from alphatoe import evals, data, game, train
import torch
from torch.nn.functional import cross_entropy
from tqdm import tqdm
from transformer_lens import HookedTransformerConfig, HookedTransformer
import json
from transformer_lens import HookedTransformer, HookedTransformerConfig
from typing import Callable, Any
import einops

In [3]:
train_data, train_labels, test_data, test_labels = data.gen_data('prob all', device='cuda')

Generating all possible games...
Generated 255168 games
Generated array of moves


100%|██████████| 255168/255168 [01:01<00:00, 4150.90it/s]


torch.Size([255168, 10, 10])
Generated all data and minimax labels
torch.Size([255168, 10, 10])
torch.Size([255168, 10, 10])


In [4]:
train_labels.shape

torch.Size([204134, 10, 10])

In [5]:
stuff = {
    "experiment_name": "minimax all 8 layer big",
    "gametype": "minimax all",
    "fine_tune": null,
    "n_epochs": 1000,
    "lr": 1e-05,
    "weight_decay": 0.0001,
    "batch_size": 4096,
    "train_test_split": 0.8,
    "n_layers": 1,
    "n_heads": 8,
    "d_model": 128,
    "d_head": 16,
    "d_mlp": 512,
    "act_fn": "relu",
    "normalization_type": null,
    "device": "cuda",
    "seed": 1337,
    "save_losses": false,
    "save_checkpoints": false,
    "eval_model": true

}

NameError: name 'null' is not defined

In [6]:
cfg = HookedTransformerConfig(
    n_layers = 1,
    n_heads = 8,
    d_model = 128,
    d_head = 16,
    d_mlp = 512,
    act_fn = "relu",
    normalization_type=None,
    # normalization_type='LN',
    d_vocab=11,
    d_vocab_out=10,
    n_ctx=10,
    init_weights=True,
    device="cuda",
    seed = 1337,
)

lr = 1e-5
weight_decay = 1e-4
test_train_split = 0.7
epochs = 100
batch_size = 4096 

In [7]:
model = HookedTransformer(cfg).to(cfg.device)

Moving model to device:  cuda


In [8]:
loss_fn = cross_entropy
optimizer =  torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-3)

In [9]:
train_losses = []
test_losses = []
# for epoch in tqdm.tqdm(range(epochs)):
for epoch in range(epochs):
    for batch in range(0, len(train_data), batch_size):
        t_logits = train_data[batch:batch+batch_size]
        t_labels = train_labels[batch:batch+batch_size]
        rearranged_train_logits = einops.rearrange(model(t_logits), "batch seq_len one_hots -> (batch seq_len) one_hots")
        rearranged_train_labels = einops.rearrange(t_labels, "batch seq_len one_hots -> (batch seq_len) one_hots")
        train_loss = loss_fn(rearranged_train_logits,rearranged_train_labels)

        train_loss.backward()

        train_losses.append(train_loss.item())
        optimizer.step()
        optimizer.zero_grad()

        with torch.inference_mode():
            test_logits = model(test_data)

            rearranged_test_logits = einops.rearrange(test_logits, "batch seq_len one_hots -> (batch seq_len) one_hots")
            rearranged_test_labels = einops.rearrange(test_labels, "batch seq_len one_hots -> (batch seq_len) one_hots")

            test_loss = loss_fn(rearranged_test_logits, rearranged_test_labels)
            test_losses.append(test_loss.item())

    print(f"Epoch {epoch} | Train Loss: {train_loss.item()} | Test Loss: {test_loss.item()}")

Epoch 0 | Train Loss: 2.079690933227539 | Test Loss: 2.072047472000122
Epoch 1 | Train Loss: 1.8971009254455566 | Test Loss: 1.8926390409469604
Epoch 2 | Train Loss: 1.8152011632919312 | Test Loss: 1.811684250831604
Epoch 3 | Train Loss: 1.7094589471817017 | Test Loss: 1.705434799194336
Epoch 4 | Train Loss: 1.567064881324768 | Test Loss: 1.5633150339126587


KeyboardInterrupt: 

In [13]:
samples = evals.sample_games(model, 1, 2000)

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [00:14<00:00, 139.06it/s]


In [14]:
evals.eval_model(samples)

2000it [00:00, 21860.19it/s]


{'_check_played_repeat_moves': 0.0685,
 '_check_played_after_player_victory': 0.2725,
 '_check_played_after_draw_game': 0.0,
 'inappropriate_end_state': 0.1295,
 '_check_if_illegal_moves': 0.45}

In [6]:
minimax_all_8_layer_weights = torch.load("../scripts/models/minimax all 8 layer big-20230718-163005.pt")

In [7]:
model.load_state_dict(minimax_all_8_layer_weights)

<All keys matched successfully>

In [90]:
sample = evals.sample_games(model, 1, 20000)

  0%|          | 0/20000 [00:00<?, ?it/s]

100%|██████████| 20000/20000 [01:36<00:00, 206.89it/s]


In [130]:
evals.eval_model(sample)

0it [00:00, ?it/s]

3556it [00:00, 17980.32it/s]

[10, 5, 8, 0, 4, 6, 3, 1, 2, 7, 9]
[10, 8, 4, 2, 5, 3, 6, 0, 1, 7, 9]
[10, 0, 4, 5, 8, 1, 2, 7, 3, 6, 9]
[10, 5, 3, 2, 8, 0, 1, 4, 6, 7, 9]
[10, 4, 8, 1, 7, 6, 2, 5, 3, 0, 9]
[10, 8, 4, 6, 7, 1, 0, 5, 2, 3, 9]
[10, 4, 0, 3, 5, 7, 1, 2, 6, 8, 9]
[10, 5, 8, 0, 3, 1, 2, 6, 4, 7, 9]
[10, 2, 4, 8, 7, 1, 5, 3, 6, 0, 9]
[10, 5, 8, 0, 6, 3, 7, 9, 9, 9, 9]
[10, 2, 4, 3, 0, 8, 5, 1, 7, 6, 9]
[10, 7, 1, 8, 6, 2, 5, 0, 4, 3, 9]
[10, 0, 4, 8, 7, 1, 2, 3, 6, 9, 9]
[10, 2, 4, 1, 0, 8, 6, 5, 9, 9, 9]
[10, 0, 4, 1, 2, 6, 3, 8, 5, 9, 9]
[10, 5, 3, 0, 2, 8, 4, 6, 7, 1, 9]
[10, 5, 2, 6, 4, 0, 3, 7, 8, 1, 9]
[10, 0, 4, 2, 1, 7, 3, 5, 6, 8, 9]
[10, 0, 4, 7, 6, 2, 1, 5, 8, 3, 9]
[10, 3, 0, 2, 4, 8, 5, 1, 6, 7, 9]
[10, 0, 4, 2, 1, 3, 7, 9, 9, 9, 9]
[10, 8, 4, 2, 5, 1, 3, 9, 9, 9, 9]
[10, 2, 4, 7, 8, 0, 1, 3, 6, 5, 9]
[10, 7, 6, 2, 4, 5, 8, 0, 1, 3, 9]
[10, 0, 4, 7, 3, 5, 2, 6, 8, 1, 9]
[10, 8, 4, 3, 0, 7, 6, 2, 5, 1, 9]
[10, 7, 6, 4, 1, 5, 3, 0, 8, 2, 9]
[10, 0, 4, 5, 1, 7, 6, 2, 8, 3, 9]
[10, 3, 6, 8, 4, 2, 

7217it [00:00, 17653.64it/s]

[10, 3, 6, 4, 5, 1, 7, 8, 0, 2, 9]
[10, 7, 4, 8, 6, 2, 3, 5, 9, 9, 9]
[10, 0, 4, 7, 6, 2, 1, 3, 5, 8, 9]
[10, 2, 4, 1, 0, 8, 5, 3, 7, 6, 9]
[10, 8, 4, 1, 5, 3, 0, 2, 6, 7, 9]
[10, 7, 1, 8, 6, 2, 5, 0, 4, 3, 9]
[10, 6, 4, 2, 7, 1, 0, 8, 5, 3, 9]
[10, 3, 0, 4, 5, 7, 8, 6, 2, 9, 9]
[10, 3, 0, 2, 5, 4, 6, 8, 7, 1, 9]
[10, 6, 4, 3, 0, 8, 7, 1, 5, 2, 9]
[10, 8, 4, 7, 6, 2, 5, 3, 1, 0, 9]
[10, 6, 4, 2, 3, 5, 8, 0, 7, 1, 9]
[10, 4, 8, 5, 3, 7, 1, 2, 6, 0, 9]
[10, 3, 5, 1, 0, 8, 6, 2, 7, 4, 9]
[10, 8, 4, 3, 1, 7, 6, 2, 5, 0, 9]
[10, 5, 8, 7, 4, 0, 1, 2, 6, 3, 9]
[10, 0, 4, 6, 3, 5, 7, 1, 2, 8, 9]
[10, 5, 4, 2, 8, 0, 7, 1, 9, 9, 9]
[10, 2, 4, 0, 1, 7, 6, 8, 5, 3, 9]
[10, 1, 0, 3, 4, 8, 6, 2, 5, 7, 9]
[10, 3, 5, 2, 0, 7, 4, 8, 6, 1, 9]
[10, 4, 2, 6, 8, 5, 3, 7, 1, 0, 9]
[10, 5, 4, 0, 2, 6, 3, 1, 8, 7, 9]
[10, 6, 4, 7, 8, 0, 3, 5, 1, 2, 9]
[10, 4, 8, 6, 2, 5, 1, 3, 9, 9, 9]
[10, 7, 6, 3, 1, 2, 4, 8, 5, 0, 9]
[10, 0, 4, 3, 6, 2, 8, 1, 9, 9, 9]
[10, 7, 8, 4, 1, 0, 5, 2, 6, 3, 9]
[10, 0, 4, 8, 3, 5, 

10860it [00:00, 17803.99it/s]

[10, 8, 4, 6, 7, 1, 2, 0, 3, 5, 9]
[10, 1, 2, 5, 7, 6, 3, 0, 8, 4, 9]
[10, 3, 4, 0, 6, 2, 1, 8, 7, 9, 9]
[10, 4, 0, 2, 6, 3, 7, 8, 5, 1, 9]
[10, 6, 4, 1, 5, 0, 2, 3, 9, 9, 9]
[10, 0, 4, 5, 7, 3, 8, 6, 9, 9, 9]
[10, 2, 4, 8, 5, 3, 0, 1, 6, 7, 9]
[10, 8, 4, 2, 5, 3, 6, 1, 0, 7, 9]
[10, 8, 4, 2, 5, 3, 7, 6, 1, 9, 9]
[10, 8, 4, 2, 5, 3, 6, 0, 1, 7, 9]
[10, 4, 0, 2, 6, 3, 5, 1, 7, 8, 9]
[10, 2, 4, 5, 8, 0, 1, 7, 3, 6, 9]
[10, 8, 4, 7, 6, 2, 3, 5, 9, 9, 9]
[10, 8, 4, 1, 5, 3, 0, 7, 6, 2, 9]
[10, 0, 4, 7, 8, 2, 6, 1, 9, 9, 9]
[10, 8, 4, 5, 2, 6, 7, 0, 1, 9, 9]
[10, 2, 4, 8, 5, 3, 7, 1, 6, 0, 9]
[10, 5, 8, 7, 3, 6, 4, 0, 2, 1, 9]
[10, 5, 4, 6, 2, 0, 3, 1, 8, 7, 9]
[10, 5, 3, 6, 8, 1, 4, 0, 2, 7, 9]
[10, 1, 7, 8, 2, 5, 4, 6, 3, 0, 9]
[10, 0, 4, 1, 2, 6, 3, 5, 8, 7, 9]
[10, 8, 4, 1, 0, 6, 7, 2, 3, 5, 9]
[10, 7, 6, 3, 4, 2, 1, 5, 8, 0, 9]
[10, 1, 0, 8, 6, 3, 4, 2, 5, 7, 9]
[10, 3, 0, 8, 4, 7, 6, 2, 5, 1, 9]
[10, 5, 3, 8, 2, 0, 4, 6, 7, 1, 9]
[10, 8, 4, 2, 5, 3, 1, 6, 7, 9, 9]
[10, 2, 4, 8, 5, 3, 

14516it [00:00, 17951.12it/s]

[10, 7, 8, 0, 1, 5, 4, 6, 3, 2, 9]
[10, 7, 1, 4, 8, 2, 5, 6, 9, 9, 9]
[10, 7, 8, 4, 1, 5, 6, 3, 9, 9, 9]
[10, 2, 4, 1, 0, 8, 3, 5, 9, 9, 9]
[10, 1, 2, 5, 4, 6, 0, 8, 7, 3, 9]
[10, 2, 4, 5, 8, 0, 7, 6, 1, 9, 9]
[10, 4, 6, 5, 3, 0, 8, 7, 1, 2, 9]
[10, 7, 8, 5, 4, 0, 2, 6, 3, 1, 9]
[10, 1, 4, 8, 3, 5, 6, 2, 9, 9, 9]
[10, 2, 4, 0, 1, 5, 7, 9, 9, 9, 9]
[10, 8, 4, 7, 6, 2, 5, 0, 3, 9, 9]
[10, 8, 4, 1, 0, 6, 3, 5, 7, 2, 9]
[10, 1, 2, 8, 7, 5, 6, 4, 3, 0, 9]
[10, 1, 2, 4, 7, 8, 0, 3, 5, 6, 9]
[10, 2, 4, 8, 5, 3, 7, 1, 0, 6, 9]
[10, 3, 0, 1, 7, 4, 6, 5, 9, 9, 9]
[10, 8, 4, 7, 6, 2, 1, 5, 9, 9, 9]
[10, 6, 4, 3, 0, 8, 7, 1, 2, 5, 9]
[10, 2, 4, 1, 0, 8, 5, 3, 7, 6, 9]
[10, 5, 3, 4, 0, 1, 6, 9, 9, 9, 9]
[10, 4, 8, 1, 7, 6, 2, 3, 5, 9, 9]
[10, 5, 3, 2, 8, 6, 4, 0, 1, 7, 9]
[10, 6, 4, 2, 5, 3, 7, 0, 9, 9, 9]
[10, 2, 4, 8, 5, 3, 1, 7, 0, 6, 9]
[10, 5, 4, 8, 2, 6, 7, 3, 1, 9, 9]
[10, 6, 4, 0, 3, 5, 1, 7, 8, 2, 9]
[10, 7, 4, 3, 8, 0, 6, 2, 1, 5, 9]
[10, 2, 4, 5, 8, 0, 1, 3, 7, 9, 9]
[10, 8, 4, 6, 7, 1, 

18213it [00:01, 18132.43it/s]

[10, 0, 4, 5, 1, 2, 7, 9, 9, 9, 9]
[10, 6, 4, 1, 0, 8, 7, 2, 5, 3, 9]
[10, 5, 3, 6, 2, 0, 4, 8, 7, 1, 9]
[10, 1, 2, 8, 4, 6, 7, 3, 0, 5, 9]
[10, 4, 6, 2, 8, 7, 1, 3, 5, 0, 9]
[10, 4, 2, 5, 3, 6, 8, 1, 7, 0, 9]
[10, 1, 7, 2, 0, 8, 3, 6, 4, 5, 9]
[10, 1, 4, 0, 2, 6, 3, 5, 8, 7, 9]
[10, 5, 3, 4, 0, 6, 2, 1, 7, 8, 9]
[10, 8, 4, 0, 7, 1, 2, 6, 3, 5, 9]
[10, 6, 4, 7, 8, 0, 2, 1, 5, 9, 9]
[10, 7, 8, 4, 1, 5, 3, 0, 2, 6, 9]
[10, 5, 3, 0, 2, 4, 8, 6, 1, 7, 9]
[10, 5, 4, 8, 2, 6, 7, 1, 0, 3, 9]
[10, 6, 4, 7, 8, 0, 5, 1, 3, 9, 9]
[10, 0, 4, 1, 2, 6, 5, 3, 9, 9, 9]
[10, 1, 0, 3, 4, 8, 2, 6, 7, 5, 9]
[10, 3, 6, 2, 4, 5, 8, 7, 0, 9, 9]
[10, 8, 4, 0, 1, 7, 6, 2, 3, 5, 9]
[10, 6, 4, 0, 3, 5, 2, 8, 7, 1, 9]
[10, 1, 4, 6, 0, 8, 2, 7, 9, 9, 9]
[10, 6, 4, 3, 0, 8, 7, 5, 1, 9, 9]
[10, 6, 4, 3, 0, 8, 7, 5, 1, 9, 9]
[10, 4, 8, 7, 1, 2, 6, 5, 3, 0, 9]
[10, 8, 4, 7, 6, 2, 5, 3, 0, 1, 9]
[10, 6, 4, 5, 7, 1, 0, 8, 2, 3, 9]
[10, 4, 8, 5, 3, 6, 2, 1, 7, 0, 9]
[10, 3, 0, 1, 4, 8, 7, 2, 5, 6, 9]
[10, 7, 1, 2, 8, 6, 

20000it [00:01, 17713.30it/s]

[10, 2, 4, 3, 6, 1, 0, 7, 8, 9, 9]
[10, 5, 8, 7, 1, 0, 4, 2, 3, 6, 9]
[10, 4, 8, 7, 1, 0, 3, 2, 5, 6, 9]
[10, 8, 4, 3, 7, 1, 2, 6, 0, 5, 9]
[10, 4, 2, 1, 7, 5, 3, 8, 0, 6, 9]
[10, 6, 4, 3, 0, 8, 7, 1, 2, 5, 9]
[10, 1, 2, 4, 7, 6, 8, 3, 5, 9, 9]
[10, 8, 4, 2, 5, 0, 1, 3, 7, 9, 9]
[10, 7, 4, 3, 6, 1, 2, 9, 9, 9, 9]
[10, 1, 7, 2, 0, 6, 3, 8, 4, 5, 9]
[10, 5, 4, 7, 2, 6, 0, 8, 9, 9, 9]
[10, 5, 8, 7, 1, 3, 2, 0, 4, 6, 9]
[10, 7, 4, 6, 8, 0, 5, 3, 9, 9, 9]
[10, 5, 4, 7, 2, 6, 8, 1, 0, 9, 9]
[10, 7, 1, 6, 8, 0, 3, 4, 2, 5, 9]
[10, 4, 6, 0, 8, 7, 5, 1, 9, 9, 9]
[10, 3, 0, 2, 5, 6, 4, 8, 1, 7, 9]
[10, 0, 4, 6, 3, 5, 1, 7, 8, 2, 9]
[10, 6, 4, 1, 0, 8, 7, 3, 2, 5, 9]
[10, 6, 4, 8, 7, 1, 2, 3, 0, 5, 9]
[10, 1, 7, 5, 0, 8, 2, 4, 6, 3, 9]
[10, 5, 2, 0, 4, 6, 7, 3, 9, 9, 9]
[10, 5, 3, 6, 2, 0, 4, 8, 7, 1, 9]
[10, 1, 2, 6, 7, 5, 3, 0, 4, 8, 9]
[10, 7, 1, 2, 8, 6, 4, 0, 3, 5, 9]
[10, 7, 1, 4, 2, 0, 8, 6, 5, 9, 9]
[10, 1, 0, 4, 7, 6, 2, 3, 5, 8, 9]
[10, 3, 4, 8, 1, 7, 2, 6, 9, 9, 9]
[10, 8, 4, 6, 7, 1, 




{'_check_played_repeat_moves': 0.0002,
 '_check_played_after_player_victory': 0.0011,
 '_check_played_after_draw_game': 0.0,
 'inappropriate_end_state': 0.00065,
 '_check_if_illegal_moves': 0.00195}

In [56]:
seq = [10,1,2, 5, 4, 6, 8, 0, 3, 7, 9]

In [24]:
board = game.Board()

In [54]:
board.make_move(seq[-1])
board.draw_board()

| X | X | O |
| O | O | X |
| X | X | O |


In [55]:
torch.argmax(model(torch.tensor(seq))[0,-1])

tensor(9, device='cuda:0')

In [57]:
sampled_board = [game.play_game(moves) for moves in sample]

| X | O | X |
| X | O | X |
| O | X | O |
| O | X | X |
| X | O | O |
| O | X | X |
| O | X | O |
| X | X | O |
| O | X | X |
Invalid game
| X |   |   |
| X | O | O |
| X | X | O |
| X | O | X |
| O | O | X |
| X | X | O |
| X | O | X |
| X | O | X |
| O | X | O |
| O | X | X |
| X | O | O |
| X | O | X |
| X | X | X |
| X | O | O |
| O | O | X |
| X | O | X |
| O | O | X |
| X | X | O |
| X | O | X |
| O | X | X |
| O | X | O |
| O | X | O |
| X | X | X |
| O | O | X |
Invalid game
| X | O | O |
|   | O | X |
| X | O | X |
| X | X | O |
| O | O | X |
| X | X | O |
| O | X | X |
| X | O | O |
| O | X | X |
Invalid game
| X | X | O |
|   | O | X |
| O |   |   |
| X | O | X |
| O | X | X |
| O | X | O |
Invalid game
|   | X | O |
| O | O |   |
| X | X | X |
| O | X | X |
| X | O | O |
| O | X | X |
Invalid game
|   | X | X |
| O | O | O |
| O | X | X |
| X | O | O |
| O | O | X |
| X | X | X |
| X | O | X |
| X | O | O |
| O | X | X |
Invalid game
| O | X |   |
| X | O | O |
| X | X | O 

In [58]:
drawn_boards = [board for board in sampled_board if board.game_state == game.State.DRAW]

In [60]:
not_drawn_boards = [board for board in sampled_board if board.game_state != game.State.DRAW]

In [62]:
len(not_drawn_boards)


827

In [73]:
i = 90
not_drawn_boards[i].draw_board()
print(not_drawn_boards[i].moves_played)
print(first_invalid_move(not_drawn_boards[i].moves_played))

| X | X | O |
| O | O | O |
| X | X |   |
[0, 4, 1, 2, 6, 3, 7, 5]
-1


In [67]:
def first_invalid_move(moves: list[int]) -> int:
    board = game.Board()
    for move in moves:
        try:
            board.make_move(move)
        except:
            return move
    return -1


In [134]:
hash("1234")

1596181734667720836

In [141]:
evals._check_minimax_win_rate(model, 100)



100%|██████████| 100/100 [01:00<00:00,  1.64it/s]


{'draw': 1.0}

In [8]:
t_logits = train_data
t_labels = train_labels
rearranged_train_logits = einops.rearrange(model(t_logits), "batch seq_len one_hots -> (batch seq_len) one_hots")
rearranged_train_labels = einops.rearrange(t_labels, "batch seq_len one_hots -> (batch seq_len) one_hots")
train_loss = cross_entropy(rearranged_train_logits,rearranged_train_labels)
print(train_loss)

tensor(0.7350, device='cuda:0', grad_fn=<DivBackward1>)


In [9]:

with torch.inference_mode():
    test_logits = model(test_data)

    rearranged_test_logits = einops.rearrange(test_logits, "batch seq_len one_hots -> (batch seq_len) one_hots")
    rearranged_test_labels = einops.rearrange(test_labels, "batch seq_len one_hots -> (batch seq_len) one_hots")

    test_loss = loss_fn(rearranged_test_logits, rearranged_test_labels)
print(test_loss)

tensor(0.7335, device='cuda:0')
