PSUEDOCODE
* one iteration of selfplay collection takes N steps
* environments are reset when they terminate (or are truncated)
* trajectories are placed in batched replay memory buffer
* rewards are assigned to trajectories after episode is completed

* once a selfplay collection iteration is completed, T training steps are taken
* a training step involves gathering a mini-batch of size M trajectories from non-truncated, terminated episodes in the replay memory buffer
* a trajectory includes metadata necessary to train a model
    * in the case of AZ, this include action visit counts, and final episode reward
* compare model output to metadata, compute loss, SGD, etc

* C collection steps makes up one training epoch
* do whatever to evaluate


def train():
    for _ in range()

In [1]:
from flax import struct
import jax.numpy as jnp
import jax
from functools import partial

In [2]:
@struct.dataclass
class Trajectory:
    obs: jnp.ndarray
    reward: jnp.ndarray

In [7]:
class Experience(struct.PyTreeNode):
    obs: jnp.ndarray

@struct.dataclass
class EndRewardReplayBufferState:
    next_index: jnp.ndarray # next index to write experience to
    next_reward_index: jnp.ndarray # next index to write reward to
    buffer: struct.PyTreeNode # buffer of experiences
    reward_buffer: jnp.ndarray # buffer of rewards
    needs_reward: jnp.ndarray # does experience need reward
    populated: jnp.ndarray # is experience populated

class EndRewardReplayBuffer:
    def __init__(self,
        template_experience: struct.PyTreeNode,
        batch_size: int,
        max_len_per_batch: int,
        sample_batch_size: int,
    ):
        self.sample_batch_size = sample_batch_size
        self.max_len_per_batch = max_len_per_batch
        self.batch_size = batch_size

        experience = jax.tree_map(
            lambda x: jnp.broadcast_to(
                x, (batch_size, max_len_per_batch, *x.shape)
            ),
            template_experience,
        )

        self.state = EndRewardReplayBufferState(
            next_index=jnp.zeros((batch_size,), dtype=jnp.int32),
            next_reward_index=jnp.zeros((batch_size,), dtype=jnp.int32),
            reward_buffer=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.float32),
            buffer=experience,
            needs_reward=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),
            populated=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),
        )

    def add_experience(self, experience: struct.PyTreeNode) -> None:
        self.state = add_experience(self.state, experience, self.batch_size, self.max_len_per_batch)    

    def assign_rewards(self, rewards: jnp.ndarray, select_batch: jnp.ndarray) -> None:
        self.state = assign_rewards(self.state, rewards, select_batch, self.max_len_per_batch)

    def sample(self, rng: jax.random.PRNGKey) -> struct.PyTreeNode:
        return sample(self.state, rng, self.batch_size, self.max_len_per_batch, self.sample_batch_size)




@partial(jax.jit, static_argnums=(2,3))
def add_experience(
    buffer_state: EndRewardReplayBufferState,
    experience: struct.PyTreeNode,
    batch_size: int,
    max_len_per_batch: int,
) -> EndRewardReplayBufferState:
    
    def add_item(items, new_item):
        return items.at[jnp.arange(batch_size), buffer_state.next_index].set(new_item)

    return buffer_state.replace(
        buffer = jax.tree_map(add_item, buffer_state.buffer, experience),
        next_index = (buffer_state.next_index + 1) % max_len_per_batch,
        needs_reward = buffer_state.needs_reward.at[:, buffer_state.next_index, 0].set(True),
        populated = buffer_state.populated.at[:, buffer_state.next_index, 0].set(True)
    )

@partial(jax.jit, static_argnums=(3,))
def assign_rewards(
    buffer_state: EndRewardReplayBufferState,
    rewards: jnp.ndarray,
    select_batch: jnp.ndarray,
    max_len_per_batch: int
) -> EndRewardReplayBufferState:
    rolled = jax.vmap(jnp.roll, in_axes=(0, 0))(rewards, buffer_state.next_reward_index)
    tiled = jnp.tile(rolled, (1, max_len_per_batch // rewards.shape[1]))
    
    return buffer_state.replace(
        reward_buffer = jnp.where(
            select_batch[..., None, None] & buffer_state.needs_reward,
            tiled[..., None],
            buffer_state.reward_buffer
        ),
        next_reward_index = jnp.where(
            select_batch,
            buffer_state.next_index,
            buffer_state.next_reward_index
        ),
        needs_reward = jnp.where(
            select_batch[..., None, None],
            False,
            buffer_state.needs_reward
        )
    )

@partial(jax.jit, static_argnums=(2,3,4))
def sample(
    buffer_state: EndRewardReplayBufferState,
    rng: jax.random.PRNGKey,
    batch_size: int,
    max_len_per_batch: int,
    sample_batch_size: int
) -> struct.PyTreeNode:
    probs = ((~buffer_state.needs_reward).reshape(-1) * buffer_state.populated.reshape(-1)).astype(jnp.float32)
    indices = jax.random.choice(
        rng,
        max_len_per_batch * batch_size,
        shape=(sample_batch_size,),
        replace=False,
        p = probs / probs.sum()
    )
    batch_indices = indices // max_len_per_batch
    item_indices = indices % max_len_per_batch

    return jax.tree_util.tree_map(
        lambda x: x[batch_indices, item_indices],
        buffer_state.buffer
    ), buffer_state.reward_buffer[batch_indices, item_indices]

In [8]:
def test(rng, batch_size, max_len_per_batch, sample_batch_size):

    buff = EndRewardReplayBuffer(
        template_experience=Experience(obs=jnp.array([0, 0])),
        batch_size=batch_size,
        max_len_per_batch=max_len_per_batch,
        sample_batch_size=sample_batch_size
    )
    
    for j in range(11):
        buff.add_experience(
            Experience(obs=jnp.stack([jnp.array([j, i]) for i in range(4)]))
        )


    buff.assign_rewards(
        jnp.array([[1,0], [0.5,0.5], [0,1], [0.5, 0.5]]).reshape(-1, 2),
        jnp.array([True, True, False, True])
    )

    for j in range(11):
        buff.add_experience(
            Experience(obs=jnp.stack([jnp.array([j+11, i]) for i in range(4)]))
        )

    buff.assign_rewards(
        jnp.array([[1,0], [1,0], [1,0], [1, 0]]).reshape(-1, 2),
        jnp.array([True, True, True, True])
    )

    return buff.sample(rng)

In [13]:
test(jax.random.PRNGKey(1), 4, 30, 10)

(Experience(obs=Array([[12,  3],
        [15,  0],
        [10,  3],
        [ 9,  0],
        [16,  2],
        [ 7,  2],
        [17,  0],
        [ 2,  1],
        [13,  3],
        [15,  2]], dtype=int32)),
 Array([[0. ],
        [1. ],
        [0.5],
        [0. ],
        [1. ],
        [0. ],
        [1. ],
        [0.5],
        [1. ],
        [0. ]], dtype=float32))

In [14]:
test_jit = jax.jit(test, static_argnums=(1, 2, 3))

In [20]:
sample, state = test_jit(jax.random.PRNGKey(7), 4, 30, 10)
sample

Experience(obs=Array([[16,  2],
       [19,  2],
       [11,  2],
       [21,  0],
       [12,  2],
       [ 6,  3],
       [ 6,  2],
       [12,  0],
       [14,  1],
       [16,  0]], dtype=int32))

In [45]:
sample(buff_state, jax.random.PRNGKey(1))

({'obs': Array([[2, 3],
         [1, 1],
         [5, 2],
         [1, 1],
         [6, 2],
         [1, 2],
         [1, 2],
         [2, 0],
         [7, 2],
         [9, 3]], dtype=int32),
  'reward': Array([[8],
         [0],
         [7],
         [2],
         [7],
         [7],
         [7],
         [5],
         [7],
         [8]], dtype=int32)},
 Array([3, 1, 2, 1, 2, 2, 2, 0, 2, 3], dtype=int32),
 Array([12, 11, 15,  1,  6, 11,  1, 12,  7, 19], dtype=int32))

In [35]:
@struct.dataclass
class MCTS_State(abc.ABC):
    env_state: State
    action_map: jnp.ndarray
    p_vals: jnp.ndarray
    n_vals: jnp.ndarray
    w_vals: jnp.ndarray
    actions_taken: jnp.ndarray
    visits: jnp.ndarray
    next_empty: jnp.ndarray
    cur_node: jnp.ndarray
    depth: jnp.ndarray
    subtrees: jnp.ndarray
    parents: jnp.ndarray
    rng: PRNGKeyArray

def init_state(
    keys: PRNGKeyArray
):
    max_nodes = 100
    policy_size = 10
    total_slots = 2 + max_nodes
    std_shape = (total_slots, policy_size)

    visits = jnp.zeros(max_nodes + 1, dtype=jnp.int32)
    visits.at[0].set(1)

    state = MCTS_State(
        action_map=jnp.zeros(std_shape, dtype=jnp.int32),
        p_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        n_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        w_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        actions_taken=jnp.zeros(max_nodes + 1, dtype=jnp.int32),
        visits=visits,
        next_empty=jnp.full(1, 2, dtype=jnp.int32),
        cur_node=jnp.ones(1, dtype=jnp.int32),
        depth=jnp.zeros(1, dtype=jnp.int32),
        subtrees=jnp.zeros(1, dtype=jnp.int32),
        parents=jnp.zeros(1, dtype=jnp.int32),
        rng=keys
    )

    return state

class MCTS_Evaluator:
    def __init__(self, env):
        self.state = None
        self.env = env

    def reset(self, keys):
        return init_state(keys), self.env.reset(keys)
        


In [38]:
keys = jax.random.split(jax.random.PRNGKey(0), 5)
evaluator = MCTS_Evaluator(env)

mcts, (state, timestep) = jax.vmap(evaluator.reset)(keys)

In [None]:
from flax import linen as nn

In [None]:
class ResiudalBlock(nn.Module):
    in_channels: int
    out_channels: int
    kernel_size: int = 3
    stride: int = 1
    def __call__(self, x):
        xt = nn.Sequential([
            nn.Conv(self.out_channels, self.kernel_size, self.stride),
        ])


class AZResnet(nn.Module):
    def __call__(self, x):
        dtype = jnp.float32
        x = x.astype(dtype)
        


    

In [75]:
jax.lax.scan(lambda x, y: x + y, jnp.arange(10), jnp.arange(10))

TypeError: scan body output must be a pair, got ShapedArray(int32[10]).