In [3]:
import torch
from torch import nn
from tensordict import TensorDict
from tensordict.nn import TensorDictModule, TensorDictSequential

In [32]:
td = TensorDict(x=torch.randn(5, 3), z=torch.randn(5, 8), batch_size=[5])

linear0 = TensorDictModule(
    nn.Linear(3, 128), in_keys=[("input", "x")], out_keys=[("hidden", "linear0")]
)
relu0 = TensorDictModule(
    torch.relu, in_keys=[("hidden", "linear0")], out_keys=[("hidden", "relu0")]
)
linear1 = TensorDictModule(
    nn.Linear(128, 128), in_keys=[("hidden", "relu0")], out_keys=[("hidden", "linear1")]
)
relu1 = TensorDictModule(nn.ReLU(), in_keys=[("hidden", "linear1")], out_keys=[("hidden", "relu1")])
linear2 = TensorDictModule(
    nn.Linear(128, 3), in_keys=[("hidden", "relu1")], out_keys=[("hidden", "linear2")]
)

block0 = TensorDictSequential(linear0, relu0)
block1 = TensorDictSequential(linear1, relu1, linear2)

residual = TensorDictModule(lambda x, y: x + y, in_keys=["x", "linear2"], out_keys=["y"])

block = TensorDictSequential(block0, block1, residual)

In [33]:
td

TensorDict(
    fields={
        x: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        z: Tensor(shape=torch.Size([5, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

In [None]:
for key in list(td.keys()):
    td.rename_key_(key, f"input.{key}")
td

TensorDict(
    fields={
        input.x: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        input.z: Tensor(shape=torch.Size([5, 8]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

In [36]:
td[("input", "z")] = torch.randn(5, 16)
td

TensorDict(
    fields={
        input.x: Tensor(shape=torch.Size([5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
        input.z: Tensor(shape=torch.Size([5, 8]), device=cpu, dtype=torch.float32, is_shared=False),
        input: TensorDict(
            fields={
                z: Tensor(shape=torch.Size([5, 16]), device=cpu, dtype=torch.float32, is_shared=False)},
            batch_size=torch.Size([5]),
            device=None,
            is_shared=False)},
    batch_size=torch.Size([5]),
    device=None,
    is_shared=False)

In [43]:
[(key, val.shape) for key, val in td["input"].items(True, True)]

[('z', torch.Size([5, 16]))]

In [None]:
td.items