In [1]:
import torch
from torchrl.data import ReplayBuffer, LazyTensorStorage

buf = ReplayBuffer(storage=LazyTensorStorage(
                    max_size=5,
                    device='cuda'),
                    batch_size=2)        # Sample 2 trajectories each time.

trajectory1 = (torch.as_tensor(1), torch.as_tensor(2), torch.as_tensor(3), torch.as_tensor(4))    # S, A, R, S
trajectory2 = (torch.as_tensor(5), torch.as_tensor(6), torch.as_tensor(7), torch.as_tensor(8))
trajectory3 = (torch.as_tensor(9), torch.as_tensor(10), torch.as_tensor(11), torch.as_tensor(12))

buf.add(trajectory1)
buf.add(trajectory2)
buf.add(trajectory3)

# buf.sample()     # prints state batch, action batch, reward batch and next state batch.


(tensor([5, 5], device='cuda:0'),
 tensor([6, 6], device='cuda:0'),
 tensor([7, 7], device='cuda:0'),
 tensor([8, 8], device='cuda:0'))

In [2]:
from tensordict import TensorDict
from torchrl.data import SliceSampler, TensorDictReplayBuffer, LazyMemmapStorage

size=20

rb = TensorDictReplayBuffer(
    storage=LazyMemmapStorage(size),
    sampler=SliceSampler(traj_key="episode", num_slices=4),
    batch_size=8,
)
episode = torch.zeros(10, dtype=torch.int)
episode[:3] = 1
episode[3:5] = 2
episode[5:7] = 3
episode[7:] = 4
steps = torch.cat([torch.arange(3), torch.arange(2), torch.arange(2), torch.arange(3)])
data = TensorDict(
    {
        "episode": episode,
        "obs": torch.randn((3, 4, 5)).expand(10, 3, 4, 5),
        "act": torch.randn((20,)).expand(10, 20),
        "other": torch.randn((20, 50)).expand(10, 20, 50),
        "steps": steps,
    },
    [10],
)
rb.extend(data)
sample = rb.sample()
print("episode are grouped", sample["episode"])
print("steps are successive", sample["steps"])

episode are grouped tensor([1, 1, 4, 4, 2, 2, 2, 2], dtype=torch.int32)
steps are successive tensor([0, 1, 1, 2, 0, 1, 0, 1])


In [8]:
data = TensorDict(
    {
        "observation": torch.ones(3, 4), # tensor at root level
        "next": {"observation": torch.ones(3, 4)} # a nested tensordict
    },
    batch_size=[3]
)

a = data["next", "observation"]
print(a)

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
