In [1]:
import torch

data_orig = torch.tensor([
  [1], [2], [3], [4],
  [2], [3], [4]
])

dim = -1

reset = torch.tensor([
  [True], [False], [False], [False],
  [True], [False], [False]
])

data = torch.tensor([
    [[0], [0], [0], [0], [0], [1]],
    [[0], [0], [0], [0], [1], [2]],
    [[0], [0], [0], [1], [2], [3]],
    [[0], [0], [1], [2], [3], [4]],

    [[0], [1], [2], [3], [4], [2]],
    [[1], [2], [3], [4], [2], [3]],
    [[2], [3], [4], [2], [3], [4]],
])

done_mask_expand = torch.tensor([
    [True, True,  True,  True,  True, False],
    [True, True,  True,  True, False, False],
    [True, True,  True, False, False, False],
    [True, True, False, False, False, False],

    [True, True,  True,  True,  True, False],
    [True, True,  True,  True, False, False],
    [True, True,  True, False, False, False]
])

expected_res = torch.tensor([
    [[1], [1], [1], [1], [1], [1]],
    [[1], [1], [1], [1], [1], [2]],
    [[1], [1], [1], [1], [2], [3]],
    [[1], [1], [1], [2], [3], [4]],
    [[2], [2], [2], [2], [2], [2]],
    [[2], [2], [2], [2], [2], [3]],
    [[2], [2], [2], [2], [3], [4]],
])


In [2]:
d = data.ndim + dim - 1
n_feat = data.shape[data.ndim + dim :].numel()
num_repeats_per_sample = done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat
num_repeats_per_sample


tensor([5, 4, 3, 2, 5, 4, 3])

In [8]:
reset_any = reset.any(-1, False)
reset_vals = list(data_orig[reset_any].unbind(0))
reset_vals

[tensor([1]), tensor([2])]

In [12]:

j_ = float('inf')
ranges = []
range_start = 0

for j in done_mask_expand.flatten(d, -1).sum(-1).view(-1) // n_feat:
    if j > j_:
        ranges.append([range_start, int(j)-1])
        range_start = int(j)-1
    j_ = j

ranges.append([range_start, data.size(0)])

ranges

[[0, 4], [4, 7]]

In [13]:
data_copy = data.clone()
for range_num, (start, end) in enumerate(ranges):
    data_slice = data_copy[start:end]
    print(data_slice.squeeze(-1))
    slices = []
    print(num_repeats_per_sample[start:end])

    for sample_idx, num_repeats in enumerate(num_repeats_per_sample[start:end]):
        if num_repeats > 0:
            print(f"{sample_idx} {num_repeats} {range_num}")
            data_slice[sample_idx, :num_repeats] = reset_vals[range_num]

print(data_copy.squeeze(-1))
print((data_copy == expected_res).all())

tensor([[0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 1, 2],
        [0, 0, 0, 1, 2, 3],
        [0, 0, 1, 2, 3, 4]])
tensor([5, 4, 3, 2])
0 5 0
1 4 0
2 3 0
3 2 0
tensor([[0, 1, 2, 3, 4, 2],
        [1, 2, 3, 4, 2, 3],
        [2, 3, 4, 2, 3, 4]])
tensor([5, 4, 3])
0 5 1
1 4 1
2 3 1
tensor([[1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 2],
        [1, 1, 1, 1, 2, 3],
        [1, 1, 1, 2, 3, 4],
        [2, 2, 2, 2, 2, 2],
        [2, 2, 2, 2, 2, 3],
        [2, 2, 2, 2, 3, 4]])
tensor(True)


## Repro from user

A user reported this in Discord: <https://discord.com/channels/1171857748607115354/1268244652180377732/1270523213600002129>

Here's the repro:

In [2]:
import torch
from torchrl.envs import CatFrames, Compose
from torchrl.data import PrioritizedSliceSampler, ReplayBuffer, LazyTensorStorage

def catframes_speed_test(with_catframes=True):
    from tensordict import TensorDict
    import time

    if with_catframes:
        obs_shape = (16, 84, 84, 1)
        transform = Compose(
            CatFrames(N=4, dim=-1, in_keys=["observation"], done_key="done"),
            CatFrames(N=4, dim=-1, in_keys=[("next", "observation")], done_key="done"))
    else:
        obs_shape = (16, 84, 84, 4)
        transform = None

    sampler = PrioritizedSliceSampler(
        max_capacity=250_000,
        alpha=0.5,
        beta=0.4,
        strict_length=False,
        num_slices=32 // 4,
        span=[True, False])

    exp_buffer = ReplayBuffer(
            storage=LazyTensorStorage(max_size=250_000, device="cpu"),
            sampler=sampler,
            batch_size=32,
            transform=transform
        )
    fake_data = TensorDict(
        {
            "observation": torch.zeros(obs_shape, dtype=torch.float32),
            "next": {"observation": torch.zeros(obs_shape, dtype=torch.float32),
                     "done": torch.zeros((obs_shape[0], 1), dtype=torch.bool)},
        },
        batch_size=[obs_shape[0]],
    )

    for _ in range(25):
        exp_buffer.extend(fake_data)

    t1 = time.perf_counter()
    data, info = exp_buffer.sample(return_info=True)
    t2 = time.perf_counter()
    print(f"Sampling took {t2 - t1} seconds.")

    exp_buffer.empty()


if __name__ == "__main__":
    print("WITHOUT CATFRAMES")
    catframes_speed_test(with_catframes=False)

    print("WITH CATFRAMES")
    catframes_speed_test(with_catframes=True)

WITHOUT CATFRAMES




: 