In [1]:
import pgx

from flax import struct
import jax.numpy as jnp
import jax
from functools import partial
from flax import linen as nn

@struct.dataclass
class Experience:
    observation: struct.PyTreeNode
    policy_logits: jnp.ndarray

class ResidualBlock(nn.Module):
    channels: int

    @nn.compact
    def __call__(self, x, training):
        y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME')(x)
        y = nn.BatchNorm(use_running_average=not training)(y)
        y = nn.relu(y)
        y = nn.Conv(features=self.channels, kernel_size=(3,3), strides=(1,1), padding='SAME')(y)
        y = nn.BatchNorm(use_running_average=not training)(y)
        return nn.relu(x + y)


class AZResnet(nn.Module):
    policy_head_out_size: int
    value_head_out_size: int
    num_blocks: int
    channels: int

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Conv(features=self.channels, kernel_size=(1,1), strides=(1,1), padding='SAME')(x)
        x = nn.BatchNorm(use_running_average=not training)(x)
        x = nn.relu(x)

        for _ in range(self.num_blocks):
            x = ResidualBlock(channels=self.channels)(x, training=training)

        # policy head
        policy = nn.Conv(features=2, kernel_size=(1,1), strides=(1,1), padding='SAME')(x)
        policy = nn.BatchNorm(use_running_average=not training)(policy)
        policy = nn.relu(policy)
        policy = policy.reshape((policy.shape[0], -1))
        policy = nn.Dense(features=self.policy_head_out_size)(policy)

        # value head
        value = nn.Conv(features=1, kernel_size=(1,1), strides=(1,1), padding='SAME')(x)
        value = nn.BatchNorm(use_running_average=not training)(value)
        value = nn.relu(value)
        value = value.reshape((value.shape[0], -1))
        value = nn.Dense(features=self.value_head_out_size)(value)
        value = nn.tanh(value)

        return policy, value
    


In [4]:
from core_jax.collector import Collector
from core_jax.envs.pgx import make_pgx_env
from core_jax.evaluators.alphazero import AlphaZero, AlphaZeroConfig
from core_jax.utils.replay_memory import EndRewardReplayBuffer


batch_size = 10
max_len_per_batch = 1000
sample_batch_size = 10

env = make_pgx_env("othello")

az_config = AlphaZeroConfig(
    mcts_iters=50,
    temperature=1.0,
    epsilon=1e-8,
    max_nodes=100,
    puct_coeff=1.0,
    dirichlet_alpha=0.3,
    dirichlet_epsilon=0.25
)

model = AZResnet(
    policy_head_out_size=jnp.prod(jnp.array(env.get_action_shape())).item(),
    value_head_out_size=1,
    num_blocks=2,
    channels=4
)

evaluator = AlphaZero(
    env=env,
    config=az_config,
    model=model,
)

buff = EndRewardReplayBuffer(
    max_len_per_batch=max_len_per_batch,
    batch_size=batch_size,
    sample_batch_size=sample_batch_size,
)

collector = Collector(
    config=None,
    env=env,
    evaluator=evaluator,
    buff=buff
)
random_key = jax.random.PRNGKey(0)

random_key, model_key, env_key, eval_key = jax.random.split(random_key, 4)


env_keys = jax.random.split(env_key, batch_size)
eval_keys = jax.random.split(eval_key, batch_size)


env_state, _ = jax.vmap(env.reset)(env_keys)
eval_state = jax.vmap(evaluator.reset)(eval_keys)

cstate = collector.init(env_state, eval_state)

model_params = evaluator.init_params(model_key)

collect_fn = partial(
    collector.collect_step,
    eval_args = {
        "model_params": model_params,
    }
)

cstate, _ = jax.lax.scan(
    lambda x, _: (collect_fn(x), None),
    cstate,
    jnp.arange(1000)
)

collector.buff.sample(cstate.buff_state)

AttributeError: 'CollectorState' object has no attribute 'buff'