PSUEDOCODE
* one iteration of selfplay collection takes N steps
* environments are reset when they terminate (or are truncated)
* trajectories are placed in batched replay memory buffer
* rewards are assigned to trajectories after episode is completed

* once a selfplay collection iteration is completed, T training steps are taken
* a training step involves gathering a mini-batch of size M trajectories from non-truncated, terminated episodes in the replay memory buffer
* a trajectory includes metadata necessary to train a model
    * in the case of AZ, this include action visit counts, and final episode reward
* compare model output to metadata, compute loss, SGD, etc

* C collection steps makes up one training epoch
* do whatever to evaluate


def train():
    for _ in range()

In [2]:
from flax import struct
import jax.numpy as jnp
import jax
from functools import partial

In [3]:
class Experience(struct.PyTreeNode):
    obs: jnp.ndarray

In [4]:
from core_jax.utils.replay_memory import EndRewardReplayBuffer

In [5]:
def test(rng, batch_size, max_len_per_batch, sample_batch_size):

    buff = EndRewardReplayBuffer(
        template_experience=Experience(obs=jnp.array([0, 0])),
        batch_size=batch_size,
        max_len_per_batch=max_len_per_batch,
        sample_batch_size=sample_batch_size
    )
    
    for j in range(11):
        buff.add_experience(
            Experience(obs=jnp.stack([jnp.array([j, i]) for i in range(4)]))
        )


    buff.assign_rewards(
        jnp.array([[1,0], [0.5,0.5], [0,1], [0.5, 0.5]]).reshape(-1, 2),
        jnp.array([True, True, False, True])
    )

    for j in range(11):
        buff.add_experience(
            Experience(obs=jnp.stack([jnp.array([j+11, i]) for i in range(4)]))
        )

    buff.assign_rewards(
        jnp.array([[1,0], [1,0], [1,0], [1, 0]]).reshape(-1, 2),
        jnp.array([True, False, True, True])
    )
    
    buff.truncate(
        jnp.array([False, True, True, False])
    )

    return buff.sample(rng), buff

In [7]:
# import jumanji

from core_jax.envs.jumanji import make_jumanji_env

env = make_jumanji_env('Minesweeper-v0')

key = jax.random.PRNGKey(0)
keys = jax.random.split(key, 10)
state, observation, reward, terminated = jax.jit(jax.vmap(env.reset))(keys)

buff = EndRewardReplayBuffer(
    batch_size=10,
    max_len_per_batch=100,
    sample_batch_size=10
)

buff_state = buff.init(template_experience=jax.tree_map(lambda x: jnp.zeros(x.shape[1:], x.dtype), observation))

actions = jax.vmap(env.get_random_legal_action)(state, observation)
state, observation, reward, terminated = jax.jit(jax.vmap(env.step))(state, actions)

buff_state = buff.add_experience(
    buff_state,
    observation
)

actions = jax.vmap(env.get_random_legal_action)(state, observation)
state, observation, reward, terminated = jax.jit(jax.vmap(env.step))(state, actions)

buff_state = buff.add_experience(
    buff_state,
    observation
)

actions = jax.vmap(env.get_random_legal_action)(state, observation)
state, observation, reward, terminated = jax.jit(jax.vmap(env.step))(state, actions)

buff_state = buff.add_experience(
    buff_state,
    observation
)

observation


Observation(board=Array([[[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [ 1,  1,  2, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1]],

       [[-1, -1, -1, -1, -1, -1, -1, -1, -1,  0],
        [ 0,  2, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-1, -1, -1, -1, -1, -

In [23]:
env._env.action_spec().generate_value()


Array([0, 0], dtype=int32)

# load model parameters
model = ...
optimizer = ...
model_state = ...
optimizer_state = ...


# init buff, env, evaluator
buff = ...
env = ...
evaluator = ...

# reset buff, env, evaluator
env_state, timestep = jax.vmap(env.reset)(keys)
evaluator_state = jax.vmap(evaluator.reset)(keys2)
buff_state = buff.reset()

# combine states into single object
state = (env_state, evaluator_state, buff_state, model_state, timestep.observation)

# define collection step
def collection_step(state, keys):
    env_state, evaluator_state, buff_state, model_state, prev_obs = state
    evaluator_state, policy, evaluation = jax.vmap(
        evaluator.evaluate, 
        static_argnums=(2,3), 
        in_axes=(0,0,None,None,0)
    )(evaluator_state, env_state, model_state, env, keys)
    action = jax.vmap(evaluator.choose_action)(policy_logits, keys)
    env_state, timestep = jax.vmap(env.step)(env_state, action, keys)
    evaluator_state = jax.vmap(evaluator.step)(evaluator_state, action, timestep.obs.terminated?)
    buff_state = buff.add_experience(
        Experience(
            obs = prev_obs,
            policy = policy,
            evaluation = evaluation
        )
    )
    rewards = env.get_rewards(env_state)
    buff_state = buff.assign_rewards(buff_state, env_state.terminated, rewards)
    buff_state = buff.truncate(buff_state, env_state.truncated)
    evaluator_state = evaluator.reset(evaluator_state, env_state)
    env_state = env.reset(env_state)
    return (env_state, evaluator_state, buff_state, timestep.observation), None

# make n collection steps
more_keys = make_more_keys(n_timesteps)
state, _ = jax.lax.scan(collection_step, state, keys)

# define training step
evaluator_state.model_state

def training_step(state, keys):
    


In [71]:
buff.state.buffer.board.shape

(10, 100, 10, 10)

In [72]:
state.board.shape

(10, 10, 10)

In [11]:
buff.state.needs_reward[1]

Array([[False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False]], dtype=bool)

In [35]:
@struct.dataclass
class MCTS_State(abc.ABC):
    env_state: State
    action_map: jnp.ndarray
    p_vals: jnp.ndarray
    n_vals: jnp.ndarray
    w_vals: jnp.ndarray
    actions_taken: jnp.ndarray
    visits: jnp.ndarray
    next_empty: jnp.ndarray
    cur_node: jnp.ndarray
    depth: jnp.ndarray
    subtrees: jnp.ndarray
    parents: jnp.ndarray
    rng: PRNGKeyArray

def init_state(
    keys: PRNGKeyArray
):
    max_nodes = 100
    policy_size = 10
    total_slots = 2 + max_nodes
    std_shape = (total_slots, policy_size)

    visits = jnp.zeros(max_nodes + 1, dtype=jnp.int32)
    visits.at[0].set(1)

    state = MCTS_State(
        action_map=jnp.zeros(std_shape, dtype=jnp.int32),
        p_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        n_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        w_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        actions_taken=jnp.zeros(max_nodes + 1, dtype=jnp.int32),
        visits=visits,
        next_empty=jnp.full(1, 2, dtype=jnp.int32),
        cur_node=jnp.ones(1, dtype=jnp.int32),
        depth=jnp.zeros(1, dtype=jnp.int32),
        subtrees=jnp.zeros(1, dtype=jnp.int32),
        parents=jnp.zeros(1, dtype=jnp.int32),
        rng=keys
    )

    return state

class MCTS_Evaluator:
    def __init__(self, env):
        self.state = None
        self.env = env

    def reset(self, keys):
        return init_state(keys), self.env.reset(keys)
        
