PSUEDOCODE
* one iteration of selfplay collection takes N steps
* environments are reset when they terminate (or are truncated)
* trajectories are placed in batched replay memory buffer
* rewards are assigned to trajectories after episode is completed

* once a selfplay collection iteration is completed, T training steps are taken
* a training step involves gathering a mini-batch of size M trajectories from non-truncated, terminated episodes in the replay memory buffer
* a trajectory includes metadata necessary to train a model
    * in the case of AZ, this include action visit counts, and final episode reward
* compare model output to metadata, compute loss, SGD, etc

* C collection steps makes up one training epoch
* do whatever to evaluate


def train():
    for _ in range()

In [1]:
from flax import struct
import jax.numpy as jnp
import jax
from functools import partial

In [2]:
class Experience(struct.PyTreeNode):
    obs: jnp.ndarray

In [3]:
from core_jax.utils.replay_memory import EndRewardReplayBuffer

In [4]:
def test(rng, batch_size, max_len_per_batch, sample_batch_size):

    buff = EndRewardReplayBuffer(
        template_experience=Experience(obs=jnp.array([0, 0])),
        batch_size=batch_size,
        max_len_per_batch=max_len_per_batch,
        sample_batch_size=sample_batch_size
    )
    
    for j in range(11):
        buff.add_experience(
            Experience(obs=jnp.stack([jnp.array([j, i]) for i in range(4)]))
        )


    buff.assign_rewards(
        jnp.array([[1,0], [0.5,0.5], [0,1], [0.5, 0.5]]).reshape(-1, 2),
        jnp.array([True, True, False, True])
    )

    for j in range(11):
        buff.add_experience(
            Experience(obs=jnp.stack([jnp.array([j+11, i]) for i in range(4)]))
        )

    buff.assign_rewards(
        jnp.array([[1,0], [1,0], [1,0], [1, 0]]).reshape(-1, 2),
        jnp.array([True, False, True, True])
    )
    
    buff.truncate(
        jnp.array([False, True, True, False])
    )

    return buff.sample(rng), buff

In [15]:
sample, buff = test(jax.random.PRNGKey(6), 4, 30, 10)
sample

(Experience(obs=Array([[ 3,  1],
        [ 3,  2],
        [ 8,  1],
        [ 7,  0],
        [20,  3],
        [21,  2],
        [ 2,  2],
        [ 3,  0],
        [11,  3],
        [12,  0]], dtype=int32)),
 Array([[0.5],
        [0. ],
        [0.5],
        [0. ],
        [0. ],
        [0. ],
        [1. ],
        [0. ],
        [1. ],
        [0. ]], dtype=float32))

In [11]:
buff.state.needs_reward[1]

Array([[False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False],
       [False]], dtype=bool)

In [35]:
@struct.dataclass
class MCTS_State(abc.ABC):
    env_state: State
    action_map: jnp.ndarray
    p_vals: jnp.ndarray
    n_vals: jnp.ndarray
    w_vals: jnp.ndarray
    actions_taken: jnp.ndarray
    visits: jnp.ndarray
    next_empty: jnp.ndarray
    cur_node: jnp.ndarray
    depth: jnp.ndarray
    subtrees: jnp.ndarray
    parents: jnp.ndarray
    rng: PRNGKeyArray

def init_state(
    keys: PRNGKeyArray
):
    max_nodes = 100
    policy_size = 10
    total_slots = 2 + max_nodes
    std_shape = (total_slots, policy_size)

    visits = jnp.zeros(max_nodes + 1, dtype=jnp.int32)
    visits.at[0].set(1)

    state = MCTS_State(
        action_map=jnp.zeros(std_shape, dtype=jnp.int32),
        p_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        n_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        w_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        actions_taken=jnp.zeros(max_nodes + 1, dtype=jnp.int32),
        visits=visits,
        next_empty=jnp.full(1, 2, dtype=jnp.int32),
        cur_node=jnp.ones(1, dtype=jnp.int32),
        depth=jnp.zeros(1, dtype=jnp.int32),
        subtrees=jnp.zeros(1, dtype=jnp.int32),
        parents=jnp.zeros(1, dtype=jnp.int32),
        rng=keys
    )

    return state

class MCTS_Evaluator:
    def __init__(self, env):
        self.state = None
        self.env = env

    def reset(self, keys):
        return init_state(keys), self.env.reset(keys)
        
