# How to turbozero 🏁

`turbozero` provides a vectorized implementation of AlphaZero. As the user, you are responsible for providing:
* environment dynamics functions
* a leaf evaluation function
* initialized evaluation parameters
* a training step function

## Environment

There are many great vectorized RL environment libraries, one I like in particular is [pgx](https://github.com/sotetsuk/pgx).

Let's use the 'othello' environment. You can see its documentation here: https://sotets.uk/pgx/othello/

In [1]:
import pgx
import jax
env = pgx.make('othello')

## Environment Dynamics

Turbozero needs to interface with the environment in order to build search trees and collect episodes.
We can define this interface with the following functions:
* `env_step_fn`: given an environment state and an action, return the new environment state 
```python
    EnvStepFn = Callable[[chex.ArrayTree, int], Tuple[chex.ArrayTree, StepMetadata]]
```
* `env_init_fn`: given a key, initialize and reutrn a new environment state
```python
    EnvInitFn = Callable[[jax.random.PRNGKey], Tuple[chex.ArrayTree, StepMetadata]]
```
Fortunately, environment libraries implement these for us! We just need to extract a few key pieces of information 
from the environment state. We store this in a StepMetadata object:

In [2]:
from core.types import StepMetadata
%psource StepMetadata

[0;34m@[0m[0mchex[0m[0;34m.[0m[0mdataclass[0m[0;34m([0m[0mfrozen[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mclass[0m [0mStepMetadata[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mrewards[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0maction_mask[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0mterminated[0m[0;34m:[0m [0mbool[0m[0;34m[0m
[0;34m[0m    [0mcur_player_id[0m[0;34m:[0m [0mint[0m[0;34m[0m[0;34m[0m[0m


We can define the environment interface for `Othello` as follows:

In [3]:
def step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

def init_fn(key):
    state = env.init(key)
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
    )

Pretty easy!

## Leaf Evaluation

Next, we'll need to define an evaluation function that we can use to evaluate leaf nodes during Monte Carlo Tree Search. 
This function will need to produce a policy and a value for a given game state.
```python
EvalFn = Callable[[chex.ArrayTree, Params, jax.random.PRNGKey], Tuple[chex.Array, float]]
```

You could choose to implement the evaluation function however you like, but given that this project mostly focuses on AlphaZero, 
we will evaluate with a neural network!

A simple implementation of the residual neural network used in the _AlphaZero_ paper is included for your convenience:

In [4]:
from core.networks.azresnet import AZResnetConfig, AZResnet

resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=env.num_actions,
    num_blocks=2,
    num_channels=4,
))

This network will output a policy equal to the size of our action space. For othello actions include placing a piece on any of the 64 tiles, or doing nothing (64 + 1 = 65). 

Next, we can define the evaluation function:

In [5]:
def eval_fn(state, params, rng_key):
    # it's important to package the environement state into a structure that can be consumed by the neural network
    # fortunately, `state.observation` is exactly what we need
    # we will vmap self-play along the batch dimension, so we need to add a dummy batch dimension to the neural network input
    # when defining this function
    # finally, set train=False, we don't want to compute gradients during self-play
    policy_logits, value = resnet.apply(params, state.observation[None,...], train=False)

    # the output should not include the dummy batch dimension
    return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), \
            value.squeeze()

## Train State

Next we need to initialize a training state. This project requires using a flax `TrainState`.

The ResNet architecture uses BatchNorm, which requires some special setup and a custom TrainState class.
You can read more about incoporating BatchNorm into a flax training workflow here: https://flax.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html

In [6]:
import chex
from flax.training.train_state import TrainState
import optax

class TrainStateWithBS(TrainState):
    batch_stats: chex.ArrayTree

sample_env_state = env.init(jax.random.PRNGKey(0))

variables = resnet.init(jax.random.PRNGKey(0), sample_env_state.observation[None,...], train=False)
params = variables['params']
batch_stats = variables['batch_stats']

train_state = TrainStateWithBS.create(
    apply_fn = resnet.apply,
    params = params,
    tx = optax.adam(1e-4),
    batch_stats = batch_stats
)

## Replay Memory Buffer

Next, we'll initialize a replay memory buffer to hold selfplay trajectories that we can sample from during training. This actually just defines an interface, the buffer state itself will be initialized and managed internally.

The replay buffer is batched, it retains a buffer of trajectories across a batch dimension. We specify a `capacity`: the amount of samples stored in a single buffer. The total capacity of the entire replay buffer is then `batch_size * capacity`.

In [7]:
from core.memory.replay_memory import EpisodeReplayBuffer

replay_memory = EpisodeReplayBuffer(capacity=1000)

## Evaluator

Next, we can initialize our evaluator, AlphaZero, which takes the following parameters:

* `eval_fn`: function used to evaluate a leaf node (returns a policy and value)
* `num_iterations`: number of MCTS iterations to run before returning the final policy
* `max_nodes`: maximum capacity of search tree
* `branching_factor`: branching factor of search tree == policy_size
* `action_selector`: the algorithm used to select an action to take at any given search node, choose between:
    * `PUCTSelector`: AlphaZero action selection algorithm
    * `MuZeroPUCTSelector`: MuZero action selection algorithm
    * or write your own! :)

There are also a few other optional parameters

In [8]:
from core.evaluators.alphazero import AlphaZero
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.mcts.mcts import MCTS

# alphazero can take an arbirary search `backend`
# here we use classic MCTS
az_evaluator = AlphaZero(MCTS)(
    eval_fn = eval_fn,
    num_iterations = 25,
    max_nodes = 50,
    branching_factor=env.num_actions,
    action_selector = PUCTSelector()
)

## Define a training step

Lastly, we need to define how to train our model's parameters, given data from the replay memory buffer.

The data will take on the following stucture:

In [9]:
from core.memory.replay_memory import BaseExperience
%psource BaseExperience

[0;34m@[0m[0mdataclass[0m[0;34m([0m[0mfrozen[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mclass[0m [0mBaseExperience[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mreward[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0mpolicy_weights[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0mpolicy_mask[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0menv_state[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArrayTree[0m[0;34m[0m[0;34m[0m[0m


This example `train_step` fn computes the cross-entropy loss between the target policy `policy_weights` and our predicted policy. Then we compute mean-squared-error between our predicated evaluation and the game's outcome `reward`.

In [10]:
import jax.numpy as jnp

def train_step(experience: BaseExperience, train_state: TrainState):
    def loss_fn(params: chex.ArrayTree):
        (pred_policy, pred_value), updates = train_state.apply_fn(
            {'params': params, 'batch_stats': train_state.batch_stats}, 
            x=experience.env_state.observation,
            train=True,
            mutable=['batch_stats']
        )
        pred_policy = jnp.where(
            experience.policy_mask,
            pred_policy,
            jnp.finfo(jnp.float32).min
        )
        policy_loss = optax.softmax_cross_entropy(pred_policy, experience.policy_weights).mean()
        # select appropriate value from experience.reward
        current_player = experience.env_state.current_player
        target_value = experience.reward[jnp.arange(experience.reward.shape[0]), current_player]
        value_loss = optax.l2_loss(pred_value.squeeze(), target_value).mean()

        l2_reg = 0.0001 * jax.tree_util.tree_reduce(
            lambda x, y: x + y,
            jax.tree_map(
                lambda x: (x ** 2).sum(),
                params
            )
        )

        loss = policy_loss + value_loss + l2_reg
        return loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)), grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    train_state = train_state.replace(batch_stats=updates['batch_stats'])
    metrics = {
        'loss': loss,
        'policy_loss': policy_loss,
        'value_loss': value_loss,
        'policy_accuracy': jnp.mean(jnp.argmax(pred_policy, axis=-1) == jnp.argmax(experience.policy_weights, axis=-1)),
        'value_accuracy': jnp.mean(jnp.round(pred_value) == jnp.round(experience.reward))
    }
    return train_state, metrics

## Trainer Initialization
Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!

The trainer will output metrics to the console, but if you'd rather visualize them it's easy to integrate with Weights and Biases!
Just pass the desired project name!

In [11]:
from core.testing.two_player_tester import TwoPlayerTester
from core.training.train import Trainer

trainer = Trainer(
    train_batch_size = 128,
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    train_step_fn = train_step,
    evaluator = az_evaluator,
    testers = [TwoPlayerTester(num_episodes=10)],
    memory_buffer = replay_memory,
    # wandb_project_name = 'turbozero-othello'
)



## Training

We can start training by calling `trainer.train_loop`, which will execute:
 * `collection_steps_per_epoch` self-play steps, putting experience in the replay memory buffer
 * `train_steps_per_epoch` training steps, sampling mini-batches of `train_batch_size` from the replay memory buffer
 * `test_episodes_per_epoch` evaluation games, against the current best-performing model parameters
for each of `num_epochs` epochs.

`warmup_steps` self-play steps are executed before the loop begins to populate replay memory with some additional samples if desired.

All self-play collection steps will be parallelized across a batch dimension of size `batch_size`.

These hyperparameters are just for example purposes, do not expect fantastic performance!

In [None]:
output = trainer.train_loop(
    key=jax.random.PRNGKey(0),
    batch_size=16,
    train_state=train_state, 
    warmup_steps=64, 
    collection_steps_per_epoch=64,
    train_steps_per_epoch=16,
    num_epochs=10
)

I'll be adding more evaluation features soon (suggestions welcome!)