In [1]:
import pgx
import jax
import jax.numpy as jnp
from core.trees.tree import init_batched_tree
from core.evaluators.mcts.mcts import MCTSTree, MCTSNode, MCTS

env = pgx.make("othello")
BATCH_SIZE = 1
MAX_NODES = 200

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, BATCH_SIZE)
sample_env_state = env.init(key)

node = MCTSNode(
    n=jnp.array(0, dtype=jnp.int32),
    p=jnp.zeros(65, dtype=jnp.float32),
    w=jnp.array(0, dtype=jnp.float32),
    terminal=jnp.array(False, dtype=jnp.bool_),
    embedding=sample_env_state,
)

tree: MCTSTree = init_batched_tree(
    key = jax.random.PRNGKey(0),
    batch_size = BATCH_SIZE,
    max_nodes = MAX_NODES,
    branching_factor = 65,
    template_data=node
)

keys = jax.random.split(key, BATCH_SIZE)
env_embedding = jax.vmap(env.init)(keys)

In [2]:
# define neural network
from core.networks.azresnet import AZResnetConfig, AZResnet

resnet = AZResnet(AZResnetConfig(
    model_type="resnet",
    policy_head_out_size=65,
    num_blocks=2,
    num_channels=4,
))

params = resnet.init(jax.random.PRNGKey(0), jnp.zeros((1, *env.observation_shape)), train=False)

In [3]:
import functools
from core.evaluators.alphazero import AlphaZero

from core.evaluators.mcts.action_selection import MuZeroPUCTSelector, PUCTSelector


def step_fn(state, action):
    state = env.step(state, action)
    return state, state.rewards[state.current_player], state.terminated

def eval_fn(state, params):
    policy_logits, value = resnet.apply(params, state.observation[None,...], train=False)
    return jax.nn.softmax(policy_logits, axis=-1).squeeze(0), \
            value.squeeze()

az = AlphaZero(
    step_fn = step_fn,
    eval_fn = functools.partial(eval_fn, params=params),
    action_selection_fn = MuZeroPUCTSelector(),
    action_mask_fn = lambda e: e.legal_action_mask
)

In [4]:
from functools import partial
from chex import dataclass
import chex
from core.trees.tree import get_subtree

@dataclass(frozen=True)
class Experience:
    env_embedding: chex.ArrayTree
    action_weights: chex.Array

template_experience = Experience(
    env_embedding = sample_env_state,
    action_weights = jnp.zeros((65), dtype=jnp.float32)
)


from core.memory.replay_memory import EpisodeReplayBuffer
replay_memory = EpisodeReplayBuffer(capacity=1000)
memory_buffer = replay_memory.init_batched_buffer(jax.random.PRNGKey(0), BATCH_SIZE, template_experience)

search = jax.jit(partial(az.search, num_iterations=200))
get_subtree = jax.jit(get_subtree)
add_experience = jax.jit(replay_memory.add_experience)

In [5]:
from core.trees.tree import reset_tree


def one_step(env_embedding, tree, memory_buffer):
    search_output = search(tree, env_embedding)
    tree = search_output.tree
    action = search_output.sampled_action
    tree = get_subtree(tree, action)
    memory_buffer = add_experience(memory_buffer, 
        Experience(env_embedding=env_embedding, action_weights=search_output.action_weights))
    env_embedding = env.step(env_embedding, action)
    memory_buffer = jax.lax.cond(
        env_embedding.terminated,
        lambda _: replay_memory.assign_reward(memory_buffer, env_embedding.rewards[env_embedding.current_player]),
        lambda _: memory_buffer,
        None
    )
    tree = jax.lax.cond(
        env_embedding.terminated,
        lambda t: reset_tree(t),
        lambda t: t,
        tree
    )
    env_embedding = jax.lax.cond(
        env_embedding.terminated,
        lambda _: env.init(jax.random.PRNGKey(0)),
        lambda _: env_embedding,
        None
    )
    return env_embedding, tree, memory_buffer

one_step_ = jax.jit(jax.vmap(one_step))
# env_embedding, tree, buffer = one_step_(env_embedding, tree, memory_buffer);



In [6]:
env_embedding, tree, memory_buffer = jax.lax.fori_loop(
    0, 500, lambda _, s: one_step_(*s), (env_embedding, tree, memory_buffer))

* search output specification
* integrate with custom replay memory buffer
* what does a loose API surrounding training look like
* can we update model params without re-compiling jitted code
    * can we do this inside jitted-code?
* does performance improve for a basic example?

In [7]:
from core.evaluators.mcts.data import tree_to_graph
graph = tree_to_graph(tree, batch_id=0)
graph.render('graph', format='svg', view=True)

'graph.svg'

In [8]:
env_embedding