In [1]:
import pgx

from flax import struct
import jax.numpy as jnp
import jax
from functools import partial
from flax import linen as nn

@struct.dataclass
class Experience(struct.PyTreeNode):
    observation: struct.PyTreeNode
    policy_logits: jnp.ndarray

class ResidualBlock(nn.Module):
    channels: int

    @nn.compact
    def __call__(self, x, training):
        y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME')(x)
        y = nn.BatchNorm(use_running_average=not training)(y)
        y = nn.relu(y)
        y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME')(y)
        y = nn.BatchNorm(use_running_average=not training)(y)
        return nn.relu(x + y)


class AZResnet(nn.Module):
    policy_head_out_size: int
    value_head_out_size: int
    num_blocks: int
    channels: int

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Conv(features=self.channels, kernel_size=(1,1), strides=(1,1), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)

        for _ in range(self.num_blocks):
            x = ResidualBlock(channels=self.channels)(x, training=training)

        # policy head
        policy = nn.Conv(features=2, kernel_size=(1,1), strides=(1,1), padding='SAME')(x)
        policy = nn.BatchNorm(use_running_average=not training)(policy)
        policy = nn.relu(policy)
        policy = policy.reshape((policy.shape[0], -1))
        policy = nn.Dense(features=self.policy_head_out_size)(policy)

        # value head
        value = nn.Conv(features=1, kernel_size=(1,1), strides=(1,1), padding='SAME')(x)
        value = nn.BatchNorm(use_running_average=not training)(value)
        value = nn.relu(value)
        value = value.reshape((value.shape[0], -1))
        value = nn.Dense(features=self.value_head_out_size)(value)
        value = nn.tanh(value)

        return policy, value
    


In [2]:
# load/init model parameters
# TODO

# PARAMETERS 
import jumanji
from core_jax.envs.pgx import make_pgx_env
from core_jax.evaluators.alphazero import AlphaZero, AlphaZeroConfig
from core_jax.evaluators.evaluator import Evaluator, EvaluatorConfig
from core_jax.utils.ranked_reward_replay_memory import RankedRewardReplayBuffer
from core_jax.utils.replay_memory import EndRewardReplayBuffer
from core_jax.envs.jumanji import JumanjiEnv, make_jumanji_env
from core_jax.evaluators.mcts import MCTSConfig
from core_jax.evaluators.randotron import Randotron


batch_size = 10
max_len_per_batch = 1000
sample_batch_size = 10


# init buffer, env
buff = EndRewardReplayBuffer(
    batch_size=batch_size,
    max_len_per_batch=max_len_per_batch,
    sample_batch_size=sample_batch_size,
    # quantile=0.75,
    # episode_reward_memory_len_per_batch=100
)

# from jumanji.environments.logic.minesweeper.types import State
# from jumanji.environments.logic.minesweeper.utils import get_mined_board



env = make_pgx_env("othello")


config = AlphaZeroConfig(
    mcts_iters=50,
    temperature=1.0,
    epsilon=1e-8,
    max_nodes=100,
    puct_coeff=1.0,
    dirichlet_alpha=0.3,
    dirichlet_epsilon=0.25
)




random_key = jax.random.PRNGKey(0)
env_key, eval_key, model_key = jax.random.split(random_key, 3)
env_keys = jax.random.split(env_key, batch_size)
eval_keys = jax.random.split(eval_key, batch_size)

env_state, terminated = jax.jit(jax.vmap(env.reset))(env_keys)

model = AZResnet(
    policy_head_out_size=jnp.prod(jnp.array(env_state.legal_action_mask.shape[1:])).item(),
    value_head_out_size=1,
    num_blocks=2,
    channels=4
)



from typing import Any
from flax.training import train_state

import optax

evaluator = AlphaZero(
    env=env,
    config=config,
    model=model,
)

model_params = model.init(
    model_key,
    jnp.zeros((1, *env_state._observation.shape[1:]), jnp.float32),
    training=False
)


# class TrainState(train_state.TrainState):
#     batch_stats: Any

# state = TrainState.create(
#     apply_fn = model.apply,
#     params = params,
#     batch_stats = batch_stats,
#     tx=optax.sgd(learning_rate=0.01, momentum=0.9)
# )


eval_state = jax.jit(jax.vmap(evaluator.reset))(eval_keys)
eval_state = jax.vmap(evaluator.evaluate, in_axes=(0,0,None))(eval_state, env_state, model_params)
buff_state = buff.init(
    template_experience=jax.tree_map(
        lambda x: jnp.zeros(x.shape[1:], x.dtype), 
        Experience(
            observation=env_state._observation, 
            policy_logits=jax.vmap(evaluator.get_policy)(eval_state)
        )
    )
)
eval_state = jax.jit(jax.vmap(evaluator.reset))(eval_keys)

state = (env_state, eval_state, buff_state)

animation_states = []

# animation_states.append(jax.tree_util.tree_map(lambda x: x[0], env_state._state))

def collection_step(state, _):
    env_state, evaluator_state, buff_state = state

    observation = env_state._observation

    evaluator_state = jax.vmap(evaluator.evaluate, in_axes=(0,0,None))(evaluator_state, env_state, model_params)
    evaluator_state, action = jax.vmap(evaluator.choose_action)(evaluator_state, env_state)

    env_state, terminated = jax.vmap(env.step)(env_state, action)
    # animation_states.append(jax.tree_util.tree_map(lambda x: x[0], env_state._state))

    buff_state = buff.add_experience(
        buff_state,
        Experience(observation=observation, policy_logits=jax.vmap(evaluator.get_policy)(evaluator_state))
    )

    buff_state = buff.assign_rewards(buff_state, env_state.reward, terminated)

    # buff_state = buff.truncate(buff_state, env_state.truncated)

    evaluator_state = jax.vmap(evaluator.step_evaluator)(evaluator_state, action, terminated)

    env_state, terminated = jax.vmap(env.reset_if_terminated)(env_state, terminated)


    return (env_state, evaluator_state, buff_state), env_state


# for i in range(20):
#     state, _ = collection_step(state, i)



(env_state, evaluator_state, buff_state), states = jax.lax.scan(
    collection_step,
    state,
    jnp.arange(1000)
)

# animation_states.extend([
#     jax.tree_util.tree_map(lambda x: x[i,0] if len(x.shape) > 1 else x[0], env_state._state) for i in range(1000) 
# ])
# pgx.save_svg_animation(animation_states, "test2.svg", frame_duration_seconds=3, color_theme='dark')


In [None]:
buff_state, experience, reward = buff.sample(buff_state)
reward

In [None]:
buff_state.reward_buffer[0,:100]