In [6]:
import wandb
import torch
from applications.tic_tac_toe.train import train

In [7]:
sweep_config = {
    'method': 'bayes',
    'metric': {
        'name': 'loss',
        'goal': 'minimize'
    },
    'parameters': {
        # Optimizer parameters
        'learning_rate': {
            'distribution': 'log_uniform_values',
            'min': 0.00001,
            'max': 0.1
        },
        'weight_decay': {
            'distribution': 'log_uniform_values',
            'min': 0.00001,
            'max': 0.1
        },

        # Model parameters
        'attention_layers': {
            'values': [1, 2, 3, 4]
        },
        'transformer_size': {
            'values': ['tiny', 'small', 'medium', 'large', 'xlarge']
        },
        'dropout': {
            'values': [0.0, 0.1, 0.01, 0.001, 0.0001]
        },
        'norm_first': {
            'values': [True, False]
        },
        'activation': {
            'values': ['relu', 'gelu']
        },

        # Trainer parameters
        'replay_buffer_max_size': {
            'value': 10000
        },
        'value_softness': {
            'distribution': 'uniform',
            'min': 0.0,
            'max': 1.0
        },

        # Training parameters
        'num_iterations': {
            'value': 50
        },
        'games_per_iteration': {
            'value': 10
        },
        'batch_size': {
            'values': [128, 256, 512, 1024]
        },
        'steps_per_iteration': {
            'value': 100
        },
        'num_simulations': {
            'values': [100]
        },
        'checkpoint_frequency': {
            'value': 20
        }
    }
}

# Transformer size mapping
transformer_size_mapping = {
    'tiny': { 'embed_dim': 4, 'num_heads': 1, 'feedforward_dim': 16 },
    'small': { 'embed_dim': 8, 'num_heads': 2, 'feedforward_dim': 32 },
    'medium': { 'embed_dim': 16, 'num_heads': 4, 'feedforward_dim': 64 },
    'large': { 'embed_dim': 32, 'num_heads': 8, 'feedforward_dim': 128 },
    'xlarge': { 'embed_dim': 64, 'num_heads': 16, 'feedforward_dim': 256 }
}


In [8]:
# Some default parameters

from core.implementations.AlphaZero import AlphaZeroConfig

# AlphaZero parameters
alphazero_config = AlphaZeroConfig(
    exploration_constant=1.0,
    dirichlet_alpha=0.3,
    dirichlet_epsilon=0.25,
    temperature=1.0
)

# AlphaZero evaluation parameters
alphazero_eval_config = AlphaZeroConfig(
    exploration_constant=1.0,
    dirichlet_alpha=0.0,
    dirichlet_epsilon=0.0,
    temperature=0.0
)

In [9]:
from applications.tic_tac_toe.transformer_model import TicTacToeTransformerInterface

def sweep_agent():
    with wandb.init(project='AlphaZero-TicTacToe') as run:
        config = {
            'model_type': 'transformer',
            'model_params': {
                'attention_layers': run.config.attention_layers,
                **transformer_size_mapping[run.config.transformer_size]
            },
            'device': 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu',
            'tree_search_params': alphazero_config,
            'tree_search_eval_params': alphazero_eval_config,
            'trainer_params': {
                'replay_buffer_max_size': run.config.replay_buffer_max_size,
                'value_softness': run.config.value_softness
            },
            'optimizer_params': {
                'lr': run.config.learning_rate,
                'betas': (0.9, 0.999),
                'eps': 1e-8,
                'weight_decay': run.config.weight_decay,
                'amsgrad': False
            },
            'training_params': {
                'num_iterations': run.config.num_iterations,
                'games_per_iteration': run.config.games_per_iteration,
                'batch_size': run.config.batch_size,
                'steps_per_iteration': run.config.steps_per_iteration,
                'num_simulations': run.config.num_simulations,
                'checkpoint_frequency': run.config.checkpoint_frequency
            }
        }

        model = TicTacToeTransformerInterface(
            device=config['device'],
            **config['model_params']
        )

        # Use training script
        train(
            config=config,
            model=model,
            use_wandb=True,
            wandb_watch_params={
                'watch': True,
                'log': 'all',
                'log_freq': 100,
                'log_graph': True
            },
            wandb_run=run
        )

In [10]:
sweep_id = wandb.sweep(
    sweep=sweep_config,
    project='AlphaZero-TicTacToe',
    entity='eigenway',
)


wandb.agent(
    sweep_id,
    function=sweep_agent,
    count=20
)

Create sweep with ID: 1zl9s6kk
Sweep URL: https://wandb.ai/eigenway/AlphaZero-TicTacToe/sweeps/1zl9s6kk


[34m[1mwandb[0m: Agent Starting Run: fjyt2t1t with config:
[34m[1mwandb[0m: 	activation: relu
[34m[1mwandb[0m: 	attention_layers: 3
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	checkpoint_frequency: 20
[34m[1mwandb[0m: 	dropout: 0.001
[34m[1mwandb[0m: 	games_per_iteration: 10
[34m[1mwandb[0m: 	learning_rate: 0.02566299983451523
[34m[1mwandb[0m: 	norm_first: False
[34m[1mwandb[0m: 	num_iterations: 2
[34m[1mwandb[0m: 	num_simulations: 100
[34m[1mwandb[0m: 	replay_buffer_max_size: 10000
[34m[1mwandb[0m: 	steps_per_iteration: 100
[34m[1mwandb[0m: 	transformer_size: large
[34m[1mwandb[0m: 	value_softness: 0.07761465177084503
[34m[1mwandb[0m: 	weight_decay: 0.001376143674359226


Training model: transformer
Using device: mps

Iteration 1/2
Self-play phase...
Playing game 10/10
Generated 79 new positions
Training phase...

Iteration 1 summary:
Average loss: 2.0898
Average policy_loss: 1.9694
Average value_loss: 0.1204
Replay buffer size: 79
Time taken: 9.0s

Iteration 2/2
Self-play phase...
Playing game 10/10
Generated 85 new positions
Training phase...

Evaluating against opponents...

Evaluation results:
MCTS: Win rate = 10.00%, Draw rate = 75.00%
RandomAgent: Win rate = 80.00%, Draw rate = 15.00%

Iteration 2 summary:
Average loss: 1.2942
Average policy_loss: 1.0294
Average value_loss: 0.2649
Replay buffer size: 164
Time taken: 39.9s

Training complete! Total time: 0.0h


0,1
buffer_size,▁█
iteration_time,▁█
loss,█▁
num_games,▁▁
num_positions,▁█
policy_loss,█▁
value_loss,▁█

0,1
best_win_rate,0.9
buffer_size,164.0
iteration_time,39.86773
loss,1.29425
num_games,10.0
num_positions,85.0
policy_loss,1.02938
total_time_hours,0.01357
value_loss,0.26487


[34m[1mwandb[0m: Agent Starting Run: zcrkvkur with config:
[34m[1mwandb[0m: 	activation: relu
[34m[1mwandb[0m: 	attention_layers: 4
[34m[1mwandb[0m: 	batch_size: 256
[34m[1mwandb[0m: 	checkpoint_frequency: 20
[34m[1mwandb[0m: 	dropout: 0.0001
[34m[1mwandb[0m: 	games_per_iteration: 10
[34m[1mwandb[0m: 	learning_rate: 0.04334473458733003
[34m[1mwandb[0m: 	norm_first: True
[34m[1mwandb[0m: 	num_iterations: 2
[34m[1mwandb[0m: 	num_simulations: 100
[34m[1mwandb[0m: 	replay_buffer_max_size: 10000
[34m[1mwandb[0m: 	steps_per_iteration: 100
[34m[1mwandb[0m: 	transformer_size: medium
[34m[1mwandb[0m: 	value_softness: 0.7084836662843019
[34m[1mwandb[0m: 	weight_decay: 0.0002848402357473973


Training model: transformer
Using device: mps

Iteration 1/2
Self-play phase...
Playing game 10/10
Generated 81 new positions
Training phase...

Iteration 1 summary:
Average loss: 2.1954
Average policy_loss: 1.9364
Average value_loss: 0.2590
Replay buffer size: 81
Time taken: 14.5s

Iteration 2/2
Self-play phase...
Playing game 10/10
Generated 99 new positions
Training phase...

Evaluating against opponents...

Evaluation results:
MCTS: Win rate = 0.00%, Draw rate = 50.00%
RandomAgent: Win rate = 70.00%, Draw rate = 20.00%

Iteration 2 summary:
Average loss: 1.0499
Average policy_loss: 0.9126
Average value_loss: 0.1373
Replay buffer size: 180
Time taken: 71.0s

Training complete! Total time: 0.0h


0,1
buffer_size,▁█
iteration_time,▁█
loss,█▁
num_games,▁▁
num_positions,▁█
policy_loss,█▁
value_loss,█▁

0,1
best_win_rate,0.7
buffer_size,180.0
iteration_time,70.96079
loss,1.04986
num_games,10.0
num_positions,99.0
policy_loss,0.91258
total_time_hours,0.02373
value_loss,0.13728
