PSUEDOCODE
* one iteration of selfplay collection takes N steps
* environments are reset when they terminate (or are truncated)
* trajectories are placed in batched replay memory buffer
* rewards are assigned to trajectories after episode is completed

* once a selfplay collection iteration is completed, T training steps are taken
* a training step involves gathering a mini-batch of size M trajectories from non-truncated, terminated episodes in the replay memory buffer
* a trajectory includes metadata necessary to train a model
    * in the case of AZ, this include action visit counts, and final episode reward
* compare model output to metadata, compute loss, SGD, etc

* C collection steps makes up one training epoch
* do whatever to evaluate


def train():
    for _ in range()

In [1]:
from flax import struct
import jax.numpy as jnp
import jax
from functools import partial

In [2]:
@struct.dataclass
class Experience(struct.PyTreeNode):
    observation: struct.PyTreeNode
    policy_logits: jnp.ndarray

# load model parameters
model = ...
optimizer = ...
model_state = ...
optimizer_state = ...


# init buff, env, evaluator
buff = ...
env = ...
evaluator = ...

# reset buff, env, evaluator
env_state, timestep = jax.vmap(env.reset)(keys)
evaluator_state = jax.vmap(evaluator.reset)(keys2)
buff_state = buff.reset()

# combine states into single object
state = (env_state, evaluator_state, buff_state, model_state, timestep.observation)

# define collection step
def collection_step(state, keys):
    env_state, evaluator_state, buff_state, model_state, prev_obs = state
    evaluator_state, policy, evaluation = jax.vmap(
        evaluator.evaluate, 
        static_argnums=(2,3), 
        in_axes=(0,0,None,None,0)
    )(evaluator_state, env_state, model_state, env, keys)
    action = jax.vmap(evaluator.choose_action)(policy_logits, keys)
    env_state, timestep = jax.vmap(env.step)(env_state, action, keys)
    evaluator_state = jax.vmap(evaluator.step)(evaluator_state, action, timestep.obs.terminated?)
    buff_state = buff.add_experience(
        Experience(
            obs = prev_obs,
            policy = policy,
            evaluation = evaluation
        )
    )
    rewards = env.get_rewards(env_state)
    buff_state = buff.assign_rewards(buff_state, env_state.terminated, rewards)
    buff_state = buff.truncate(buff_state, env_state.truncated)
    evaluator_state = evaluator.reset(evaluator_state, env_state)
    env_state = env.reset(env_state)
    return (env_state, evaluator_state, buff_state, timestep.observation), None

# make n collection steps
more_keys = make_more_keys(n_timesteps)
state, _ = jax.lax.scan(collection_step, state, keys)

# define training step
evaluator_state.model_state

def training_step(state, keys):
    


In [6]:
# load/init model parameters
# TODO

# PARAMETERS 
import jumanji
from core_jax.evaluators.evaluator import Evaluator, EvaluatorConfig
from core_jax.utils.ranked_reward_replay_memory import RankedRewardReplayBuffer
from core_jax.utils.replay_memory import EndRewardReplayBuffer
from core_jax.envs.jumanji import JumanjiEnv, make_jumanji_env
from core_jax.evaluators.mcts import MCTSConfig
from core_jax.evaluators.randotron import Randotron


batch_size = 10
max_len_per_batch = 1000
sample_batch_size = 10


# init buffer, env
buff = RankedRewardReplayBuffer(
    batch_size=batch_size,
    max_len_per_batch=max_len_per_batch,
    sample_batch_size=sample_batch_size,
    quantile=0.75,
    episode_reward_memory_len_per_batch=100
)

# from jumanji.environments.logic.minesweeper.types import State
# from jumanji.environments.logic.minesweeper.utils import get_mined_board



env = JumanjiEnv(jumanji.environments.Knapsack())


config = MCTSConfig(
    epsilon=1e-8,
    max_nodes=100,
    puct_coeff=1.0,
    dirichlet_alpha=0.3,
    dirichlet_epsilon=0.25
)




random_key = jax.random.PRNGKey(0)
env_key, eval_key = jax.random.split(random_key, 2)
env_keys = jax.random.split(env_key, batch_size)
eval_keys = jax.random.split(eval_key, batch_size)

env_state, terminated = jax.jit(jax.vmap(env.reset))(env_keys)


evaluator = Randotron(policy_size=env_state.legal_action_mask.shape[1:], num_players=1, config=config)

eval_state = jax.jit(jax.vmap(evaluator.reset))(eval_keys)
eval_state = jax.vmap(evaluator.evaluate, in_axes=(0,None,0))(eval_state, env, env_state)
buff_state = buff.init(
    template_experience=jax.tree_map(
        lambda x: jnp.zeros(x.shape[1:], x.dtype), 
        Experience(
            observation=env_state._observation, 
            policy_logits=jax.vmap(evaluator.get_raw_policy)(eval_state)
        )
    )
)
eval_state = jax.jit(jax.vmap(evaluator.reset))(eval_keys)

state = (env_state, eval_state, buff_state)

def collection_step(state, _):
    env_state, evaluator_state, buff_state = state
    evaluator_state = jax.vmap(evaluator.evaluate, in_axes=(0,None,0))(
        evaluator_state, 
        env, 
        env_state
    )

    evaluator_state, action = jax.vmap(evaluator.choose_action, in_axes=(0,None,0))(
        evaluator_state,
        env, 
        env_state
    )
    env_state, terminated = jax.vmap(env.step)(env_state, action)
    

    buff_state = buff.add_experience(
        buff_state,
        Experience(observation=env_state._observation, policy_logits=jax.vmap(evaluator.get_raw_policy)(evaluator_state))
    )
    buff_state = buff.assign_rewards(buff_state, env_state.reward, terminated)

    # buff_state = buff.truncate(buff_state, env_state.truncated)

    evaluator_state = jax.vmap(evaluator.step_evaluator)(
        evaluator_state, 
        action,
        terminated
    )

    env_state, terminated = jax.vmap(env.reset_if_terminated)(
        env_state,
        terminated
    )

    
    
    
    return (env_state, evaluator_state, buff_state), None


for i in range(100):
    state, _ = collection_step(state, i)
# collect_keys = jax.random.split(random_key, (20000, batch_size))
# (env_state, evaluator_state, buff_state), _ = jax.lax.scan(
#     collection_step,
#     state,
#     jnp.arange(20000)
# )



In [None]:
buff_state, experience, reward = buff.sample(buff_state)
reward

Array([[-1.],
       [ 1.],
       [-1.],
       [-1.],
       [-1.],
       [-1.],
       [ 1.],
       [-1.],
       [ 1.],
       [-1.]], dtype=float32)

In [None]:
buff_state.reward_buffer.mean()

Array(-0.0014, dtype=float32)

In [None]:
buff_state.raw_reward_buffer.max()

Array(0., dtype=float32)