In [1]:
from tensordict import TensorDict
import torch
data = TensorDict({
    "key 1": torch.ones(3, 4, 5),
    "key 2": torch.zeros(3, 4, 5, dtype=torch.bool),
}, batch_size=[3, 4])

In [4]:
data['key 2']

tensor([[[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]]])

In [5]:
>>> data["nested", "key"] = torch.zeros(3, 4) # the batch-size must match

In [10]:
data

TensorDict(
    fields={
        key 1: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        key 2: Tensor(shape=torch.Size([3, 4, 5]), device=cpu, dtype=torch.bool, is_shared=False),
        nested: TensorDict(
            fields={
                key: Tensor(shape=torch.Size([3, 4]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([3, 4]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([3, 4]),
    device=None,
    is_shared=False)

In [29]:
from tensordict.prototype import tensorclass
import torch

@tensorclass
class MyData:
   image: torch.Tensor
   mask: torch.Tensor
   label: torch.Tensor

   def mask_image(self):
       return self.image[self.mask.expand_as(self.image)].view(*self.batch_size, -1)

   def select_label(self, label):
       return self[self.label == label]

images = torch.randn(100, 3, 64, 64)
label = torch.randint(10, (100,))
mask = torch.zeros(1, 64, 64, dtype=torch.bool).bernoulli_().expand(100, 1, 64, 64)

data = MyData(images, mask, label=label, batch_size=[100])

print(data.select_label(1))
print(data.mask_image().shape)
print(data.reshape(10, 10).mask_image().shape)

MyData(
    image=Tensor(shape=torch.Size([6, 3, 64, 64]), device=cpu, dtype=torch.float32, is_shared=False),
    label=Tensor(shape=torch.Size([6]), device=cpu, dtype=torch.int64, is_shared=False),
    mask=Tensor(shape=torch.Size([6, 1, 64, 64]), device=cpu, dtype=torch.bool, is_shared=False),
    batch_size=torch.Size([6]),
    device=None,
    is_shared=False)
torch.Size([100, 6048])
torch.Size([10, 10, 6048])


In [27]:
data.image.shape

torch.Size([100, 3, 64, 64])

In [30]:
from torchrl.envs import GymEnv

env = GymEnv("Pendulum-v1")

  logger.warn(
  logger.warn(


In [31]:
reset = env.reset()
reset

TensorDict(
    fields={
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [34]:
reset['observation']

tensor([-0.7224,  0.6915,  0.2165])

In [35]:
reset_with_action = env.rand_action(reset)
print(reset_with_action)



TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
        terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
        truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)


In [46]:
stepped_data = env.step(TensorDict({'action': reset_with_action['action']}, batch_size=torch.Size([])))
print(stepped_data)

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([3]), device=cpu, dtype=torch.float32, is_shared=False),
                reward: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False),
                truncated: Tensor(shape=torch.Size([1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)


In [None]:
from torchrl.envs.transforms import StepCounter