In [None]:
import wandb
import jax
import jax.numpy as jnp
from core.collector import Collector
from core.envs.make import make_env
from core.envs.pgx import make_pgx_env
from core.evaluators.alphazero import AlphaZero, AlphaZeroConfig
from core.evaluators.make import make_evaluator
from core.memory.make import make_replay_buffer
from core.networks.make import make_model
from core.training.train import Trainer, TrainerConfig
from core.memory.replay_memory import EndRewardReplayBuffer


batch_size = 10
max_len_per_batch = 1000
sample_batch_size = 10

env = make_env(
    {
        "env_name": "othello",
        "env_type": "pgx",
        "base_config": {}
    }
)

model = make_model(
    {
        "model_type": "az_resnet",
        "policy_head_out_size": jnp.prod(jnp.array(env.get_action_shape())).item(),
        "value_head_out_size": 1,
        "num_blocks": 2,
        "channels": 4
    }
)

evaluator = make_evaluator(
    {
        "evaluator_type": "alphazero",
        "mcts_iters": 100,
        "temperature": 1.0,
        "epsilon": 1e-8,
        "max_nodes": 100,
        "puct_coeff": 1.0,
        "dirichlet_alpha": 0.3,
        "dirichlet_epsilon": 0.25,
    },
    env,
    model=model
)

buff = make_replay_buffer(
    {
        "buff_type": "end_reward",
        "max_len_per_batch": max_len_per_batch,
        "batch_size": batch_size,
        "sample_batch_size": sample_batch_size
    }
)

   
trainer = Trainer(
    config=TrainerConfig(
        warmup_steps=100,
        collection_steps_per_epoch=100,
        train_steps_per_epoch=4,
        epochs_per_checkpoint=1,
        learning_rate=1e-3,
        momentum=0.9,
        policy_factor=1.0,
        checkpoint_dir="../checkpoints/",
        max_checkpoints_to_keep=3
    ),
    env=env,
    evaluator=evaluator,
    buff=buff,
    model=model
)

wandb.init(
    project="test_az_0",
    config={
        'train_config': trainer.config,
        'model_config': model.config,
        'env_config': env.config,
        'evaluator_config': evaluator.config,
        'buff_config': buff.config
    }
)

state = trainer.init(jax.random.PRNGKey(0))

trainer.train_loop(
    state,
    num_epochs = 10,
    warmup=True
)
