In [1]:
import pgx
import chex
import jax
import jax.numpy as jnp
import optax
from functools import partial

from core.memory.replay_memory import EpisodeReplayBuffer
from core.networks.azresnet import AZResnet, AZResnetConfig
from core.evaluators.alphazero import AlphaZero
from core.evaluators.mcts.weighted_mcts import WeightedMCTS
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.evaluation_fns import make_nn_eval_fn
from core.testing.two_player_tester import TwoPlayerTester
from core.training.train import Trainer
from core.training.loss_fns import az_default_loss_fn
from core.types import StepMetadata

This is a demo of AlphaZero using weighted MCTS. 

Make sure to set specify a weights and biases project name if you have a wandb account to track metrics!

Hyperparameters are mostly for the purposes of example, do not assume they are correct!

Weighted MCTS: https://twitter.com/ptrschmdtnlsn/status/1748800529608888362

Implemented here: https://github.com/lowrollr/turbozero/blob/main/core/evaluators/mcts/weighted_mcts.py

temperature controlled by `q_temperature` (passed to AlphaZero initialization below)

For more on turbozero, see the [README](https://github.com/lowrollr/turbozero) and 
[Hello World notebook](https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb). The hello world notebook explains each component we set up in this notebook!


In [2]:
# get connect 4 environment
# pgx has lots more to choose from!
# othello, chess, etc.
env = pgx.make("connect_four")

In [3]:
# define environment dynamics functions
def step_fn(state, action):
    state = env.step(state, action)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

def init_fn(key):
    state = env.init(key)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

In [4]:
# define ResNet architecture
resnet = AZResnet(AZResnetConfig(
    policy_head_out_size=env.num_actions,
    num_blocks=4, # number of residual blocks
    num_channels=16 # channels per block
))


In [5]:
# define replay buffer
# store 300 experiences per batch
replay_memory = EpisodeReplayBuffer(capacity=300)

In [6]:
# define conversion fn for environment state to nn input
def state_to_nn_input(state):
    # pgx does this for us with state.observation!
    return state.observation

In [7]:
# define AlphaZero evaluator to use during self-play
# with weighted MCTS
alphazero = AlphaZero(WeightedMCTS)(
    eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),
    num_iterations = 100, # number of MCTS iterations
    max_nodes = 200,
    dirichlet_alpha=0.6,
    temperature = 1.0, # MCTS root action sampling temperature
    branching_factor = env.num_actions,
    action_selector = PUCTSelector(),
    q_temperature = 1.0, # temperature applied to child Q values prior to weighted propagation to parent
)

In [8]:
# define AlphaZero evaluator to use during evaluation games
alphazero_test = AlphaZero(WeightedMCTS)(
    eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),
    num_iterations = 100,
    max_nodes = 200,
    temperature = 0.0, # set temperature to zero to always sample most visited action after search
    branching_factor = env.num_actions,
    action_selector = PUCTSelector(),
    q_temperature = 1.0
)


In [9]:
# initialize trainer
# set `wandb_project_name` to log to wandb!!
trainer = Trainer(
    batch_size = 128, # number of parallel environments to collect self-play games from
    train_batch_size = 512, # training minibatch size
    warmup_steps = 42,
    collection_steps_per_epoch = 42,
    train_steps_per_epoch=(128*42)//512,
    nn = resnet,
    loss_fn = partial(az_default_loss_fn, l2_reg_lambda = 0.0001),
    optimizer = optax.adam(5e-3),
    evaluator = alphazero,
    memory_buffer = replay_memory,
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    state_to_nn_input_fn=state_to_nn_input,
    testers=[TwoPlayerTester(num_episodes=64)],
    evaluator_test = alphazero_test,
    # wandb_project_name='weighted_mcts_test' 
)



In [None]:
output = trainer.train_loop(seed=0, num_epochs=20)