In [1]:
!pip install functorch
!pip install "gym[classic_control]"
!pip install torchrl

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting functorch
  Downloading functorch-0.2.1-cp37-cp37m-manylinux1_x86_64.whl (20.6 MB)
[K     |████████████████████████████████| 20.6 MB 1.2 MB/s 
Installing collected packages: functorch
Successfully installed functorch-0.2.1
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pygame==2.1.0
  Downloading pygame-2.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.3 MB)
[K     |████████████████████████████████| 18.3 MB 76 kB/s 
Installing collected packages: pygame
Successfully installed pygame-2.1.0
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchrl
  Downloading torchrl-0.0.2a0-cp37-cp37m-manylinux1_x86_64.whl (3.3 MB)
[K     |████████████████████████████████| 3.3 MB 5.2 MB/s 
Installing collected packages: torchrl
Successfully installed t

## Data
### TensorDict

In [2]:
import torch
from torchrl.data import TensorDict

In [3]:
# Creating a TensorDict
batch_size = 5
tensordict = TensorDict(source={
    "key 1": torch.zeros(batch_size, 3),
    "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool)
}, batch_size = [batch_size])
print(tensordict)

TensorDict(
    fields={
        key 1: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        key 2: Tensor(torch.Size([5, 5, 6]), dtype=torch.bool)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)


In [4]:
# indexing
tensordict[2]

TensorDict(
    fields={
        key 1: Tensor(torch.Size([3]), dtype=torch.float32),
        key 2: Tensor(torch.Size([5, 6]), dtype=torch.bool)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [5]:
# querying keys
tensordict["key 1"] is tensordict.get("key 1")

True

In [6]:
# Stacking tensordicts

tensordict1 = TensorDict(source={
    "key 1": torch.zeros(batch_size, 1),
    "key 2": torch.zeros(batch_size, 5, 6, dtype=torch.bool)
}, batch_size = [batch_size])

tensordict2 = TensorDict(source={
    "key 1": torch.ones(batch_size, 1),
    "key 2": torch.ones(batch_size, 5, 6, dtype=torch.bool)
}, batch_size = [batch_size])

tensordict = torch.stack([tensordict1, tensordict2], 0)
tensordict.batch_size, tensordict["key 1"]

(torch.Size([2, 5]), tensor([[[0.],
          [0.],
          [0.],
          [0.],
          [0.]],
 
         [[1.],
          [1.],
          [1.],
          [1.],
          [1.]]]))

In [7]:
# Other functionalities
print("view(-1): ", tensordict.view(-1).batch_size, tensordict.view(-1).get("key 1").shape)

print("to device: ", tensordict.to("cpu"))

# print("pin_memory: ", tensordict.pin_memory())

print("share memory: ", tensordict.share_memory_())

print("permute(1, 0): ", 
      tensordict.permute(1, 0).batch_size, 
      tensordict.permute(1, 0).get("key 1").shape)

print("expand: ", 
      tensordict.expand(3, *tensordict.batch_size).batch_size, 
      tensordict.expand(3, *tensordict.batch_size).get("key 1").shape)

view(-1):  torch.Size([10]) torch.Size([10, 1])
to device:  TensorDict(
    fields={
        key 1: Tensor(torch.Size([2, 5, 1]), dtype=torch.float32),
        key 2: Tensor(torch.Size([2, 5, 5, 6]), dtype=torch.bool)},
    batch_size=torch.Size([2, 5]),
    device=cpu,
    is_shared=False)
share memory:  LazyStackedTensorDict(
    fields={
        key 1: Tensor(torch.Size([2, 5, 1]), dtype=torch.float32),
        key 2: Tensor(torch.Size([2, 5, 5, 6]), dtype=torch.bool)},
    batch_size=torch.Size([2, 5]),
    device=None,
    is_shared=True)
permute(1, 0):  torch.Size([5, 2]) torch.Size([5, 2, 1])
expand:  torch.Size([3, 2, 5]) torch.Size([3, 2, 5, 1])


#### Nested tensordict

In [8]:
tensordict = TensorDict(source={
    "key 1": torch.zeros(batch_size, 3),
    "key 2": TensorDict(source={
        "sub-key 1": torch.zeros(batch_size, 2, 1)
    }, batch_size=[batch_size, 2])
}, batch_size = [batch_size])
tensordict

TensorDict(
    fields={
        key 1: Tensor(torch.Size([5, 3]), dtype=torch.float32),
        key 2: TensorDict(
            fields={
                sub-key 1: Tensor(torch.Size([5, 2, 1]), dtype=torch.float32)},
            batch_size=torch.Size([5, 2]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

### Replay buffers

In [9]:
from torchrl.data import ReplayBuffer, PrioritizedReplayBuffer

In [10]:
rb = ReplayBuffer(100, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)

[1]

In [11]:
rb.extend([2, 3])
rb.sample(3)

[3, 1, 1]

In [12]:
rb = PrioritizedReplayBuffer(100, alpha=0.7, beta=1.1, collate_fn=lambda x: x)
rb.add(1)
rb.sample(1)
rb.update_priority(1, 0.5)

#### working with tensordicts

In [13]:

collate_fn = torch.stack
rb = ReplayBuffer(100, collate_fn=collate_fn)
rb.add(TensorDict({"a": torch.randn(3)}, batch_size=[]))
len(rb)

1

In [14]:
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
len(rb)

3

In [15]:
rb.sample(10)

LazyStackedTensorDict(
    fields={
        a: Tensor(torch.Size([10, 3]), dtype=torch.float32)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)

In [16]:
rb.sample(2).contiguous()

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 3]), dtype=torch.float32)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

In [17]:
torch.manual_seed(0)
from torchrl.data import TensorDictPrioritizedReplayBuffer
rb = TensorDictPrioritizedReplayBuffer(100, alpha=0.7, beta=1.1, priority_key="td_error")
rb.extend(TensorDict({"a": torch.randn(2, 3)}, batch_size=[2]))
tensordict_sample = rb.sample(2).contiguous()
tensordict_sample

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 3]), dtype=torch.float32),
        index: Tensor(torch.Size([2, 1]), dtype=torch.int32)},
    batch_size=torch.Size([2]),
    device=None,
    is_shared=False)

In [18]:
tensordict_sample["index"]

tensor([[1],
        [0]], dtype=torch.int32)

In [19]:
tensordict_sample["td_error"] = torch.rand(2)
rb.update_priority(tensordict_sample)

In [20]:
for i, val in enumerate(rb._sum_tree):
    print(i, val)
    if i == len(rb):
        break

0 0.28791671991348267
1 0.06984968483448029
2 0.0


## Envs

In [21]:
from torchrl.envs.libs.gym import GymWrapper, GymEnv
import gym

gym_env = gym.make("Pendulum-v1")
env = GymWrapper(gym_env)
env = GymEnv("Pendulum-v1")

Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.


  "Initializing wrapper in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."
  "Initializing environment in old step API which returns one bool instead of two. It is recommended to set `new_step_api=True` to use new step API. This will be the default behaviour in future."


In [22]:
tensordict = env.reset()

In [23]:
env.rand_step(tensordict)

TensorDict(
    fields={
        action: Tensor(torch.Size([1]), dtype=torch.float32),
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([3]), dtype=torch.float32),
        observation: Tensor(torch.Size([3]), dtype=torch.float32),
        reward: Tensor(torch.Size([1]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

### changing environments config

In [24]:
env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env.reset()

Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.


TensorDict(
    fields={
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        pixels: Tensor(torch.Size([500, 500, 3]), dtype=torch.uint8),
        state: Tensor(torch.Size([3]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [25]:
env.close()
del env

In [26]:
from torchrl.envs import Compose, ObservationNorm, ToTensorImage, NoopResetEnv, TransformedEnv
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(keys_in=["next_pixels"], loc=2, scale=1))

Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.


### Transforms

In [27]:
from torchrl.envs import Compose, ObservationNorm, ToTensorImage, NoopResetEnv, TransformedEnv
base_env = GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False)
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))
env.append_transform(ObservationNorm(keys_in=["next_pixels"], loc=2, scale=1))

Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.


In [28]:
env.reset()

TensorDict(
    fields={
        done: Tensor(torch.Size([1]), dtype=torch.bool),
        pixels: Tensor(torch.Size([3, 500, 500]), dtype=torch.float32),
        state: Tensor(torch.Size([3]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=cpu,
    is_shared=False)

In [29]:
print("env: ", env)
print("last transform parent: ", env.transform[2].parent)

env:  TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['next_pixels']),
            ObservationNorm(loc=2.0000, scale=1.0000, keys=['next_pixels'])))
last transform parent:  TransformedEnv(
    env=GymEnv(env=Pendulum-v1, batch_size=torch.Size([]), device=cpu),
    transform=Compose(
            NoopResetEnv(noops=3, random=True),
            ToTensorImage(keys=['next_pixels'])))


### Vectorized environments

In [30]:
from torchrl.envs import ParallelEnv
base_env = ParallelEnv(4, lambda: GymEnv("Pendulum-v1", frame_skip=3, from_pixels=True, pixels_only=False))
env = TransformedEnv(base_env, Compose(NoopResetEnv(3), ToTensorImage()))  # applies transforms on batch of envs
env.append_transform(ObservationNorm(keys_in=["next_pixels"], loc=2, scale=1))
env.reset()

Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.


TensorDict(
    fields={
        done: Tensor(torch.Size([4, 1]), dtype=torch.bool),
        pixels: Tensor(torch.Size([4, 3, 500, 500]), dtype=torch.float32),
        state: Tensor(torch.Size([4, 3]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

In [31]:
env.action_spec

NdBoundedTensorSpec(
     shape=torch.Size([1]), space=ContinuousBox(minimum=tensor([-2.]), maximum=tensor([2.])), device=cpu, dtype=torch.float32, domain=continuous)

## Modules

### Models
#### MLP

In [32]:
from torchrl.modules import MLP, ConvNet
from torchrl.modules.models.utils import SquashDims
from torch import nn
net = MLP(num_cells=[32, 64], out_features=4, activation_class=nn.ELU)
print(net)

MLP(
  (0): LazyLinear(in_features=0, out_features=32, bias=True)
  (1): ELU(alpha=1.0)
  (2): Linear(in_features=32, out_features=64, bias=True)
  (3): ELU(alpha=1.0)
  (4): Linear(in_features=64, out_features=4, bias=True)
)




In [33]:
net(torch.randn(10, 3)).shape

torch.Size([10, 4])

#### CNN

In [34]:
cnn = ConvNet(num_cells=[32, 64], kernel_sizes=[8, 4], strides=[2, 1], aggregator_class=SquashDims)
print(cnn)

ConvNet(
  (0): LazyConv2d(0, 32, kernel_size=(8, 8), stride=(2, 2))
  (1): ELU(alpha=1.0)
  (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(1, 1))
  (3): ELU(alpha=1.0)
  (4): SquashDims()
)


In [35]:
cnn(torch.randn(10, 3, 32, 32)).shape  # last tensor is squashed

torch.Size([10, 6400])

### TensorDictModules

In [36]:
from torchrl.modules import TensorDictModule
tensordict = TensorDict({"key 1": torch.randn(10, 3)}, batch_size=[10])
module = nn.Linear(3, 4)
td_module = TensorDictModule(module, in_keys=["key 1"], out_keys=["key 2"])
td_module(tensordict)
print(tensordict)

TensorDict(
    fields={
        key 1: Tensor(torch.Size([10, 3]), dtype=torch.float32),
        key 2: Tensor(torch.Size([10, 4]), dtype=torch.float32)},
    batch_size=torch.Size([10]),
    device=None,
    is_shared=False)


### Sequences of modules

In [37]:
from torchrl.modules import TensorDictSequential
backbone_module = nn.Linear(5, 3)
backbone = TensorDictModule(backbone_module, in_keys=["observation"], out_keys=["hidden"])
actor_module = nn.Linear(3, 4)
actor = TensorDictModule(actor_module, in_keys=["hidden"], out_keys=["action"])
value_module = MLP(out_features=1, num_cells=[4, 5])
value = TensorDictModule(value_module, in_keys=["hidden", "action"], out_keys=["value"])

sequence = TensorDictSequential(backbone, actor, value)
print(sequence)

TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=Linear(in_features=5, out_features=3, bias=True), 
          device=cpu, 
          in_keys=['observation'], 
          out_keys=['hidden'])
      (1): TensorDictModule(
          module=Linear(in_features=3, out_features=4, bias=True), 
          device=cpu, 
          in_keys=['hidden'], 
          out_keys=['action'])
      (2): TensorDictModule(
          module=MLP(
            (0): LazyLinear(in_features=0, out_features=4, bias=True)
            (1): Tanh()
            (2): Linear(in_features=4, out_features=5, bias=True)
            (3): Tanh()
            (4): Linear(in_features=5, out_features=1, bias=True)
          ), 
          device=cpu, 
          in_keys=['hidden', 'action'], 
          out_keys=['value'])
    ), 
    device=cpu, 
    in_keys=['observation'], 
    out_keys=['hidden', 'action', 'value'])


In [38]:
print(sequence.in_keys, sequence.out_keys)

['observation'] ['hidden', 'action', 'value']


In [39]:
tensordict = TensorDict(
    {"observation": torch.randn(3, 5)}, [3],
)
backbone(tensordict)
actor(tensordict)
value(tensordict)

TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([3, 5]), dtype=torch.float32),
        value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

In [40]:
tensordict = TensorDict(
    {"observation": torch.randn(3, 5)}, [3],
)
sequence(tensordict)
print(tensordict)

TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([3, 5]), dtype=torch.float32),
        value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)


### Functional programming (ensembling / meta-RL)

In [41]:
fsequence, (params, buffers) = sequence.make_functional_with_buffers()

In [42]:
len(list(fsequence.parameters()))  # functional modules have no parameters

0

In [43]:
fsequence(tensordict, params=params, buffers=buffers)

TensorDict(
    fields={
        action: Tensor(torch.Size([3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([3, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([3, 5]), dtype=torch.float32),
        value: Tensor(torch.Size([3, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)

In [44]:
params_expand = [p.expand(4, *p.shape) for p in params]
buffers_expand = [b.expand(4, *b.shape) for b in buffers]
tensordict_exp = fsequence(tensordict, params=params_expand, buffers=buffers, vmap=(0, 0, None))
print(tensordict_exp)

TensorDict(
    fields={
        action: Tensor(torch.Size([4, 3, 4]), dtype=torch.float32),
        hidden: Tensor(torch.Size([4, 3, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([4, 3, 5]), dtype=torch.float32),
        value: Tensor(torch.Size([4, 3, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4, 3]),
    device=None,
    is_shared=False)


### Specialized classes

In [45]:
torch.manual_seed(0)
from torchrl.data import NdBoundedTensorSpec
spec = NdBoundedTensorSpec(-torch.ones(3), torch.ones(3))
base_module = nn.Linear(5, 3)
module = TensorDictModule(module=base_module, spec=spec, in_keys=["obs"], out_keys=["action"], safe=True)
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
module(tensordict)["action"]

tensor([-0.0137,  0.1524, -0.0641], grad_fn=<AddBackward0>)

In [46]:
tensordict = TensorDict({"obs": torch.randn(5)*100}, batch_size=[])
module(tensordict)["action"]  # safe=True projects the result within the set

tensor([-1.,  1., -1.], grad_fn=<IndexPutBackward0>)

In [47]:
from torchrl.modules import Actor
base_module = nn.Linear(5, 3)
actor = Actor(base_module, in_keys=["obs"])
tensordict = TensorDict({"obs": torch.randn(5)}, batch_size=[])
actor(tensordict)  # action is the default value

TensorDict(
    fields={
        action: Tensor(torch.Size([3]), dtype=torch.float32),
        obs: Tensor(torch.Size([5]), dtype=torch.float32)},
    batch_size=torch.Size([]),
    device=None,
    is_shared=False)

In [48]:
# Probabilistic modules
from torchrl.modules import ProbabilisticTensorDictModule
from torchrl.data import TensorDict
from torchrl.modules import  TanhNormal, NormalParamWrapper
td = TensorDict({"input": torch.randn(3, 5)}, [3,])
net = NormalParamWrapper(nn.Linear(5, 4))  # splits the output in loc and scale
module = TensorDictModule(net, in_keys=["input"], out_keys=["loc", "scale"])
td_module = ProbabilisticTensorDictModule(
   module=module,
   dist_param_keys=["loc", "scale"],
   out_key_sample=["action"],
   distribution_class=TanhNormal,
   return_log_prob=False,
)
td_module(td)
print(td)

TensorDict(
    fields={
        action: Tensor(torch.Size([3, 2]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 5]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 2]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 2]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)


In [49]:
# returning the log-probability
td = TensorDict({"input": torch.randn(3, 5)}, [3,])
td_module = ProbabilisticTensorDictModule(
   module=module,
   dist_param_keys=["loc", "scale"],
   out_key_sample=["action"],
   distribution_class=TanhNormal,
   return_log_prob=True,
)
td_module(td)
print(td)

TensorDict(
    fields={
        action: Tensor(torch.Size([3, 2]), dtype=torch.float32),
        input: Tensor(torch.Size([3, 5]), dtype=torch.float32),
        loc: Tensor(torch.Size([3, 2]), dtype=torch.float32),
        sample_log_prob: Tensor(torch.Size([3, 1]), dtype=torch.float32),
        scale: Tensor(torch.Size([3, 2]), dtype=torch.float32)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)


In [50]:
# Sampling vs mode / mean
from torchrl.envs.utils import set_exploration_mode
td = TensorDict({"input": torch.randn(3, 5)}, [3,])

torch.manual_seed(0)
with set_exploration_mode("random"):
    td_module(td)
    print("random:", td["action"])
    
with set_exploration_mode("mode"):
    td_module(td)
    print("mode:", td["action"])

with set_exploration_mode("mean"):
    td_module(td)
    print("mean:", td["action"])

    

random: tensor([[ 0.8728, -0.1335],
        [-0.9833,  0.3497],
        [-0.6889, -0.6433]], grad_fn=<ClampBackward1>)
mode: tensor([[-0.1131,  0.1761],
        [-0.3425, -0.2665],
        [ 0.2915,  0.6207]], grad_fn=<ClampBackward1>)
mean: tensor([[-0.1131,  0.1441],
        [-0.2375, -0.1242],
        [ 0.1372,  0.3810]], grad_fn=<MeanBackward1>)


## Using environments and modules

In [51]:
from torchrl.envs.utils import step_tensordict
env = GymEnv("Pendulum-v1")

action_spec = env.action_spec
actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, spec=action_spec, in_keys=["observation"], out_keys=["action"])

torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
tensordict = env.reset()
tensordicts = TensorDict({}, [max_steps])
for i in range(max_steps):
    actor(tensordict)
    tensordicts[i] = env.step(tensordict)
    tensordict = step_tensordict(tensordict)  # roughly equivalent to obs = next_obs
    if env.is_done:
        break

tensordicts_prealloc = tensordicts.clone()
print("total steps:", i)
print(tensordicts)

Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.
total steps: 99
TensorDict(
    fields={
        action: Tensor(torch.Size([100, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([100, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([100, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([100, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([100, 1]), dtype=torch.float32)},
    batch_size=torch.Size([100]),
    device=None,
    is_shared=False)


In [52]:
# equivalent
torch.manual_seed(0)
env.set_seed(0)

max_steps = 100
tensordict = env.reset()
tensordicts = []
for i in range(max_steps):
    actor(tensordict)
    tensordicts.append(env.step(tensordict))
    tensordict = step_tensordict(tensordict)  # roughly equivalent to obs = next_obs
    if env.is_done:
        break
tensordicts_stack = torch.stack(tensordicts, 0)
print("total steps:", i)
print(tensordicts_stack)

total steps: 99
LazyStackedTensorDict(
    fields={
        action: Tensor(torch.Size([100, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([100, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([100, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([100, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([100, 1]), dtype=torch.float32)},
    batch_size=torch.Size([100]),
    device=cpu,
    is_shared=False)


In [53]:
(tensordicts_stack == tensordicts_prealloc).all()

True

In [54]:
# helper
torch.manual_seed(0)
env.set_seed(0)
tensordict_rollout = env.rollout(policy=actor, max_steps=max_steps)
tensordict_rollout

TensorDict(
    fields={
        action: Tensor(torch.Size([100, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([100, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([100, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([100, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([100, 1]), dtype=torch.float32)},
    batch_size=torch.Size([100]),
    device=cpu,
    is_shared=False)

In [55]:
(tensordict_rollout == tensordicts_prealloc).all()

True

## Collectors

In [56]:
from torchrl.envs import ParallelEnv, EnvCreator
from torchrl.envs.libs.gym import GymEnv
from torchrl.modules import TensorDictModule
from torchrl.collectors import MultiSyncDataCollector, MultiaSyncDataCollector
from torch import nn

# EnvCreator makes sure that we can send a lambda function from process to process
parallel_env = ParallelEnv(3, EnvCreator(lambda: GymEnv("Pendulum-v1")))
create_env_fn=[parallel_env, parallel_env]

actor_module = nn.Linear(3, 1)
actor = TensorDictModule(actor_module, in_keys=["observation"], out_keys=["action"])

# Sync data collector
devices = ["cpu", "cpu"]

collector = MultiSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early 
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    passing_devices=devices,  # len must match len of env created
    devices=devices,
)


Discarding frameskip arg. This will be taken care of by TorchRL env wrapper.


In [57]:
for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)

TensorDict(
    fields={
        action: Tensor(torch.Size([6, 10, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([6, 10, 1]), dtype=torch.bool),
        mask: Tensor(torch.Size([6, 10, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([6, 10, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([6, 10, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([6, 10, 1]), dtype=torch.float32),
        step_count: Tensor(torch.Size([6, 10, 1]), dtype=torch.int32),
        traj_ids: Tensor(torch.Size([6, 10, 1]), dtype=torch.int64)},
    batch_size=torch.Size([6, 10]),
    device=None,
    is_shared=False)
3


In [58]:

# async data collector: keeps working while you update your model
collector = MultiaSyncDataCollector(
    create_env_fn=create_env_fn,  # either a list of functions or a ParallelEnv
    policy=actor,
    total_frames=240,
    max_frames_per_traj=-1,  # envs are terminating, we don't need to stop them early 
    frames_per_batch=60,  # we want 60 frames at a time (we have 3 envs per sub-collector)
    passing_devices=devices,  # len must match len of env created
    devices=devices,
)

for i, d in enumerate(collector):
    if i == 0:
        print(d)  # trajectories are split automatically in [6 workers x 10 steps]
    collector.update_policy_weights_()  # make sure that our policies have the latest weights if working on multiple devices
print(i)
del collector

TensorDict(
    fields={
        action: Tensor(torch.Size([3, 20, 1]), dtype=torch.float32),
        done: Tensor(torch.Size([3, 20, 1]), dtype=torch.bool),
        mask: Tensor(torch.Size([3, 20, 1]), dtype=torch.bool),
        next_observation: Tensor(torch.Size([3, 20, 3]), dtype=torch.float32),
        observation: Tensor(torch.Size([3, 20, 3]), dtype=torch.float32),
        reward: Tensor(torch.Size([3, 20, 1]), dtype=torch.float32),
        step_count: Tensor(torch.Size([3, 20, 1]), dtype=torch.int32),
        traj_ids: Tensor(torch.Size([3, 20, 1]), dtype=torch.int64)},
    batch_size=torch.Size([3, 20]),
    device=cpu,
    is_shared=False)
3
