In [1]:
import pgx
import jax
import jax.numpy as jnp
from core.memory.replay_memory import BaseExperience, EpisodeReplayBuffer
from core.trees.tree import init_batched_tree
from core.evaluators.mcts.mcts import MCTSTree, MCTSNode, MCTS

env = pgx.make("othello")
BATCH_SIZE = 4
MAX_NODES = 200

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, BATCH_SIZE)
sample_env_state = env.init(key)

node = MCTSNode(
    n=jnp.array(0, dtype=jnp.int32),
    p=jnp.zeros(65, dtype=jnp.float32),
    w=jnp.array(0, dtype=jnp.float32),
    terminated=jnp.array(False, dtype=jnp.bool_),
    embedding=sample_env_state,
)

tree: MCTSTree = init_batched_tree(
    key = jax.random.PRNGKey(0),
    batch_size = BATCH_SIZE,
    max_nodes = MAX_NODES,
    branching_factor = 65,
    template_data=node
)

replay_memory = EpisodeReplayBuffer(capacity=1000)


In [2]:
import functools
import chex
import optax
from core.evaluators.alphazero import AlphaZero
from core.evaluators.mcts.action_selection import MuZeroPUCTSelector
from core.memory.replay_memory import BaseExperience, EpisodeReplayBuffer
from core.networks.azresnet import AZResnet, AZResnetConfig
from core.training.train import extract_params
from core.training.train_2p import TwoPlayerTrainer
from flax.training.train_state import TrainState

from core.types import StepMetadata

resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=65,
    num_blocks=2,
    num_channels=4,
))

def train_step(experience: BaseExperience, train_state: TrainState):
    def loss_fn(params: chex.ArrayTree):
        (pred_policy, pred_value), updates = train_state.apply_fn(
            {'params': params, 'batch_stats': train_state.batch_stats}, 
            x=experience.env_state.observation,
            train=True,
            mutable=['batch_stats']
        )
        pred_policy = jnp.where(
            experience.policy_mask,
            pred_policy,
            0
        )
        policy_loss = optax.softmax_cross_entropy(pred_policy, experience.policy_weights).mean()
        # select appropriate value from experience.reward
        current_player = experience.env_state.current_player
        target_value = experience.reward[jnp.arange(experience.reward.shape[0]), current_player]
        value_loss = optax.l2_loss(pred_value.squeeze(), target_value).mean()

        l2_reg = 0.0001 * jax.tree_util.tree_reduce(
            lambda x, y: x + y,
            jax.tree_map(
                lambda x: (x ** 2).sum(),
                params
            )
        )

        loss = policy_loss + value_loss + l2_reg
        return loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)
    
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, ((policy_loss, value_loss, pred_policy, pred_value), updates)), grads = grad_fn(train_state.params)
    train_state = train_state.apply_gradients(grads=grads)
    train_state = train_state.replace(batch_stats=updates['batch_stats'])
    metrics = {
        'loss': loss,
        'policy_loss': policy_loss,
        'value_loss': value_loss,
        'policy_accuracy': jnp.mean(jnp.argmax(pred_policy, axis=-1) == jnp.argmax(experience.policy_weights, axis=-1)),
        'value_accuracy': jnp.mean(jnp.round(pred_value) == jnp.round(experience.reward))
    }
    return train_state, metrics

def step_fn(state, action):
    state = env.step(state, action)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

def init_fn(key):
    state = env.init(key)
    metadata = StepMetadata(
        rewards = state.rewards,
        terminated = state.terminated,
        action_mask = state.legal_action_mask,
        cur_player_id = state.current_player
    )
    return state, metadata

def eval_fn(state, params):
    policy_logits, value = resnet.apply(params, state.observation[None,...], train=False)
    return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), \
            value.squeeze()

trainer = TwoPlayerTrainer(
    train_batch_size = 16,
    env_step_fn = step_fn,
    env_init_fn = init_fn,
    eval_fn = eval_fn,
    train_step_fn =  train_step,
    evaluator = AlphaZero(
        max_nodes = MAX_NODES,
        branching_factor=65,
        action_selection_fn = MuZeroPUCTSelector()
    ),
    memory_buffer = EpisodeReplayBuffer(capacity=1000),
    template_env_state = sample_env_state,
    evaluator_kwargs_train = dict(num_iterations=400),
    evaluator_kwargs_test = dict(num_iterations=10),
    # wandb_project_name='az3'
)



In [3]:
from core.training.train import CollectionState


keys = jax.random.split(key, BATCH_SIZE)
env_embedding, metadata = jax.vmap(init_fn)(keys)

template_experience = BaseExperience(
    env_state = sample_env_state,
    policy_weights = jnp.zeros((65), dtype=jnp.float32),
    policy_mask = jnp.ones((65), dtype=jnp.bool_),
    reward = sample_env_state.rewards
)

buffer_state = replay_memory.init_batched_buffer(jax.random.PRNGKey(0), BATCH_SIZE, template_experience)

collection_state = CollectionState(
    key = keys,
    eval_state = tree,
    env_state = env_embedding,
    buffer_state = buffer_state,
    metadata = metadata
)

class TrainStateWithBS(TrainState):
    batch_stats: chex.ArrayTree

variables = resnet.init(jax.random.PRNGKey(0), env_embedding.observation[0:1], train=False)
params = variables['params']
batch_stats = variables['batch_stats']

train_state = TrainStateWithBS.create(
    apply_fn = resnet.apply,
    params = params,
    tx = optax.adam(1e-4),
    batch_stats = batch_stats
)

In [4]:
collection_state, train_state = trainer.train_loop(
    collection_state, train_state, 
    warmup_steps=32, 
    collection_steps_per_epoch=64,
    train_steps_per_epoch=4,
    test_episodes_per_epoch=16,
    num_epochs=2
)

Epoch 0: {'loss': Array(4.7744074, dtype=float32), 'policy_accuracy': Array(0.15625, dtype=float32), 'policy_loss': Array(4.1525707, dtype=float32), 'value_accuracy': Array(0.1484375, dtype=float32), 'value_loss': Array(0.61038065, dtype=float32)}
Epoch 0: {'performance_vs_best_model': Array(0.5, dtype=float32)}
Epoch 1: {'loss': Array(4.7484875, dtype=float32), 'policy_accuracy': Array(0.15625, dtype=float32), 'policy_loss': Array(4.1213527, dtype=float32), 'value_accuracy': Array(0.140625, dtype=float32), 'value_loss': Array(0.6156898, dtype=float32)}
Epoch 1: {'performance_vs_best_model': Array(0.53125, dtype=float32)}


User defines:
* Environment dynamics
    * env_step_fn
    * env_init_fn
* Model
    * evaluation_fn
* flax.train_state.TrainState
    * train_step_fn
    * extract_model_params_fn

In [5]:
from core.evaluators.mcts.data import tree_to_graph
graph = tree_to_graph(collection_state.eval_state, batch_id=0)
graph.render('graph', format='svg', view=True)

'graph.svg'