# How to turbozero 🏁

`turbozero` provides a vectorized implementation of AlphaZero. 

In a nutshell, this means we can massively speed up training, by collecting many self-play games in parallel across one or more GPUs!

As the user, you just need to provide:
* environment dynamics functions (step and init) that adhere to the TurboZero spec
* a conversion function for environment state -> neural net input
* and a few hyperparameters!

TurboZero takes care of the rest. 😀 

## Environments

In order to take advantage of the batched implementation of AlphaZero, we need to pair it with a vectorized environment.

Fortunately, there are many great vectorized RL environment libraries, one I like in particular is [pgx](https://github.com/sotetsuk/pgx).

Let's use the 'othello' environment. You can see its documentation here: https://sotets.uk/pgx/othello/

In [1]:
import pgx
import jax
env = pgx.make('othello')

## Environment Dynamics

Turbozero needs to interface with the environment in order to build search trees and collect self-play episodes.

We can define this interface with the following functions:
* `env_step_fn`: given an environment state and an action, return the new environment state 
```python
    EnvStepFn = Callable[[chex.ArrayTree, int], Tuple[chex.ArrayTree, StepMetadata]]
```
* `env_init_fn`: given a key, initialize and reutrn a new environment state
```python
    EnvInitFn = Callable[[jax.random.PRNGKey], Tuple[chex.ArrayTree, StepMetadata]]
```
Fortunately, environment libraries implement these for us! We just need to extract a few key pieces of information 
from the environment state so that we can match the TurboZero specification. We store this in a StepMetadata object:

In [9]:
from core.types import StepMetadata
%psource StepMetadata

[0;34m@[0m[0mchex[0m[0;34m.[0m[0mdataclass[0m[0;34m([0m[0mfrozen[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m[0m
[0;34m[0m[0;32mclass[0m [0mStepMetadata[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0mrewards[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0maction_mask[0m[0;34m:[0m [0mchex[0m[0;34m.[0m[0mArray[0m[0;34m[0m
[0;34m[0m    [0mterminated[0m[0;34m:[0m [0mbool[0m[0;34m[0m
[0;34m[0m    [0mcur_player_id[0m[0;34m:[0m [0mint[0m[0;34m[0m[0;34m[0m[0m


* `rewards` stores the rewards emitted for each player for the given timestep
* `action_mask` is a mask across all possible actions, where legal actions are set to `True`, and invalid/illegal actions are set to `False`
* `terminated` True if the environment is terminated/completed
* `cur_player_id`: id of the current player

We can define the environment interface for `Othello` as follows:

In [10]:
def step_fn(state, action):
    new_state = env.step(state, action)
    return new_state, StepMetadata(
        rewards=new_state.rewards,
        action_mask=new_state.legal_action_mask,
        terminated=new_state.terminated,
        cur_player_id=new_state.current_player,
    )

def init_fn(key):
    state = env.init(key)
    return state, StepMetadata(
        rewards=state.rewards,
        action_mask=state.legal_action_mask,
        terminated=state.terminated,
        cur_player_id=state.current_player,
    )

Pretty easy!

## Neural Network

Next, we'll need to define the architecture of the neural network 

A simple implementation of the residual neural network used in the _AlphaZero_ paper is included for your convenience. 

You can implement your own architecture using `flax.linen`.

In [11]:
from core.networks.azresnet import AZResnetConfig, AZResnet

resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=env.num_actions,
    num_blocks=2,
    num_channels=4,
))

We also need a way to convert our environment's state into something our neural network can take as input (i.e. structured data -> Array). `pgx` conveniently includes this in `state.observation`, but for other environments you may need to perform the conversion yourself.

In [None]:
def state_to_nn_input(state):
    return state.observation

## Evaluator

Next, we can initialize our evaluator, AlphaZero, which takes the following parameters:

* `eval_fn`: function used to evaluate a leaf node (returns a policy and value)
* `num_iterations`: number of MCTS iterations to run before returning the final policy
* `max_nodes`: maximum capacity of search tree
* `branching_factor`: branching factor of search tree == policy_size
* `action_selector`: the algorithm used to select an action to take at any given search node, choose between:
    * `PUCTSelector`: AlphaZero action selection algorithm
    * `MuZeroPUCTSelector`: MuZero action selection algorithm
    * or write your own! :)

There are also a few other optional parameters

In [6]:
from core.evaluators.alphazero import AlphaZero
from core.evaluators.evaluation_fns import make_nn_eval_fn
from core.evaluators.mcts.action_selection import PUCTSelector
from core.evaluators.mcts.mcts import MCTS

# alphazero can take an arbirary search `backend`
# here we use classic MCTS
az_evaluator = AlphaZero(MCTS)(
    eval_fn = make_nn_eval_fn(resnet, state_to_nn_input),
    num_iterations = 25,
    max_nodes = 50,
    branching_factor=env.num_actions,
    action_selector = PUCTSelector()
)

## Replay Memory Buffer

Next, we'll initialize a replay memory buffer to hold selfplay trajectories that we can sample from during training. This actually just defines an interface, the buffer state itself will be initialized and managed internally.

The replay buffer is batched, it retains a buffer of trajectories across a batch dimension. We specify a `capacity`: the amount of samples stored in a single buffer. The total capacity of the entire replay buffer is then `batch_size * capacity`.

In [5]:
from core.memory.replay_memory import EpisodeReplayBuffer

replay_memory = EpisodeReplayBuffer(capacity=1000)

## Trainer Initialization
Now that we have all the proper pieces defined, we are ready to initialize a Trainer and start training!

The trainer will output metrics to the console, but if you'd rather visualize them it's easy to integrate with Weights and Biases!
Just pass the desired project name!

In [7]:
from functools import partial
from core.testing.elo_approx import ApproxEloTester
from core.testing.two_player_tester import TwoPlayerTester
from core.training.loss_fns import az_default_loss_fn
from core.training.train import Trainer
import optax

trainer = Trainer(
    train_batch_size = 16,
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    loss_fn = partial(az_default_loss_fn, l2_reg_lambda = 0.0001),
    nn = resnet,
    optimizer = optax.adam(1e-4),
    state_to_nn_input_fn=state_to_nn_input,
    evaluator = az_evaluator,
    testers = [ApproxEloTester(total_epochs=10, episodes_per_opponent=20, num_opponenets=5, rating_optim_steps=5000, rating_optim_lr=10000)],
    # testers = [TwoPlayerTester(num_episodes=10)], 
    memory_buffer = replay_memory,
    batch_size=16,
    warmup_steps=10,
    collection_steps_per_epoch=10,
    train_steps_per_epoch=10,
    
    # wandb_project_name = 'turbozero-othello'
)



## Training

We can start training by calling `trainer.train_loop`, which will execute:
 * `collection_steps_per_epoch` self-play steps, putting experience in the replay memory buffer
 * `train_steps_per_epoch` training steps, sampling mini-batches of `train_batch_size` from the replay memory buffer
 * `test_episodes_per_epoch` evaluation games, against the current best-performing model parameters
for each of `num_epochs` epochs.

`warmup_steps` self-play steps are executed before the loop begins to populate replay memory with some additional samples if desired.

All self-play collection steps will be parallelized across a batch dimension of size `batch_size`.

These hyperparameters are just for example purposes, do not expect fantastic performance!

In [8]:
output = trainer.train_loop(key=jax.random.PRNGKey(0), num_epochs=10)

Epoch 0: {'loss': '2.0972', 'policy_loss': '1.9122', 'value_loss': '0.1739'}




Epoch 0: {'elo_rating': '0.0000'}
Epoch 1: {'loss': '2.0119', 'policy_loss': '1.8336', 'value_loss': '0.1674'}
Epoch 1: {'elo_rating': '28.6646'}
Epoch 2: {'loss': '1.9295', 'policy_loss': '1.7581', 'value_loss': '0.1604'}
Epoch 2: {'elo_rating': '0.0000'}
Epoch 3: {'loss': '1.8485', 'policy_loss': '1.6849', 'value_loss': '0.1528'}
Epoch 3: {'elo_rating': '59.4534'}
Epoch 4: {'loss': '2.2547', 'policy_loss': '2.1273', 'value_loss': '0.1167'}
Epoch 4: {'elo_rating': '92.9027'}
Epoch 5: {'loss': '2.8498', 'policy_loss': '2.0592', 'value_loss': '0.7799'}
Epoch 5: {'elo_rating': '72.7003'}
Epoch 6: {'loss': '2.7154', 'policy_loss': '1.9598', 'value_loss': '0.7449'}
Epoch 6: {'elo_rating': '64.9589'}
Epoch 7: {'loss': '2.7585', 'policy_loss': '2.0207', 'value_loss': '0.7272'}
Epoch 7: {'elo_rating': '0.0923'}
Epoch 8: {'loss': '2.5920', 'policy_loss': '1.9243', 'value_loss': '0.6571'}
Epoch 8: {'elo_rating': '56.0113'}
Epoch 9: {'loss': '2.4871', 'policy_loss': '1.9336', 'value_loss': '0.54