PSUEDOCODE
* one iteration of selfplay collection takes N steps
* environments are reset when they terminate (or are truncated)
* trajectories are placed in batched replay memory buffer
* rewards are assigned to trajectories after episode is completed

* once a selfplay collection iteration is completed, T training steps are taken
* a training step involves gathering a mini-batch of size M trajectories from non-truncated, terminated episodes in the replay memory buffer
* a trajectory includes metadata necessary to train a model
    * in the case of AZ, this include action visit counts, and final episode reward
* compare model output to metadata, compute loss, SGD, etc

* C collection steps makes up one training epoch
* do whatever to evaluate


def train():
    for _ in range()

In [1]:
from flax import struct
import jax.numpy as jnp
import jax
from functools import partial

In [2]:
@struct.dataclass
class Experience(struct.PyTreeNode):
    obs: struct.PyTreeNode
    policy: jnp.ndarray
    evaluation: jnp.ndarray

# load model parameters
model = ...
optimizer = ...
model_state = ...
optimizer_state = ...


# init buff, env, evaluator
buff = ...
env = ...
evaluator = ...

# reset buff, env, evaluator
env_state, timestep = jax.vmap(env.reset)(keys)
evaluator_state = jax.vmap(evaluator.reset)(keys2)
buff_state = buff.reset()

# combine states into single object
state = (env_state, evaluator_state, buff_state, model_state, timestep.observation)

# define collection step
def collection_step(state, keys):
    env_state, evaluator_state, buff_state, model_state, prev_obs = state
    evaluator_state, policy, evaluation = jax.vmap(
        evaluator.evaluate, 
        static_argnums=(2,3), 
        in_axes=(0,0,None,None,0)
    )(evaluator_state, env_state, model_state, env, keys)
    action = jax.vmap(evaluator.choose_action)(policy_logits, keys)
    env_state, timestep = jax.vmap(env.step)(env_state, action, keys)
    evaluator_state = jax.vmap(evaluator.step)(evaluator_state, action, timestep.obs.terminated?)
    buff_state = buff.add_experience(
        Experience(
            obs = prev_obs,
            policy = policy,
            evaluation = evaluation
        )
    )
    rewards = env.get_rewards(env_state)
    buff_state = buff.assign_rewards(buff_state, env_state.terminated, rewards)
    buff_state = buff.truncate(buff_state, env_state.truncated)
    evaluator_state = evaluator.reset(evaluator_state, env_state)
    env_state = env.reset(env_state)
    return (env_state, evaluator_state, buff_state, timestep.observation), None

# make n collection steps
more_keys = make_more_keys(n_timesteps)
state, _ = jax.lax.scan(collection_step, state, keys)

# define training step
evaluator_state.model_state

def training_step(state, keys):
    


In [3]:
# load/init model parameters
# TODO

# PARAMETERS 
import jumanji
from core_jax.evaluators.evaluator import Evaluator, EvaluatorConfig
from core_jax.utils.replay_memory import EndRewardReplayBuffer
from core_jax.envs.jumanji import JumanjiEnv, make_jumanji_env


batch_size = 10
max_len_per_batch = 1000
sample_batch_size = 10


# init buffer, env
buff = EndRewardReplayBuffer(
    batch_size=batch_size,
    max_len_per_batch=max_len_per_batch,
    sample_batch_size=sample_batch_size
)

# from jumanji.environments.logic.minesweeper.types import State
# from jumanji.environments.logic.minesweeper.utils import get_mined_board

from jumanji.environments.logic.minesweeper.types import State
from jumanji.environments.logic.minesweeper.constants import UNEXPLORED_ID, IS_MINE

def minesweeper_reward_fn(state: jumanji.environments.logic.minesweeper.types.State, action: jnp.ndarray):
    return jnp.sum(
        jnp.where(
            jnp.equal(state.board, UNEXPLORED_ID) | jnp.equal(state.board, IS_MINE),
            0,
            1
        )
    ) / (state.board.shape[0] * state.board.shape[1])


env = JumanjiEnv(jumanji.environments.Minesweeper(
    reward_function=minesweeper_reward_fn
))

evaluator = Evaluator(
    EvaluatorConfig(),
)


random_key = jax.random.PRNGKey(0)
env_key, eval_key = jax.random.split(random_key, 2)
env_keys = jax.random.split(env_key, batch_size)
eval_keys = jax.random.split(eval_key, batch_size)

env_state, observation, reward, terminated = jax.jit(jax.vmap(env.reset))(env_keys)
eval_state = jax.jit(jax.vmap(evaluator.reset))(eval_keys)
_, policy_logits, evaluation = jax.vmap(evaluator.evaluate, in_axes=(0,None,0,0))(eval_state, env, env_state, observation)
buff_state = buff.init(template_experience=jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), Experience(observation, policy_logits, evaluation)))
eval_state = jax.jit(jax.vmap(evaluator.reset))(eval_keys)

state = (env_state, eval_state, buff_state, observation)

def collection_step(state, key):
    env_state, evaluator_state, buff_state, prev_obs = state
    evaluator_state, policy_logits, evaluation = jax.vmap(evaluator.evaluate, in_axes=(0,None,0,0))(
        evaluator_state, 
        env, 
        env_state,
        prev_obs
    )

    evaluator_state, action = jax.vmap(evaluator.choose_action, in_axes=(0,None,0,0,0))(
        evaluator_state,
        env, 
        env_state, 
        prev_obs,
        policy_logits
    )
    env_state, observation, reward, terminated = jax.vmap(env.step)(env_state, action)
    evaluator_state = jax.vmap(evaluator.step_evaluator)(
        evaluator_state, 
        action,
        terminated
    )

    buff_state = buff.add_experience(
        buff_state,
        Experience(obs=observation, policy=policy_logits, evaluation=evaluation)
    )
    buff_state = buff.assign_rewards(buff_state, reward, terminated)
    # buff_state = buff.truncate(buff_state, env_state.truncated)
    env_state, observation, reward, terminated = jax.vmap(env.reset_if_terminated)(
        env_state,
        observation,
        reward,
        terminated,
        key
    )
    
    
    return (env_state, evaluator_state, buff_state, observation), None


# for _ in range(100):
#     random_key, _ = jax.random.split(random_key)
#     collect_keys = jax.random.split(random_key, batch_size)
#     state, _ = collection_step(state, collect_keys)
collect_keys = jax.random.split(random_key, (1000, batch_size))
(env_state, evaluator_state, buff_state, observation), _ = jax.lax.scan(
    collection_step,
    state,
    collect_keys
)

