In [1]:
import pgx
import jax
import jax.numpy as jnp
from core.trees.tree import init_batched_tree
from core.evaluators.mcts.mcts import MCTSTree, MCTSNode, MCTS

env = pgx.make("othello")
BATCH_SIZE = 8
MAX_NODES = 50

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, BATCH_SIZE)
sample_env_state = env.init(key)

node = MCTSNode(
    n=jnp.array(0, dtype=jnp.int32),
    p=jnp.zeros(65, dtype=jnp.float32),
    w=jnp.array(0, dtype=jnp.float32),
    terminal=jnp.array(False, dtype=jnp.bool_),
    embedding=sample_env_state,
)

tree: MCTSTree = init_batched_tree(
    key = jax.random.PRNGKey(0),
    batch_size = BATCH_SIZE,
    max_nodes = MAX_NODES,
    branching_factor = 65,
    dummy_data=node
)

keys = jax.random.split(key, BATCH_SIZE)
env_embedding = jax.vmap(env.init)(keys)

In [2]:
# define neural network
from core.networks.azresnet import AZResnetConfig, AZResnet

resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=65,
    num_blocks=2,
    num_channels=4,
))

params = resnet.init(jax.random.PRNGKey(0), jnp.zeros((1, *env.observation_shape)), train=False)

In [3]:
import functools
from core.evaluators.alphazero import AlphaZero

from core.evaluators.mcts.action_selection import MuZeroPUCTSelector, PUCTSelector


def step_fn(state, action):
    player_id = state.current_player
    state = env.step(state, action)
    return state, state.rewards[player_id], state.terminated

def eval_fn(state, params):
    policy_logits, value = resnet.apply(params, state.observation[None,...], train=False)
    return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), \
            value.squeeze()

az = AlphaZero(
    step_fn = step_fn,
    eval_fn = functools.partial(eval_fn, params=params),
    action_selection_fn = PUCTSelector(),
    action_mask_fn = lambda e: e.legal_action_mask
)

In [4]:
from functools import partial

search = jax.jit(partial(az.search, num_iterations=50))
tree2 = jax.vmap(search)(tree, env_embedding)

In [5]:
from core.evaluators.mcts.data import tree_to_graph
graph = tree_to_graph(tree2, batch_id=0)
graph.render('graph', format='svg', view=True)

'graph.svg'