In [1]:
import pgx
import chex
import jax
import jax.numpy as jnp
import optax

from flax.training.train_state import TrainState

from core.memory.replay_memory import BaseExperience
from core.memory.replay_memory import EpisodeReplayBuffer
from core.training.train_2p import TwoPlayerTrainer
from core.networks.azresnet import AZResnet, AZResnetConfig
from core.evaluators.alphazero import AlphaZero
from core.evaluators.mcts.weighted_mcts import WeightedMCTS
from core.evaluators.mcts.action_selection import PUCTSelector
from core.types import StepMetadata

This is a demo of AlphaZero using weighted MCTS. 

Make sure to set specify a weights and biases project name if you have a wandb account to track metrics!

Hyperparameters are mostly for the purposes of example, do not assume they are correct!

Weighted MCTS: https://twitter.com/ptrschmdtnlsn/status/1748800529608888362

Implemented here: https://github.com/lowrollr/turbozero/blob/main/core/evaluators/mcts/weighted_mcts.py

temperature controlled by `q_temperature` (passed to AlphaZero initialization below)

For more on turbozero, see the [README](https://github.com/lowrollr/turbozero) and 
[Hello World notebook](https://github.com/lowrollr/turbozero/blob/main/notebooks/hello_world.ipynb)



In [2]:
# init rng
key = jax.random.PRNGKey(0)

# get connect 4 environment
# pgx has lots more to choose from!
# othello, chess, etc.
env = pgx.make("connect_four")

# define environment dynamics functions
def step_fn(state, action):
    state = env.step(state, action)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

def init_fn(key):
    state = env.init(key)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

# define collection batch size
# number of environments to collect self-play episodes from in parallel
# on a GPU you can go pretty high (like 2048 depending on your GPU memory)
BATCH_SIZE = 128

# define ResNet architecture
resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=env.num_actions,
    num_blocks=4, # number of residual blocks
    num_channels=16 # channels per block
))

# define replay buffer
# store 300 experiences per batch
replay_memory = EpisodeReplayBuffer(capacity=300)

# define AlphaZero evaluator to use during self-play
# with weighted MCTS
alphazero = AlphaZero(WeightedMCTS)(
    num_iterations = 100, # number of MCTS iterations
    max_nodes = 200,
    dirichlet_alpha=0.6,
    temperature = 1.0, # MCTS root action sampling temperature
    branching_factor = env.num_actions,
    action_selection_fn = PUCTSelector(),
    q_temperature = 1.0, # temperature applied to child Q values prior to weighted propagation to parent
)

# define AlphaZero evaluator to use during evaluation games
alphazero_test = AlphaZero(WeightedMCTS)(
    num_iterations = 100,
    max_nodes = 200,
    temperature = 0.0, # set temperature to zero to always sample most visited action after search
    branching_factor = env.num_actions,
    action_selection_fn = PUCTSelector(),
    q_temperature = 1.0
)

# define alphazero leaf evaluation function
def eval_fn(state, params, key):
    policy_logits, value = resnet.apply(params, state.observation[None,...], train=False)
    return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), \
            value.squeeze()

# define a training step
# this looks scary but it's just:
# policy loss = cross entropy loss
# value loss = l2 loss
# + l2 regularization
def train_step(experience: BaseExperience, train_state: TrainState):
    def loss_fn(params: chex.ArrayTree):
        (pred_policy, pred_value), updates = train_state.apply_fn(
            {'params': params, 'batch_stats': train_state.batch_stats}, 
            x=experience.env_state.observation,
            train=True,
            mutable=['batch_stats']
        )
        pred_policy = jnp.where(
            experience.policy_mask,
            pred_policy,
            jnp.finfo(jnp.float32).min
        )
        policy_loss = optax.softmax_cross_entropy(pred_policy, experience.policy_weights).mean()
        # select appropriate value from experience.reward
        current_player = experience.env_state.current_player
        target_value = experience.reward[jnp.arange(experience.reward.shape[0]), current_player]
        value_loss = optax.l2_loss(pred_value.squeeze(), target_value).mean()

        l2_reg = 0.0001 * jax.tree_util.tree_reduce(
            lambda x, y: x + y,
            jax.tree_map(
                lambda x: (x ** 2).sum(),
                params
            )
        )

        loss = policy_loss + value_loss + l2_reg
        return loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)), grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    train_state = train_state.replace(batch_stats=updates['batch_stats'])
    metrics = {
        'loss': loss,
        'policy_loss': policy_loss,
        'value_loss': value_loss,
        'policy_accuracy': jnp.mean(jnp.argmax(pred_policy, axis=-1) == jnp.argmax(experience.policy_weights, axis=-1)),
        'value_accuracy': jnp.mean(jnp.round(pred_value) == jnp.round(experience.reward))
    }
    return train_state, metrics

# set up custom training state to handle BatchNorm 
class TrainStateWithBS(TrainState):
    batch_stats: chex.ArrayTree

# initialize model paramters 
model_param_init_key, key = jax.random.split(key, 2)
variables = resnet.init(model_param_init_key, jnp.zeros((1, *env.observation_shape)), train=False)
params = variables['params']
batch_stats = variables['batch_stats']

# initialize flax training state
train_state = TrainStateWithBS.create(
    apply_fn = resnet.apply,
    params = params,
    tx = optax.adam(learning_rate=5e-3),
    batch_stats = batch_stats
)

# initialize trainer
# set `wandb_project_name` to log to wandb!!
trainer = TwoPlayerTrainer(
    train_batch_size = 512, # training minibatch size
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    eval_fn = eval_fn,
    train_step_fn =  train_step,
    evaluator = alphazero,
    evaluator_test = alphazero_test,
    memory_buffer = replay_memory,
    # wandb_project_name='weighted_mcts_test' 
)



In [None]:
trainer_key, key = jax.random.split(key, 2)

# initialize training
# 42 = max steps in connect 4 game, so one epoch is roughly `BATCH_SIZE` games
output = trainer.train_loop(
    key=trainer_key,
    batch_size=BATCH_SIZE,
    train_state=train_state, 
    warmup_steps=42, # number of self-play steps to collect during warmup (per batch)
    collection_steps_per_epoch=42, # number of self-play steps to collect per epoch (per batch)
    train_steps_per_epoch=(BATCH_SIZE*42)//trainer.train_batch_size, # train steps per epoch
    test_episodes_per_epoch=64, # evaluation games per epoch
    num_epochs=50
)