In [18]:
from dataclasses import dataclass
import jax.numpy as jnp
import numpy as np
import jax
import torch
from functools import partial
import abc
import jumanji
from flax import struct
from jumanji.env import State
import jax.random as random
import orbax.checkpoint

import optax

from typing import Optional, Any
import shutil

from flax.training import checkpoints, train_state
from flax import struct, serialization
import orbax.checkpoint

import optax as opx
import flashbax as fbx
import chex
from typing import TypeVar

In [19]:
Experience = TypeVar("Experience", bound=chex.ArrayTree)

@struct.dataclass
class BufferState:
    needs_reward: jnp.ndarray
    buffer: Experience
    next_index: int
    batch_size: int
    max_len_per_batch: int

In [20]:
# stolen from flashbax
def init_buffer_state(
    experience: Experience,
    batch_size: int,
    max_len_per_batch: int
) -> BufferState:
    
    experience = jax.tree_map(jnp.empty_like, experience)

    experience = jax.tree_map(
        lambda x: jnp.broadcast_to(
            x[None, None,...], (batch_size, max_len_per_batch, *x.shape)
        ),
        experience,
    )

    return BufferState(
        buffer=experience,
        next_index=0,
        needs_reward=jnp.zeros((batch_size, max_len_per_batch, 1), dtype=jnp.bool_),
        max_len_per_batch=max_len_per_batch,
        batch_size=batch_size
    )

def add_experience(
    buffer_state: BufferState,
    experience: Experience
) -> BufferState:

    def add_item(items, new_item):
        return items.at[:, buffer_state.next_index].set(new_item)

    updated_pytree = jax.tree_map(add_item, buffer_state.buffer, experience)

    # Update the next index
    needs_reward = buffer_state.needs_reward.at[:, buffer_state.next_index, 0].set(True)
    updated_next_index = (buffer_state.next_index + 1) % buffer_state.max_len_per_batch
    
    return buffer_state.replace(
        buffer=updated_pytree,
        next_index=updated_next_index,
        needs_reward=needs_reward
    )

def assign_rewards(
    buffer_state: BufferState,
    rewards: jnp.ndarray,
    select_batch: jnp.ndarray,
    reward_field: str = "reward"
) -> BufferState:
    buffer_state.buffer[reward_field] += rewards * buffer_state.needs_reward * select_batch

    return buffer_state.replace(
        needs_reward = buffer_state.needs_reward * (1 - select_batch)
    )


In [21]:
buff_state = init_buffer_state(
    {"obs": jnp.array([0, 0]), "reward": jnp.array([0])},
    batch_size=4,
    max_len_per_batch=100
)

for j in range(10):
    buff_state = add_experience(
        buff_state,
        {"obs": jnp.stack([jnp.array([j, i]) for i in range(4)]), "reward": jnp.stack([jnp.array([0]) for i in range(4)])}
    )

buff_state = assign_rewards(
    buff_state,
    jnp.array([1, 2, 3, 4]).reshape(-1, 1, 1),
    jnp.array([1, 1, 0, 1]).reshape(-1, 1, 1)
)

for j in range(10):
    buff_state = add_experience(
        buff_state,
        {"obs": jnp.stack([jnp.array([j, i]) for i in range(4)]), "reward": jnp.stack([jnp.array([0]) for i in range(4)])}
    )

buff_state = assign_rewards(
    buff_state,
    jnp.array([5, 6, 7, 8]).reshape(-1, 1, 1),
    jnp.array([1, 0, 1, 1]).reshape(-1, 1, 1)
)

In [35]:
@struct.dataclass
class MCTS_State(abc.ABC):
    env_state: State
    action_map: jnp.ndarray
    p_vals: jnp.ndarray
    n_vals: jnp.ndarray
    w_vals: jnp.ndarray
    actions_taken: jnp.ndarray
    visits: jnp.ndarray
    next_empty: jnp.ndarray
    cur_node: jnp.ndarray
    depth: jnp.ndarray
    subtrees: jnp.ndarray
    parents: jnp.ndarray
    rng: PRNGKeyArray

def init_state(
    keys: PRNGKeyArray
):
    max_nodes = 100
    policy_size = 10
    total_slots = 2 + max_nodes
    std_shape = (total_slots, policy_size)

    visits = jnp.zeros(max_nodes + 1, dtype=jnp.int32)
    visits.at[0].set(1)

    state = MCTS_State(
        action_map=jnp.zeros(std_shape, dtype=jnp.int32),
        p_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        n_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        w_vals=jnp.zeros(std_shape, dtype=jnp.float32),
        actions_taken=jnp.zeros(max_nodes + 1, dtype=jnp.int32),
        visits=visits,
        next_empty=jnp.full(1, 2, dtype=jnp.int32),
        cur_node=jnp.ones(1, dtype=jnp.int32),
        depth=jnp.zeros(1, dtype=jnp.int32),
        subtrees=jnp.zeros(1, dtype=jnp.int32),
        parents=jnp.zeros(1, dtype=jnp.int32),
        rng=keys
    )

    return state

class MCTS_Evaluator:
    def __init__(self, env):
        self.state = None
        self.env = env

    def reset(self, keys):
        return init_state(keys), self.env.reset(keys)
        


In [38]:
keys = jax.random.split(jax.random.PRNGKey(0), 5)
evaluator = MCTS_Evaluator(env)

mcts, (state, timestep) = jax.vmap(evaluator.reset)(keys)