These are my notes on the [Data Collection and Storage tutorial](https://pytorch.org/rl/stable/tutorials/getting-started-3.html) in TorchRL, plus some notes at the end about how replay buffers currently interact with `torch.compile`.

## Data collectors

TorchRL has the concept of a data collector, which has the following responsibilities:

* Execute whatever policy you give it within any environment you give it.
* Reset the environment when necessary, e.g. upon termination. Does not automatically reset between consecutive batches of data, unlike `rollout()`.
* Collect batches of data of a predefined size.
* The data collected contains the actions generated from the policy and the states of the environment at each step.

A collector needs to be given the arguments:

* The size of each batch to collect, called `frames_per_batch`.
* The total number of frames to iterate over, called `total_frames`. If `total_frames=-1`, then the number of frames is infinite.
* A policy.
* An environment.

In [5]:
import torch
import torchrl

torch.manual_seed(0)
env = torchrl.envs.GymEnv("CartPole-v1")
env.set_seed(0)
policy = torchrl.envs.utils.RandomPolicy(env.action_spec)
collector = torchrl.collectors.SyncDataCollector(
    env,
    policy,
    frames_per_batch=200,
    total_frames=-1
)

In [6]:
for data in collector:
    print(data)
    break

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([200, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([200]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([200]),
            device=None,
            is_shared=False),
        done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([200, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor

Since we are guaranteed to get a batch of size 200, and a single trajectory may be shorter than that, the collected data contains ID numbers for each individual trajectory.

In [7]:
print(data['collector', 'traj_ids'])

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3,
        3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6,
        6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
        7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9,
        9, 9, 9, 9, 9, 9, 9, 9])


## Replay buffers

After collecting data, we need to store it. Typically, the data is only stored temporarily and then cleared according to some heuristic. TorchRL provides replay buffers for this task. The following aspects of a replay buffer can be specified:

* The storage type.
* The sampling technique.
* The writing heuristic.
* Transforms to apply to the data.

A generic replay buffer just needs to know the storage type. There are several storage types in TorchRL, and they usually need to be given the maximum size of the data that they will hold. We'll use `LazyMemmapStorage`, which is a type of storage which has two useful features:

* It lazily updates so that you don't need to tell it ahead of time what kind of data it will need to hold.
* It uses `MemoryMappedTensor` so that tensors added to it are memory mapped, which allows them to be saved to disk efficiently.


In [17]:
buffer = torchrl.data.replay_buffers.ReplayBuffer(
    storage=torchrl.data.replay_buffers.LazyMemmapStorage(max_size=1000)
)

We can add data to the buffer with either `add()`, to add a single element, or `extend()` to add multiple elements. Since the data we collected earlier has 200 elements, we'll use `extend()`.

In [18]:
indices = buffer.extend(data)

In [24]:
print(len(data))
print(len(buffer))
assert len(buffer) == collector.frames_per_batch

200
200


To take samples from the replay buffer, we can use the `sample()` method. Since we didn't specify any sampling technique, it is purely random, meaning that uniqueness is not guaranteed. Let's take a sample of 30 elements.

In [21]:
sample = buffer.sample(batch_size=30)
print(sample)
print(sample['collector', 'traj_ids'])

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([30, 2]), device=cpu, dtype=torch.int64, is_shared=False),
        collector: TensorDict(
            fields={
                traj_ids: Tensor(shape=torch.Size([30]), device=cpu, dtype=torch.int64, is_shared=False)},
            batch_size=torch.Size([30]),
            device=cpu,
            is_shared=False),
        done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([30, 4]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([30, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=to

The structure of the tensordict returned by `sample()` is the same as that returned by the collector, except that we have a different batch size, and the order is random.

## Compiling replay buffers

At the moment, some of the samplers have an argument that enables the `sample()` method to be compiled. In particular, the `*SliceSampler` classes have this feature. However, we don't just want to compile the sampler, we would also like to compile the `extend()` method on replay buffers.

The first thing do to is to support compiling `extend()` for `ReplayBuffer(storage=LazyTensorStorage(1000))`

We should probably use the same interface for compiling that `SliceSampler` offers, which is a `bool` arg called `compile`. If set to `True` or given a dict of compiler args, `SliceSampler.__init__` runs `self._get_index = torch.compile(self._get_index, **kwargs)` to compile the method that does the sampling work.

So we should add the same arg to `ReplayBuffer` and compile the method that does the work of extending the buffer.

In [25]:
buffer = torchrl.data.ReplayBuffer(
    storage=torchrl.data.LazyTensorStorage(1000)
)

In [27]:
print(type(buffer._sampler))
print(type(buffer._writer))

<class 'torchrl.data.replay_buffers.samplers.RandomSampler'>
<class 'torchrl.data.replay_buffers.writers.RoundRobinWriter'>


### Recompiles during test

@vmoens, you mentioned offline that you've seen recompiles every time you tried to call `extend` on a `ReplayBuffer(storage=LazyTensorStorage(1000))`. Could you share with me a case where that happens? The test I added in #2504 only has recompiles on the first two calls.

As for the recompiles that I have seen, if I set `num_extend_before_capture = 0` in the test, it no longer ignores those recompiles, and I get these recompile records:

```
$ python test/test_rb.py -k test_extend_recompile[100-ReplayBuffer-LazyTensorStorage-tensor-RoundRobinWriter-RandomSampler]
...
[11/1] [__recompiles] Recompiling function _lazy_call_fn in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:389
[11/1] [__recompiles]     triggered by the following guard failure(s):
[11/1] [__recompiles]     - 11/0: L['self'].func_name == 'torchrl.data.replay_buffers.storages.TensorStorage.set'
[12/1] [__recompiles] Recompiling function torch_dynamo_resume_in__lazy_call_fn_at_394 in /home/endoplasm/develop/torchrl-0/torchrl/_utils.py:394
[12/1] [__recompiles]     triggered by the following guard failure(s):
[12/1] [__recompiles]     - 12/0: len(L['args']) == 3                                         
[18/1] [__recompiles] Recompiling function torch_dynamo_resume_in_set_at_713 in /home/endoplasm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py:713
[18/1] [__recompiles]     triggered by the following guard failure(s):
[18/1] [__recompiles]     - 18/0: ___check_obj_id(L['self'].initialized, 8907584) 
```

So there are three recompiles to look into.

The first and second one are caused by the use of the `implement_for` decorator [here](https://github.com/pytorch/rl/blob/baba52b9a13d5416ff622e486fee9b3f05f51f2f/torchrl/data/replay_buffers/storages.py#L686). Since this is the first time the `_lazy_call_fn` function within the `implement_for.__call__` method is being called with the string `self.func_name = "torchrl.data.replay_buffers.storages.TensorStorage.set"`, and since `torch.compile` has to recompile every time a function is given a different string argument, I'm not sure there is much we can/should do about this one. Maybe it would be possible to change how `implement_for` works to avoid this recompile, but maybe it's not worth the trouble since these only happen once.

The third recompile is caused by the fact that [this branch of `TensorStorage.set()`](https://github.com/pytorch/rl/blob/baba52b9a13d5416ff622e486fee9b3f05f51f2f/torchrl/data/replay_buffers/storages.py#L715) is only visited in the first call. I'd guess that this also is not worth the trouble to try to fix.


I tried changing `TensorStorage._len_value` to be just an int, rather than `mp.Value`, and that did not avoid the recompiles.

```
diff --git a/torchrl/data/replay_buffers/storages.py b/torchrl/data/replay_buffers/storages.py
index 217229b5..145bcdd5 100644
--- a/torchrl/data/replay_buffers/storages.py
+++ b/torchrl/data/replay_buffers/storages.py
@@ -149,6 +149,7 @@ class Storage:
         if self.ndim == 1:
             return torch.randint(
                 0,
+                #len(self),
                 self._len,
                 (batch_size,),
                 generator=self._rng,
@@ -458,15 +459,12 @@ class TensorStorage(Storage):
     def _len(self):
         _len_value = self.__dict__.get("_len_value", None)
         if _len_value is None:
-            _len_value = self._len_value = mp.Value("i", 0)
-        return _len_value.value
+            _len_value = self._len_value = 0
+        return _len_value
 
     @_len.setter
     def _len(self, value):
-        _len_value = self.__dict__.get("_len_value", None)
-        if _len_value is None:
-            _len_value = self._len_value = mp.Value("i", 0)
-        _len_value.value = value
+        self._len_value = value
 
     @property
     def _total_shape(self):
@@ -606,7 +604,7 @@ class TensorStorage(Storage):
     def __setstate__(self, state):
         len = state.pop("len__context", None)
         if len is not None:
-            _len_value = mp.Value("i", len)
+            _len_value = len
             state["_len_value"] = _len_value
         self.__dict__.update(state)
```

```
diff --git a/torchrl/data/replay_buffers/writers.py b/torchrl/data/replay_buffers/writers.py
index 3a95c397..115e4ff0 100644
--- a/torchrl/data/replay_buffers/writers.py
+++ b/torchrl/data/replay_buffers/writers.py
@@ -214,29 +214,23 @@ class RoundRobinWriter(Writer):
     def _cursor(self):
         _cursor_value = self.__dict__.get("_cursor_value", None)
         if _cursor_value is None:
-            _cursor_value = self._cursor_value = mp.Value("i", 0)
-        return _cursor_value.value
+            _cursor_value = self._cursor_value = 0
+        return _cursor_value
 
     @_cursor.setter
     def _cursor(self, value):
-        _cursor_value = self.__dict__.get("_cursor_value", None)
-        if _cursor_value is None:
-            _cursor_value = self._cursor_value = mp.Value("i", 0)
-        _cursor_value.value = value
+        self._cursor_value = value
 
     @property
     def _write_count(self):
         _write_count = self.__dict__.get("_write_count_value", None)
         if _write_count is None:
-            _write_count = self._write_count_value = mp.Value("i", 0)
-        return _write_count.value
+            _write_count = self._write_count_value = 0
+        return _write_count
 
     @_write_count.setter
     def _write_count(self, value):
-        _write_count = self.__dict__.get("_write_count_value", None)
-        if _write_count is None:
-            _write_count = self._write_count_value = mp.Value("i", 0)
-        _write_count.value = value
+        self._write_count_value = 0
 
     def __getstate__(self):
         state = super().__getstate__()
@@ -249,7 +243,7 @@ class RoundRobinWriter(Writer):
     def __setstate__(self, state):
         cursor = state.pop("cursor__context", None)
         if cursor is not None:
-            _cursor_value = mp.Value("i", cursor)
+            _cursor_value = cursor
             state["_cursor_value"] = _cursor_value
         self.__dict__.update(state)
```

```
I1024 12:22:47.898000 677981 site-packages/torch/_utils_internal.py:116] [27/0_1] CompilationMetrics(compile_id='27/0', frame_key='65', co_name='_rand_given_ndim', co_filename='/home/endopla
sm/develop/torchrl-0/torchrl/data/replay_buffers/storages.py', co_firstlineno=164, cache_size=0, accumulated_cache_size=0, guard_count=10, shape_env_guard_count=0, graph_op_count=0, graph_no
de_count=0, graph_input_count=0, start_time=1729797767.891117, entire_frame_compile_time_s=0.0068094730377197266, backend_compile_time_s=None, inductor_compile_time_s=None, code_gen_time_s=N
one, fail_type=None, fail_reason=None, fail_user_frame_filename=None, fail_user_frame_lineno=None, non_compliant_ops=set(), compliant_custom_ops=set(), restart_reasons={'Graph break due to u
nsupported builtin None.SemLock.acquire. This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind). If it is a Pyth
on builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround. If it is a third-party C/C++ Python extension, please either wra
p it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use torch.compiler.allow_in_
graph.'}, dynamo_time_before_restart_s=0.0027790069580078125, has_guarded_code=True, possibly_missed_reinplacing_opportunities=0, remote_cache_time_saved_s=0, structured_logging_overhead_s=None, config_suppress_errors=False, config_inline_inbuilt_nn_modules=True, specialize_float=True)
```