In [1]:
import pgx
import jax
import jax.numpy as jnp
from core.memory.replay_memory import BaseExperience, EpisodeReplayBuffer
from core.trees.tree import init_batched_tree
from core.evaluators.mcts.mcts import MCTSTree, MCTSNode, MCTS

env = pgx.make("othello")
BATCH_SIZE = 16
MAX_NODES = 100




replay_memory = EpisodeReplayBuffer(capacity=1000)


In [2]:
import functools
import chex
import optax
from core.evaluators.alphazero import AlphaZero
from core.evaluators.mcts.action_selection import MuZeroPUCTSelector
from core.memory.replay_memory import BaseExperience, EpisodeReplayBuffer
from core.networks.azresnet import AZResnet, AZResnetConfig
from core.training.train import extract_params
from core.training.train_2p import TwoPlayerTrainer
from flax.training.train_state import TrainState

from core.types import StepMetadata

resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=65,
    num_blocks=2,
    num_channels=4,
))

def train_step(experience: BaseExperience, train_state: TrainState):
    def loss_fn(params: chex.ArrayTree):
        (pred_policy, pred_value), updates = train_state.apply_fn(
            {'params': params, 'batch_stats': train_state.batch_stats}, 
            x=experience.env_state.observation,
            train=True,
            mutable=['batch_stats']
        )
        pred_policy = jnp.where(
            experience.policy_mask,
            pred_policy,
            jnp.finfo(jnp.float32).min
        )
        policy_loss = optax.softmax_cross_entropy(pred_policy, experience.policy_weights).mean()
        # select appropriate value from experience.reward
        current_player = experience.env_state.current_player
        target_value = experience.reward[jnp.arange(experience.reward.shape[0]), current_player]
        value_loss = optax.l2_loss(pred_value.squeeze(), target_value).mean()

        l2_reg = 0.0001 * jax.tree_util.tree_reduce(
            lambda x, y: x + y,
            jax.tree_map(
                lambda x: (x ** 2).sum(),
                params
            )
        )

        loss = policy_loss + value_loss + l2_reg
        return loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)), grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    train_state = train_state.replace(batch_stats=updates['batch_stats'])
    metrics = {
        'loss': loss,
        'policy_loss': policy_loss,
        'value_loss': value_loss,
        'policy_accuracy': jnp.mean(jnp.argmax(pred_policy, axis=-1) == jnp.argmax(experience.policy_weights, axis=-1)),
        'value_accuracy': jnp.mean(jnp.round(pred_value) == jnp.round(experience.reward))
    }
    return train_state, metrics

def step_fn(state, action):
    state = env.step(state, action)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

def init_fn(key):
    state = env.init(key)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

def eval_fn(state, params):
    policy_logits, value = resnet.apply(params, state.observation[None,...], train=False)
    return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), \
            value.squeeze()

trainer = TwoPlayerTrainer(
    train_batch_size = 32,
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    eval_fn = eval_fn,
    train_step_fn =  train_step,
    evaluator = AlphaZero(
        max_nodes = MAX_NODES,
        branching_factor=65,
        action_selection_fn = MuZeroPUCTSelector()
    ),
    memory_buffer = EpisodeReplayBuffer(capacity=1000),
    evaluator_kwargs_train = dict(num_iterations=100),
    evaluator_kwargs_test = dict(num_iterations=100),
    wandb_project_name='aztest'
)

ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mlowrollr[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class TrainStateWithBS(TrainState):
    batch_stats: chex.ArrayTree

sample_env_state = trainer.make_template_env_state()

variables = resnet.init(jax.random.PRNGKey(0), sample_env_state.observation[None,...], train=False)
params = variables['params']
batch_stats = variables['batch_stats']

train_state = TrainStateWithBS.create(
    apply_fn = resnet.apply,
    params = params,
    tx = optax.adam(1e-4),
    batch_stats = batch_stats
)

In [4]:
collection_state, train_state, best_params = trainer.train_loop(
    key=jax.random.PRNGKey(0),
    batch_size=BATCH_SIZE,
    train_state=train_state, 
    warmup_steps=64, 
    collection_steps_per_epoch=128,
    train_steps_per_epoch=16,
    test_episodes_per_epoch=10,
    num_epochs=10
)

Epoch 1: {'loss': Array(2.8069592, dtype=float32), 'policy_accuracy': Array(0.17578125, dtype=float32), 'policy_loss': Array(2.1975522, dtype=float32), 'value_accuracy': Array(0.17285156, dtype=float32), 'value_loss': Array(0.597967, dtype=float32)}
Epoch 1: {'batch_stats': {'BatchNorm_0': {'mean': Array([0., 0., 0., 0.], dtype=float32), 'var': Array([1., 1., 1., 1.], dtype=float32)}, 'BatchNorm_1': {'mean': Array([0., 0.], dtype=float32), 'var': Array([1., 1.], dtype=float32)}, 'BatchNorm_2': {'mean': Array([0.], dtype=float32), 'var': Array([1.], dtype=float32)}, 'ResidualBlock_0': {'BatchNorm_0': {'mean': Array([0., 0., 0., 0.], dtype=float32), 'var': Array([1., 1., 1., 1.], dtype=float32)}, 'BatchNorm_1': {'mean': Array([0., 0., 0., 0.], dtype=float32), 'var': Array([1., 1., 1., 1.], dtype=float32)}}, 'ResidualBlock_1': {'BatchNorm_0': {'mean': Array([0., 0., 0., 0.], dtype=float32), 'var': Array([1., 1., 1., 1.], dtype=float32)}, 'BatchNorm_1': {'mean': Array([0., 0., 0., 0.], dty

ApplyScopeInvalidVariablesTypeError: The first argument passed to an apply function should be a dictionary of collections. Each collection should be a dictionary with string keys. (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ApplyScopeInvalidVariablesTypeError)

User defines:
* Environment dynamics
    * env_step_fn
    * env_init_fn
* Model
    * evaluation_fn
* flax.train_state.TrainState
    * train_step_fn
    * extract_model_params_fn

In [None]:
from core.evaluators.mcts.data import tree_to_graph
graph = tree_to_graph(collection_state.eval_state, batch_id=0)
graph.render('graph', format='svg', view=True)

'graph.svg'