In [1]:
# Change directory to the root of the project
import os 
os.chdir('..')
os.chdir('..')
os.chdir('..')
print(f"Working directory: {os.getcwd()}")

Working directory: /Users/eohjelle/Documents/2025-dots-and-boxes/dots-and-boxes


In [2]:
from applications.tic_tac_toe.models.dynamic_mask_experimental_transformer import DynamicMaskExperimentalTransformerInitParams
import torch

# Initialize parameters

## Model parameters
model_type = 'dynamic_mask_experimental_transformer'
model_params: DynamicMaskExperimentalTransformerInitParams = {
    'embed_dim': 8,
    'num_heads': 2,
    'mask_dim': 4
}
device = torch.device('mps')
model_name = 'tic_tac_toe_dynamic_mask_experimental_transformer'

## Initialize new model
load_model = None
load_model_params = {}


## Optimizer parameters
optimizer_type = 'adam'
optimizer_params = {
    'lr': 1e-2,
    'betas': (0.9, 0.999),
    'eps': 1e-8,
    'weight_decay': 1e-3,
    'amsgrad': False
}

## Learning scheduler parameters
lr_scheduler_type = 'plateau'
lr_scheduler_params = {
    'factor': 0.5,
    'patience': 25,
    'cooldown': 50,
    'min_lr': 1e-6
}

## Training parameters
training_method = 'supervised'
trainer_params = {}
training_params = {
    'epochs': 1000,
    'batch_size': 256,
    'eval_freq': 25,
    'checkpoint_freq': 50,
    'mask_illegal_moves': False,
    'mask_value': -20.0, # Doesn't matter when mask_illegal_moves is False
    'checkpoint_dir': 'checkpoints',
    'start_at': 1
}

## Load training data from wandb
load_replay_buffer = 'from_wandb'
load_replay_buffer_params = {
    'project': 'AlphaZero-TicTacToe',
    'artifact_name': f'tic_tac_toe_TokenizedTensorMapping_training_data',
    'artifact_version': 'latest'
}


In [3]:
# Initialize wandb run
import wandb

run_name = 'Adaptive Mask Experimental Transformer 4'
notes = 'Same as last time, higher weight decay.'

config = {
    'model_type': model_type,
    'model_params': model_params,
    'optimizer_type': optimizer_type,
    'optimizer_params': optimizer_params,
    'lr_scheduler_type': lr_scheduler_type,
    'lr_scheduler_params': lr_scheduler_params,
    'training_method': training_method,
    'trainer_params': trainer_params,
    'training_params': training_params
}

run = wandb.init(
    project='AlphaZero-TicTacToe',
    name=run_name,
    config=config,
    notes=notes
)

[34m[1mwandb[0m: Currently logged in as: [33meohjelle[0m ([33meigenway[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [4]:
# Perform training

from applications.tic_tac_toe.train import train

model_interface = train(
    model_type=model_type,
    model_params=model_params,
    device=device,
    model_name=model_name,
    optimizer_type=optimizer_type,
    optimizer_params=optimizer_params,
    lr_scheduler_type=lr_scheduler_type,
    lr_scheduler_params=lr_scheduler_params,
    training_method=training_method,
    trainer_params=trainer_params,
    training_params=training_params,
    load_model=load_model,
    load_model_params=load_model_params,
    load_replay_buffer=load_replay_buffer,
    load_replay_buffer_params=load_replay_buffer_params,
    wandb_run=run
)

[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`
[34m[1mwandb[0m:   1 of 1 files downloaded.  
  checkpoint = torch.load(path, map_location=device)


Epoch 1/1000, Train Loss: 1.9113, Val Loss: 1.7086, Learning rate: 0.010000
Epoch 2/1000, Train Loss: 1.6366, Val Loss: 1.6754, Learning rate: 0.010000
Epoch 3/1000, Train Loss: 1.5669, Val Loss: 1.5858, Learning rate: 0.010000
Epoch 4/1000, Train Loss: 1.4707, Val Loss: 1.4435, Learning rate: 0.010000
Epoch 5/1000, Train Loss: 1.2776, Val Loss: 1.2102, Learning rate: 0.010000
Epoch 6/1000, Train Loss: 1.0804, Val Loss: 1.0452, Learning rate: 0.010000
Epoch 7/1000, Train Loss: 0.9826, Val Loss: 0.9035, Learning rate: 0.010000
Epoch 8/1000, Train Loss: 0.8892, Val Loss: 0.8304, Learning rate: 0.010000
Epoch 9/1000, Train Loss: 0.8395, Val Loss: 0.8383, Learning rate: 0.010000
Epoch 10/1000, Train Loss: 0.8338, Val Loss: 0.8186, Learning rate: 0.010000
Epoch 11/1000, Train Loss: 0.8147, Val Loss: 0.8156, Learning rate: 0.010000
Epoch 12/1000, Train Loss: 0.8123, Val Loss: 0.8298, Learning rate: 0.010000
Epoch 13/1000, Train Loss: 0.8262, Val Loss: 0.7948, Learning rate: 0.010000
Epoch 14

In [5]:
run.finish()

0,1
MCTS_draw_rate,▄▄▄▄▄▂▃▁▂▃▄▄▆▄▃▇▇▄▇▇▇▇█▇▇▇█▇██▇█▇▇▇▇▇█▇▇
MCTS_loss_rate,▃▄▃▃▃▅▅▇▇▇▅▅▄▆█▂▁▅▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁
MCTS_score,█▇███▇▆▅▃▃▄▄▄▄▁▇█▄▇▇▇▇▆▆▆▆▆▇▆▆▇▆▇▆▇▇▇▆▇▇
MCTS_win_rate,▇▆▇▇▆█▆▇▅▄▃▃▂▃▂▃▄▄▃▃▃▂▁▂▂▂▂▃▁▂▂▁▃▂▂▃▂▁▃▃
Minimax_draw_rate,▅▅▅▄▄▅▁▄▂▄▃▆▇▄▄▇▇▇█▆██▇▇▇█▇██▇▇▇▇█▇▇▇▇█▇
Minimax_loss_rate,▄▄▄▅▅▄█▅▇▅▆▃▂▅▅▂▂▂▁▃▁▁▂▂▂▁▂▁▁▂▂▂▂▁▂▂▂▂▁▂
Minimax_score,▅▅▅▄▄▅▁▄▂▄▃▆▇▄▄▇▇▇█▆██▇▇▇█▇██▇▇▇▇█▇▇▇▇█▇
Minimax_win_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
RandomAgent_draw_rate,▃▃▄▄▄▅▂▇▄▅▅▃▃▂█▅▅▃▃▃▄▄▃▄▃▃▃▃▅▃▅▂▅▁▄▄▄▄▄▂
RandomAgent_loss_rate,▂▃▁▂▅▅▅▆▆▃▂█▆▃▅▂▃█▁▁▁▁▁▅▂▁▁▁▁▂▂▁▁▂▂▁▁▁▂▂

0,1
MCTS_draw_rate,0.82
MCTS_loss_rate,0.01
MCTS_score,0.16
MCTS_win_rate,0.17
Minimax_draw_rate,0.94
Minimax_loss_rate,0.06
Minimax_score,-0.06
Minimax_win_rate,0.0
RandomAgent_draw_rate,0.03
RandomAgent_loss_rate,0.01


In [6]:
def print_model_parameters(model):
    """
    Print the total number of parameters in a PyTorch model,
    with a breakdown of trainable vs non-trainable parameters.
    
    Args:
        model: PyTorch model
    """
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    non_trainable_params = total_params - trainable_params
    
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,} ({trainable_params/total_params:.2%})")
    print(f"Non-trainable parameters: {non_trainable_params:,} ({non_trainable_params/total_params:.2%})")
    
    # Optional: Print parameters by layer
    print("\nParameters by layer:")
    for name, param in model.named_parameters():
        print(f"{name}: {param.numel():,} parameters")

# Example usage
print_model_parameters(model_interface.model)

Total parameters: 2,080
Trainable parameters: 2,080 (100.00%)
Non-trainable parameters: 0 (0.00%)

Parameters by layer:
input_embedding.weight: 24 parameters
transformer_block.mask_layer.0.weight: 288 parameters
transformer_block.mask_layer.0.bias: 4 parameters
transformer_block.mask_layer.1.weight: 648 parameters
transformer_block.mask_layer.1.bias: 162 parameters
transformer_block.attn.q_emb.weight: 64 parameters
transformer_block.attn.q_emb.bias: 8 parameters
transformer_block.attn.k_emb.weight: 64 parameters
transformer_block.attn.k_emb.bias: 8 parameters
transformer_block.attn.v_emb.weight: 64 parameters
transformer_block.attn.v_emb.bias: 8 parameters
transformer_block.ff.0.weight: 8 parameters
transformer_block.ff.0.bias: 8 parameters
transformer_block.ff.1.weight: 256 parameters
transformer_block.ff.1.bias: 32 parameters
transformer_block.ff.3.weight: 256 parameters
transformer_block.ff.3.bias: 8 parameters
transformer_block.out_emb.weight: 64 parameters
transformer_block.out_em