# `parllel`

`parllel` is a modular, flexible framework for developing performant algorithms in Reinforcement Learning.

Rather than being a library of algorithm implementations, it instead provides primitive types that are useful for research in RL, then makes it easier to optimize algorithms for speed. `parllel` supports recurrent agents/algorithms, visual RL, multi-agent RL, and RL on graphs/pointclouds.

## Arrays

One of the most fundamental types in `parllel` is the `Array`. It's similar to a `numpy` array, but is intended for data storage rather than math operations.

In [None]:
import numpy as np
from parllel import Array

array = Array(batch_shape=(5, 4), dtype=np.float32)  # use batch_shape instead of shape
array[:] = np.arange(4)
print(array)

To do math operations, we can get a view as an ndarray (this operation does not copy the data).

In [None]:
ndarray = array.to_ndarray()
print(ndarray.sum(axis=-1))

#### padding
In RL, we often need to save state between batches/iterations. Since this state is often associated with time (e.g. next_observation, previous_action, etc.), a convenient place to store this information is in the array itself. For this, we use the `padding` argument.

In [None]:
array = Array(batch_shape=(5, 4), dtype=np.float32, padding=1)
array[:] = np.arange(4)
array[5] = [4, 5, 6, 7]  # note that this appears to be out of bounds!
print(array)
print(array[5])
print(array[array.last + 1])
assert array.last + 1 == 5

The `array.last + 1` is just syntactic sugar that makes it clear we are writing beyond the end of the array.

The values written into the padding are not "visible" to normal operations, or when converting to a numpy array. If we want to access them in the next iteration, we can call `rotate()`.

In [None]:
array[...] = 0
array.rotate()
print(array[0])  # [4, 5, 6, 7] has been copied to the 0th position in the array

#### warning: -1 is not the last element!

One important difference between `Array` and `np.ndarray`, is that negative indices are relative to the *beginning* of the array, not the end.

This is used to access the last element of the last batch. `array.last` can always be used to access the last element.

In [None]:
print(array[-1])
array[4] = [42, 7, 42, 7]
array.rotate()
print(array[-1])

#### full_size
For e.g. replay buffers, we may want to allocate a lot of memory, but only a small window is visible for collecting samples from the environment. This window then slides along the entire replay buffer until its full. We do this using the `full_size` argument.

In [None]:
array = Array(batch_shape=(5, 4), dtype=np.float32, full_size=10)  # replaces leading batch dimension, e.g. 5
array[...] = 7
array.rotate()
array[...] = 42
array.rotate()
print(array)
array.rotate()
print(array)

`padding` and `full_size` can be combined arbitrarily.

#### next & previous
RL is often concerned with comparing a value to its past (or future) values. One example is a replay buffer for SAC, where we need both the observation and the next observation to compute the loss for Q-learning. Because Arrays keep track of their indices, we can conveniently access these through the `next` and `previous` attributes.

In [None]:
array = Array(batch_shape=(5, 4), dtype=np.float32, padding=1)
array[...] = np.arange(np.prod(array.shape)).reshape(array.shape)
array[-1] = np.arange(-4, 0)
print(array)
print("previous:\n", array.previous)
print("next:\n", array.next)

This also works with slices and elements of the array, not just with the entire array.

In [None]:
print(array[2])
print("previous: ", array[2].previous)
print("next: ", array[2].next)


#### storage
In RL, we often want to run several environments in parallel for collecting samples faster. In order to avoid copies, we can have these environments write directly to Arrays in shared memory.

In [None]:
import multiprocessing as mp
from operator import setitem

array = Array(batch_shape=(5, 4), dtype=np.int32, storage="shared")
subarray = array[np.array([1, 3]), np.array([0, 2])]  # unlike for ndarrays, this does not produce a copy

p = mp.Process(target=setitem, args=(subarray, ..., 42))  # executes subarray[...] = 42 in another process
p.start()
p.join()

print(array)

# array.close()  # always close arrays allocated in shared memory

## ArrayDict

Often in RL, it is useful to store data in a tree structure. This allows for uniform handling of data in many cases, even when the underlying structure is different. The ArrayDict is a simple, lightweight data structure that stores any array-like objects and simplifies handling them.

ArrayDict is inspired by TensorDict (and also by JAX trees), but is a lot more flexible.

In [None]:
from parllel import ArrayDict

tree = ArrayDict({
    "observation": Array(batch_shape=(10, 5, 4), dtype=np.float32),
    "action": {  # in a multi-agent problem, action might be a dictionary of actions
        "pinky": Array(batch_shape=(10, 5), dtype=np.int64),
        "the_brain": Array(batch_shape=(10, 5, 2), dtype=np.float32),
    },
    "done": Array(batch_shape=(10, 5), dtype=bool),
})

tree[0, 1] = 42 * 10 * 5  # you can write 
print(tree[0, 1])  # you can index the tree
print()
print(tree.dtype)  # you can get attributes
print()
print(tree.to_ndarray().mean(axis=(0, 1)))  # convert to nd.array and calculate the mean across batch dimensions

The leaf nodes can be `Array`, `np.ndarray`, `torch.Tensor`, `jax.numpy.array`, etc.

ArrayDict has two methods that are not present in normal python dictionaries: `to_ndarray` and `apply`.

`to_ndarray` converts all leaf nodes to `np.ndarray`. `apply` calls a function with each leaf node as an argument (as well as any other args and kwargs), and returns a new `ArrayDict` with the result. `map` is an alias for `apply`.

In [None]:
import torch

tensor_tree = tree[0].to_ndarray().apply(torch.from_numpy)
print(tensor_tree)
print()
tensor_tree[1] = 0
tensor_tree[2] = 42 * 10 * 5
print(tree[0, 2])  # Arrays/ndarrays/tensors all share the same storage

One of the biggest advantages of `ArrayDict` is the ability to treat arrays and array trees identically. In the following code, we don't care if action is a single array or an array tree (as in multi-agent reinforcement learning).

In [None]:
action = tree.to_ndarray()["action"]

print(action.mean(), action.std())

## Putting it all together: Sampling

Primitives like Array and ArrayDict allow us to write very powerful and expressive code. For example, sampling (collecting rollouts from the policy) can be implemented as simply as:

In [None]:
from parllel import dict_map
from parllel.cages import SerialCage, TrajInfo
from parllel.torch.agents.categorical import CategoricalPgAgent, ModelOutputs, DistParams
from parllel.torch.distributions.categorical import Categorical

from torch.nn import Linear
from gymnasium.envs.classic_control.cartpole import CartPoleEnv

batch_T, batch_B = 10, 5

# create environments
envs = [SerialCage(CartPoleEnv, {}, TrajInfo) for _ in range(batch_B)]

# get example action and observation from environment step
envs[0].random_step_async()
action, observation, _, _, _, _ = envs[0].await_step()

# allocate Arrays to store samples based on examples
action = dict_map(Array.from_numpy, action, batch_shape=(batch_T, batch_B))
observation = dict_map(Array.from_numpy, observation, batch_shape=(batch_T, batch_B), padding=1)
reward = Array(batch_shape=(batch_T, batch_B), dtype=np.float32)
terminated = Array(batch_shape=(batch_T, batch_B), dtype=bool)
truncated = Array(batch_shape=(batch_T, batch_B), dtype=bool)
env_info = ArrayDict()

# define a model with the correct output for a Categorical distribution
class Model(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.pi = Linear(4, 2)
        self.value = Linear(4, 1)

    def forward(self, observation):
        probs = self.pi(observation).softmax(dim=-1)
        value = self.value(observation).squeeze(-1)
        return ModelOutputs(dist_params=DistParams(probs=probs), value=value)

# instantiate an agent, which requires a model and a distribution
agent = CategoricalPgAgent(
    model=Model(),
    distribution=Categorical(dim=2),
    example_obs=observation[0],
)

# reset all environments and write reset observation to 0th position
for b, env in enumerate(envs):
    env.reset_async(
        out_obs=observation[0, b],
        out_info=env_info[0, b],
    )

for t in range(batch_T):

    # get new actions from agent
    action[t], _ = agent.step(observation[t])

    # rollout actions and get new observations, etc.
    for b, env in enumerate(envs):
        env.step_async(
            action[t, b],
            out_obs=observation[t + 1, b],
            out_reward=reward[t, b],
            out_terminated=terminated[t, b],
            out_truncated=truncated[t, b],
            out_info=env_info[t, b],
        )

        for b, env in enumerate(envs):
            env.await_step()

# print results
print(ArrayDict({
    "action": action,
    "observation": observation,
    "reward": reward,
    "terminated": terminated,
    "truncated": truncated,
}))

## Next Steps

In many cases, these common operations are already available as pre-built classes (for the above example, `BasicSampler` and `RecurrentSampler`). parllel is designed such that these components are as general and interchangeable as possible, allowing you to simply pick out the desired objects and combine them in the desired way. But each class is also written to be understandable, such that researchers can subclass, override, and customize at whim.

This is only a small sample of what is possible with parllel. To get a better idea, please take a look at the `examples` folder.