# Tensordict experiments

The current documentation of the `tensordict` library can be unclear at points. This notebook experimentally validates 
or disproves some ambiguities.

## Are `tensorclass` objects preserved as an instance when they are part of a `TensorDict`?

In [1]:
import torch
from tensordict import TensorDict, tensorclass
from IPython.display import HTML, display


@tensorclass
class TestClass:
    a: torch.Tensor
    b: torch.Tensor


tc = TestClass(a=torch.randn(3, 6, 4), b=torch.randn(3, 1, 2), batch_size=[3])
tc_key = "tc"
td = TensorDict(
    {
        "value": torch.randn(3, 3),
        tc_key: tc,
    },
    batch_size=[3],
)

print(td)

if type(td[tc_key]) is TestClass:
    display(HTML(f"<p>Type is <strong>preserved</strong></p>"))
else:
    display(HTML(f"<p>Type is <strong>not preserved</strong></p>"))

TensorDict(
    fields={
        tc: TestClass(
            a=Tensor(shape=torch.Size([3, 6, 4]), device=cpu, dtype=torch.float32, is_shared=False),
            b=Tensor(shape=torch.Size([3, 1, 2]), device=cpu, dtype=torch.float32, is_shared=False),
            batch_size=torch.Size([3]),
            device=None,
            is_shared=False),
        value: Tensor(shape=torch.Size([3, 3]), device=cpu, dtype=torch.float32, is_shared=False)},
    batch_size=torch.Size([3]),
    device=None,
    is_shared=False)


## Can ``Tensorclass`` objects be iterated over and set?



In [2]:
import torch
from unipercept.utils.tensorclass import Tensorclass


class TestClass(Tensorclass):
    a: torch.Tensor
    b: torch.Tensor


tc = TestClass(a=torch.randn(3, 1), b=torch.randn(3, 1), batch_size=[3])

print(tc.a.shape, tc.a)
print(tc.b.shape, tc.b)

for batch_idx, tc_item in enumerate(tc):
    print(batch_idx)
    tc_item = tc_item.clone()
    tc_item.a *= batch_idx
    tc_item.b += batch_idx

    tc[batch_idx] = tc_item


print(tc.a.shape, tc.a)
print(tc.b.shape, tc.b)

torch.Size([3, 1]) tensor([[-1.8834],
        [ 0.2882],
        [-0.5761]])
torch.Size([3, 1]) tensor([[-0.7000],
        [ 1.6227],
        [-0.3501]])
0
1
2
torch.Size([3, 1]) tensor([[-0.0000],
        [ 0.2882],
        [-1.1522]])
torch.Size([3, 1]) tensor([[-0.7000],
        [ 2.6227],
        [ 1.6499]])


## What does the flat PyTree Spec look like?

In [5]:
from tensordict import TensorDict
from torch.utils._pytree import tree_flatten
import torch
from unipercept.data.points import Image

td = TensorDict.from_dict(
    {
        "a": torch.randn(3, 1),
        "b": torch.randn(3, 5),
        "c": TensorDict.from_dict(
            {
                "d": torch.randn(3, 5, 1),
                "e": torch.randn(3, 5, 6),
                "f": TensorDict.from_dict(
                    {
                        "g": torch.randn(3, 5, 3),
                        "h": torch.randn(3, 5, 3, 2).as_subclass(Image),
                    }
                ),
            }
        ),
        "x": 1,
    }
)

print(td)

td_flat, td_structure = tree_flatten(td)

print(f"PyTree: \n" + "\n".join([str(type(x)) for x in td_flat]))
print(
    f"Structure: {td_structure.context}\n"
    + "\n".join([str(x) for x in td_structure.children_specs])
)

TensorDict(
    fields={
        a: Tensor(shape=torch.Size([3, 1]), device=cpu, dtype=torch.float32, is_shared=False),
        b: Tensor(shape=torch.Size([3, 5]), device=cpu, dtype=torch.float32, is_shared=False),
        c: TensorDict(
            fields={
                d: Tensor(shape=torch.Size([3, 5, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                e: Tensor(shape=torch.Size([3, 5, 6]), device=cpu, dtype=torch.float32, is_shared=False),
                f: TensorDict(
                    fields={
                        g: Tensor(shape=torch.Size([3, 5, 3]), device=cpu, dtype=torch.float32, is_shared=False),
                        h: Image(shape=torch.Size([3, 5, 3, 2]), device=cpu, dtype=torch.float32, is_shared=False)},
                    batch_size=torch.Size([3, 5, 3]),
                    device=None,
                    is_shared=False)},
            batch_size=torch.Size([3, 5]),
            device=None,
            is_shared=False),
        x: Ten