In [1]:
# Change directory to the root of the project
import os 
os.chdir('..')
print(f"Working directory: {os.getcwd()}")

Working directory: /Users/eohjelle/Documents/2025-dots-and-boxes/dots-and-boxes


In this notebook we will do single runs for experimental transformer based models using training data generated with a Minimax model. This is based on tic_tac_toe_transformer_single_run.ipynb.

In [2]:
# Load the training data

from core.data_structures import ReplayBuffer
import random

buffer = ReplayBuffer.from_file('applications/tic_tac_toe/training_data/transformer.pkl')

print(f"Buffer size: {buffer.states.shape[0]}")
for i in random.sample(range(buffer.states.shape[0]), 1):
    print(f"Buffer state {i}: {buffer.states[i]}")
    for key in buffer.targets.keys():
        print(f"Buffer target {key} {i}: {buffer.targets[key][i]}")
    for key in buffer.data.keys():
        print(f"Buffer data {key} {i}: {buffer.data[key][i]}")

Buffer size: 5478
Buffer state 1120: tensor([2, 0, 1, 0, 0, 0, 0, 0, 2], device='mps:0')
Buffer target policy 1120: tensor([0.0000, 0.1667, 0.0000, 0.1667, 0.1667, 0.1667, 0.1667, 0.1667, 0.0000],
       device='mps:0')
Buffer target value 1120: -1.0
Buffer data legal_actions 1120: tensor([0., 1., 0., 1., 1., 1., 1., 1., 0.], device='mps:0')


  checkpoint = torch.load(path, map_location=device)


In [3]:
from applications.tic_tac_toe.game_state import TicTacToeState
from core.implementations.Minimax import Minimax

# Creat minmax agent and expand the game tree, this will be used for evaluation later on
state = TicTacToeState()
minimax_agent = Minimax(state)
minimax_agent_root = minimax_agent.root
minimax_agent()

def minimax_agent_factory() -> Minimax:
    """
    This function returns a minimax agent that is initialized with the root of the game tree.
    """
    minimax_agent.root = minimax_agent_root
    return minimax_agent


In [4]:
# Define config

model_type = 'experimental_transformer'

config = {
    # Optimizer parameters
    'learning_rate': 0.01,
    'weight_decay': 0.0001,

    # Learning rate scheduler parameters
    'lr_scheduler': 'plateau',  # Options: 'step', 'multistep', 'exponential', 'cosine', 'plateau'
    'lr_eta_min': 1e-6, # For CosineAnnealingLR
    'lr_step_size': 30,  # For StepLR
    'lr_gamma': 0.1,  # For StepLR, MultiStepLR, ExponentialLR, ReduceLROnPlateau
    'lr_milestones': [30, 60, 90],  # For MultiStepLR
    'lr_t_max': 1000,  # For CosineAnnealingLR (usually set to total epochs)
    'lr_patience': 25,  # For ReduceLROnPlateau
    'lr_cooldown': 175,  # For ReduceLROnPlateau

    # Model parameters
    'embed_dim': 32,
    'num_heads': 4,

    # Training parameters
    'epochs': 1000,
    'batch_size': 256,
    'mask_illegal_moves': False,
    'mask_value': -10.0,
    'eval_freq': 50
}

In [5]:
# Define simple training loop

import wandb
from applications.tic_tac_toe import TicTacToeState, TokenizedTensorMapping, TicTacToeExperimentalTransformer, ExperimentalTransformerInitParams
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, random_split
from core import benchmark, ModelInterface
from core.implementations.RandomAgent import RandomAgent
from core.implementations.AlphaZero import AlphaZeroModelAgent

def do_run(run):
    config = run.config
    device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
    model_interface = ModelInterface(
        model_architecture=TicTacToeExperimentalTransformer,
        init_params=ExperimentalTransformerInitParams(
            embed_dim=config.embed_dim,
            num_heads=config.num_heads
        ),
        device=device
    )
    model = model_interface.model

    wandb.watch(
        models=model,
        log="all",
        log_freq=20,
        log_graph=True
    )
    
    # Create optimizer
    optimizer = torch.optim.Adam(
        model.parameters(), 
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    # Create learning rate scheduler
    match config['lr_scheduler']:
        case 'step':
            scheduler = torch.optim.lr_scheduler.StepLR(
                optimizer, 
                step_size=config['lr_step_size'], 
                gamma=config['lr_gamma']
            )
        case 'multistep':
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optimizer, 
                milestones=config['lr_milestones'], 
                gamma=config['lr_gamma']
            )
        case 'exponential':
            scheduler = torch.optim.lr_scheduler.ExponentialLR(
                optimizer, 
                gamma=config['lr_gamma']
                )
        case 'cosine':
            scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
                optimizer, 
                T_max=config['lr_t_max'],
                eta_min=1e-6
            )
        case 'plateau':
            scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                optimizer, 
                mode='min', 
                factor=config['lr_gamma'], 
                patience=config['lr_patience'],
                cooldown=config['lr_cooldown']
            )
        case _:
            scheduler = None

    print(f"Using learning rate scheduler: {config['lr_scheduler']}")
    
    # Create datasets
    states = buffer.states
    policy_targets = buffer.targets['policy']
    value_targets = buffer.targets['value']
    legal_actions_mask = buffer.data['legal_actions']
    
    # Create dataset and split into train/val
    dataset = TensorDataset(states, policy_targets, value_targets, legal_actions_mask)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
    
    # Training loop
    best_val_loss = float('inf')
    
    for epoch in range(1, config.epochs + 1):
        # Training phase
        model.train()
        train_losses = []
        policy_losses = []
        value_losses = []
        
        for batch in train_loader:
            states_batch, policy_targets_batch, value_targets_batch, legal_actions_batch = batch
            
            # Zero gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(states_batch)
            policy_logits = outputs['policy']
            value_pred = outputs['value']
            
            # Apply mask for illegal moves if enabled
            if config.mask_illegal_moves:
                policy_logits = policy_logits * legal_actions_batch + (1 - legal_actions_batch) * config.mask_value
            
            # Compute losses
            policy_loss = F.cross_entropy(policy_logits, policy_targets_batch)
            value_loss = F.mse_loss(value_pred, value_targets_batch)
            total_loss = policy_loss + value_loss
            
            # Backward pass and optimization
            total_loss.backward()
            optimizer.step()
            
            # Track metrics
            train_losses.append(total_loss.item())
            policy_losses.append(policy_loss.item())
            value_losses.append(value_loss.item())
        
        # Validation phase
        model.eval()
        val_losses = []
        val_policy_losses = []
        val_value_losses = []
        
        with torch.no_grad():
            for batch in val_loader:
                states_batch, policy_targets_batch, value_targets_batch, legal_actions_batch = batch
                
                # Forward pass
                outputs = model(states_batch)
                policy_logits = outputs['policy']
                value_pred = outputs['value']
                
                # Apply mask for illegal moves if enabled
                if config.mask_illegal_moves:
                    policy_logits = policy_logits * legal_actions_batch + (1 - legal_actions_batch) * config.mask_value
                
                # Compute losses
                policy_loss = F.cross_entropy(policy_logits, policy_targets_batch)
                value_loss = F.mse_loss(value_pred, value_targets_batch)
                total_loss = policy_loss + value_loss
                
                # Track metrics
                val_losses.append(total_loss.item())
                val_policy_losses.append(policy_loss.item())
                val_value_losses.append(value_loss.item())
        
        # Calculate average metrics
        avg_train_loss = np.mean(train_losses)
        avg_val_loss = np.mean(val_losses)
        
        current_lr = optimizer.param_groups[0]['lr']

        # Log metrics to wandb
        wandb.log({
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "train_policy_loss": np.mean(policy_losses),
            "train_value_loss": np.mean(value_losses),
            "val_loss": avg_val_loss,
            "val_policy_loss": np.mean(val_policy_losses),
            "val_value_loss": np.mean(val_value_losses),
            "learning_rate": current_lr
        }, step=epoch)
        
        print(f"Epoch {epoch}/{config.epochs}, "
                f"Train Loss: {avg_train_loss:.4f}, "
                f"Val Loss: {avg_val_loss:.4f}, "
                f"Learning rate: {current_lr:.6f}")
        
        # Update the learning rate scheduler
        if scheduler is not None:

            # Update scheduler based on its actual type
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(avg_val_loss)  # ReduceLROnPlateau needs a metric
            else:
                scheduler.step()  # All other schedulers just need step()
    
        
        # Save best model
        if avg_val_loss < best_val_loss:
            best_val_loss = avg_val_loss
            
            # Save model checkpoint
            if not os.path.exists('checkpoints'):
                os.makedirs('checkpoints')
                
            checkpoint_path = f"checkpoints/run_{run.id}_best_model.pt"
            # model.save_checkpoint(checkpoint_path)
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': best_val_loss,
                'config': dict(config)
            }, checkpoint_path)
            
            # Save as W&B artifact
            model_artifact = wandb.Artifact(
                f"tic_tac_toe_{model_type}_model", 
                type="model",
                description=f"Best model with val_loss: {best_val_loss:.4f}"
            )
            model_artifact.add_file(checkpoint_path)
            run.log_artifact(model_artifact)

        # Evaluate against agents
        if epoch % config.eval_freq == 0 or epoch == config.epochs - 1:
            stats = benchmark(
                create_agent=lambda state: AlphaZeroModelAgent(
                    initial_state=state,
                    model=model_interface,
                    tensor_mapping=TokenizedTensorMapping(),
                    temperature=0.0
                ),
                create_opponents={
                    'random': lambda state: RandomAgent(state),
                    'minimax': lambda state: minimax_agent_factory()
                },
                initial_state=lambda: TicTacToeState(),
                num_games=100
            )
            print(f"RandomAgent score: {stats['random']['win_rate'] - stats['random']['loss_rate']}")
            print(f"Minimax draw rate: {stats['minimax']['draw_rate']}")
            wandb.log({
                'random_win_rate': stats['random']['win_rate'],
                'minimax_win_rate': stats['minimax']['win_rate'],
                'random_draw_rate': stats['random']['draw_rate'],
                'minimax_draw_rate': stats['minimax']['draw_rate'],
                'random_loss_rate': stats['random']['loss_rate'],
                'minimax_loss_rate': stats['minimax']['loss_rate'],
                'random_score': stats['random']['win_rate'] - stats['random']['loss_rate'],
                'minimax_score': stats['minimax']['win_rate'] - stats['minimax']['loss_rate']
            })

In [6]:
run = wandb.init(
    project="AlphaZero-TicTacToe",
    id=None,
    config=config
)
do_run(run)

[34m[1mwandb[0m: Currently logged in as: [33meohjelle[0m ([33meigenway[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


[34m[1mwandb[0m: logging graph, to disable use `wandb.watch(log_graph=False)`


Using learning rate scheduler: plateau
Epoch 1/1000, Train Loss: 2.3656, Val Loss: 2.1696, Learning rate: 0.010000
Epoch 2/1000, Train Loss: 1.8734, Val Loss: 1.8026, Learning rate: 0.010000
Epoch 3/1000, Train Loss: 1.6459, Val Loss: 1.7066, Learning rate: 0.010000
Epoch 4/1000, Train Loss: 1.6094, Val Loss: 1.6775, Learning rate: 0.010000
Epoch 5/1000, Train Loss: 1.5573, Val Loss: 1.6491, Learning rate: 0.010000
Epoch 6/1000, Train Loss: 1.5507, Val Loss: 1.6288, Learning rate: 0.010000
Epoch 7/1000, Train Loss: 1.5300, Val Loss: 1.5749, Learning rate: 0.010000
Epoch 8/1000, Train Loss: 1.5014, Val Loss: 1.5828, Learning rate: 0.010000
Epoch 9/1000, Train Loss: 1.4854, Val Loss: 1.6003, Learning rate: 0.010000
Epoch 10/1000, Train Loss: 1.4237, Val Loss: 1.5257, Learning rate: 0.010000
Epoch 11/1000, Train Loss: 1.3720, Val Loss: 1.5042, Learning rate: 0.010000
Epoch 12/1000, Train Loss: 1.3100, Val Loss: 1.4137, Learning rate: 0.010000
Epoch 13/1000, Train Loss: 1.2422, Val Loss: 1

In [7]:
wandb.finish(exit_code=0)

0,1
epoch,▁▁▁▁▁▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇▇▇▇█
learning_rate,███████████▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
minimax_draw_rate,▁█▇█▇█▇██▇▇▇▇▇▇███▇▇█
minimax_loss_rate,█▁▂▁▂▁▂▁▁▂▂▂▂▂▂▁▁▁▂▂▁
minimax_score,▁█▇█▇█▇██▇▇▇▇▇▇███▇▇█
minimax_win_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
random_draw_rate,▆▃▄▄▃▃▃█▄▁▂▅▃▆▃▂▆▃▂▃▇
random_loss_rate,▁█▁▁▁█▁▁█▁▁▁▁▁▁▁▁█▁▁▁
random_score,▃▄▅▅▆▄▆▁▃█▇▄▆▃▆▇▃▅▇▆▂
random_win_rate,▃▅▅▅▆▅▆▁▄█▇▄▆▃▆▇▃▆▇▆▂

0,1
epoch,1000.0
learning_rate,0.0
minimax_draw_rate,1.0
minimax_loss_rate,0.0
minimax_score,0.0
minimax_win_rate,0.0
random_draw_rate,0.11
random_loss_rate,0.0
random_score,0.89
random_win_rate,0.89
