# TensorDict tutorial

`TensorDict` is a new tensor structure introduced in TorchRL. 

With RL, you need to be able to deal with multiple tensors such as actions, observations and reward. `TensorDict` aims at making it more convenient to deal with multiple tensors at the same time. 

Furthermore, different RL algorithms can deal with different input and outputs. The `TensorDict` class makes it possible to abstract away the differences between these algorithmes. 

TensorDict combines the convinience of using `dict`s to organize your data with the power of pytorch tensors.


#### Improving the modularity of codes

Let's suppose we have 2 datasets: Dataset A which has images and labels and Dataset B which has images, segmentation maps and labels. 

Suppose we want to train a common algorithm over these two datasets (i.e. an algorithm that would ignore the mask or infer it when needed). 

In classical pytorch we would need to do the following:
```python
#Method A
for i in range(optim_steps):
    images, labels = get_data_A()
    loss = loss_module(images, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
````

```python
#Method B
for i in range(optim_steps):
    images, masks, labels = get_data_B()
    loss = loss_module(images, labels)
    loss.backward()
    optim.step()
    optim.zero_grad()
```

We can see that this limits the reusability of code. A lot of code has to be rewriten because of the modality difference between the 2 datasets.
The idea of TensorDict is to do the following:

```python
# General Method
for i in range(optim_steps):
    tensordict = get_data()
    loss = loss_module(tensordict)
    loss.backward()
    optim.step()
    optim.zero_grad()
```


Now we can reuse the same training loop across datasets and losses.

#### Can't i do this with a python dict?

One could argue that you could achieve the same results with a dataset that outputs a pytorch dict. 
```python
class DictDataset(Dataset):
    ...
    
    def __getitem__(self, idx)
        
    ...
    
        return {"images": image, "masks": mask}
    
```

However to achieve this you would need to write a complicated collate function that make sure that every modality is agregated properly.

```python

def collate_dict_fn(dict_list):
    final_dict = {}
    for key in dict_list[0].keys():
        final_dict[key]= []
        for single_dict in dict_list:
            final_dict[key].append(single_dict[key])
        final_dict[key] = torch.stack(final_dict[key], dim=0)
    return final_dict


dataloader = Dataloader(DictDataset(), collate_fn = collate_dict_fn)

````
With TensorDicts this is now much simpler:

```python
class DictDataset(Dataset):
    ...
    
    def __getitem__(self, idx)
        
        ...
    
        return TensorDict({"images": image, "masks": mask})
```


Here, the collate function is as simple as:
```python
collate_tensordict_fn = lambda tds : torch.stack(tds, dim=0)

dataloader = Dataloader(DictDataset(), collate_fn = collate_tensordict_fn)
```
This is even more useful when considering nested structures (Which `TensorDict`supports).

TensorDict inherits multiple properties from `torch.Tensor` and `dict` that we will detail furtherdown.

## `TensorDict` structure

In [1]:
from torchrl.data import TensorDict
from torchrl.data.tensordict.tensordict import UnsqueezedTensorDict, ViewedTensorDict, PermutedTensorDict, LazyStackedTensorDict
import torch

TensorDict is a Datastructure indexed by keys. The values can either be tensors, memmap-tensors or `TensorDict`. The values need to share the same device and the same shared memory. They can however have different dtypes.

Another essential property of TensorDict is the batch_size. It is required when setting a `TensorDict`. We define as batch_size the n-first dimensions common to all values.

Nested `TensorDict`have therefore the following property. The parent `TensorDict` needs to have a batch_size included in the childs `TensorDict` batch size.

In [2]:
a = torch.zeros(3, 4)
b = TensorDict(
    {
    "c": torch.zeros(3, 4, 5, dtype=torch.int32),
    "d": torch.zeros(3, 4, 5, 6, dtype=torch.float32)
    },
    batch_size=[3, 4, 5]
)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32),
        b: TensorDict(
            fields={
                c: Tensor(torch.Size([3, 4, 5, 1]), dtype=torch.int32),
                d: Tensor(torch.Size([3, 4, 5, 6]), dtype=torch.float32)},
            batch_size=torch.Size([3, 4]),
            device=cpu,
            is_shared=False)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


`TensorDict` does not support algebraic operations.

## `TensorDict` dictionary features

`TensorDict` shares a lot of features with python dictionaries

In [3]:
a = torch.zeros(3, 4, 5)
b = torch.zeros(3, 4)
tensordict = TensorDict({"a": a, "b": b}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


### `get(key)`
If we want to access a certain key, we can index the tensordict or alternatively use the `get` method:

In [4]:
print(tensordict["a"] is tensordict.get("a") is a)
print(tensordict["a"].shape)

True
torch.Size([3, 4, 5])


The `get` method also supports default values:

In [5]:
out = tensordict.get("foo", torch.ones(3))
out

tensor([1., 1., 1.])

## `set(key, value)`
The `set()` method can be used to set new values. Regular indexing also does the job:

In [6]:
c = torch.zeros((3, 4, 2, 2))
tensordict.set("c", c)
print(f"td[\"c\"] is c: {c is tensordict['c']}")

d = torch.zeros((3, 4, 2, 2))
tensordict["d"] = d
print(f"td[\"d\"] is d: {d is tensordict['d']}")

td["c"] is c: True
td["d"] is d: True


## Other methods:
### `keys`
We can access the keys of a tensordict:

In [7]:
for key in tensordict.keys():
    print(key)

a
b
c
d


### `values`
The values of a `TensorDict` can be retrieved with the `values()` function. Note that, unlike python `dict`s, the `values()` method returns a generator and not a list.

In [8]:
for value in tensordict.values():
    print(value.shape)

torch.Size([3, 4, 5])
torch.Size([3, 4, 1])
torch.Size([3, 4, 2, 2])
torch.Size([3, 4, 2, 2])


### TensorDict.update()
The `update` method can be used to update a TensorDict with another one (or with a dict):

In [9]:
tensordict.update({"a": torch.ones((3, 4, 5)), "d": 2*torch.ones((3, 4, 2))})
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)), "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
print(f"a is now equal to 1: {(tensordict['a'] == 1).all()}")
print(f"d is now equal to 2: {(tensordict['d'] == 2).all()}")

a is now equal to 1: True
d is now equal to 2: True


### TensorDict del key
TensorDict also support keys deletion with the `del` operator:

In [10]:
del tensordict["c"]
print(tensordict.keys())

dict_keys(['a', 'b', 'd'])


## TensorDict as a Tensor-like object

But wait? Can't we do this with a classical dict? 
Well, we would like the TensorDict to keep some nice Pytorch properties. TensorDict combines the advantages of the Python dictionary and of a Pytorch Tensor.
TensorDict has a batch size. It is not inferred automatically by looking at the tensors, but must be set when creating the TensorDict.

TensorDict is a tensor container where all tensors are stored in akey-value pair fashion and where each element shares at least the following features:
- device;
- memory location (shared, memory-mapped array, ...);
- batch size (i.e. n^th first dimensions).

In [11]:
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 4]),
    device=cpu,
    is_shared=False)


#### Batch size
`TensorDict` has a batch size which is shared across all tensors. The batch size can be [], unidimensional or multidimensional according to your needs.

In [12]:
print(f"Our TensorDict is of size {tensordict.shape}")

Our TensorDict is of size torch.Size([3, 4])


You cannot have items that don't share the batch size inside the same TensorDict:

In [13]:
# we cannot add tensors that violate the batch size:
try:
    tensordict.update({"c": torch.zeros(4, 3, 1)})
except RuntimeError as err:
    print(f"Caramba! We got this error: {err}")

Caramba! We got this error: batch dimension mismatch, got self.batch_size=torch.Size([3, 4]) and tensor.shape[:self.batch_dims]=torch.Size([4, 3])


When changing the batch_size, it needs to comply with the `TensorDict` batch_size

In [14]:
tensordict.batch_size = [3]
assert tensordict.batch_size == torch.Size([3])
tensordict.batch_size = [3, 4]

In [15]:
try:
    tensordict.batch_size = [4, 4]
except RuntimeError as err:
    print(f"Caramba! We got this error: {err}")

Caramba! We got this error: the tensor a has shape torch.Size([3, 4, 5]) which is incompatible with the new shape torch.Size([4, 4]).


We can also fill the values of a TensorDict sequentially

In [16]:
tensordict = TensorDict({}, [10])
for i in range(10):
    tensordict[i] = TensorDict({"a": torch.randn(3, 4)}, [])
print(tensordict)

TensorDict(
    fields={
        a: Tensor(torch.Size([10, 3, 4]), dtype=torch.float32)},
    batch_size=torch.Size([10]),
    device=cpu,
    is_shared=False)


If all values are not filled, they get the default value of zero.

In [17]:
tensordict = TensorDict({}, [10])
for i in range(2):
    tensordict[i] = TensorDict({"a": torch.randn(3, 4)}, [])
assert (tensordict[9]["a"] == torch.zeros((3,4))).all()
tensordict = TensorDict({"a": torch.zeros(3, 4, 5), "b": torch.zeros(3, 4)}, batch_size=[3, 4])

#### Devices
TensorDict can be sent to the desired devices like a pytorch tensor with `td.cuda()` or `td.to(device)` with `device`the desired device

#### Memory sharing via physical memory usage
When on cpu, one can use either `tensordict.memmap_()` or `tensordict.share_memory_()` to send a `tensordict` to represent it as a memory-mapped collection of tensors or put it in shared memory resp.

### Tensor operations
We can perform tensor operations among the batch dimensions:

#### Cloning
TensorDict supports cloning. Cloning returns the same TensorDict class than the original item.

In [18]:
tensordict_clone = tensordict.clone()
tensordict_clone["a"] = torch.ones(*tensordict.shape, 5)
print("redefining a tensor in the clone does not impact the original tensordict: ", (tensordict["a"] == tensordict_clone["a"]).all())

redefining a tensor in the clone does not impact the original tensordict:  tensor(False)


#### Slicing and indexing
Slicing and indexing is supported along the batch dimensions

In [19]:
tensordict[0]

TensorDict(
    fields={
        a: Tensor(torch.Size([4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([4]),
    device=cpu,
    is_shared=False)

In [20]:
tensordict[1:]

TensorDict(
    fields={
        a: Tensor(torch.Size([2, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 4]),
    device=cpu,
    is_shared=False)

In [21]:
tensordict[:, 2:]

TensorDict(
    fields={
        a: Tensor(torch.Size([3, 2, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([3, 2, 1]), dtype=torch.float32)},
    batch_size=torch.Size([3, 2]),
    device=cpu,
    is_shared=False)

#### Setting values with indexing
We can also edit certain tensor features by deliminting certain indexes:

In [22]:
subtd = tensordict[:, torch.tensor([1, 3])]  # a SubTensorDict keeps track of the original one: it does not create a copy in memory of the original data
tensordict.fill_("a", -1)
assert (subtd["a"] == -1).all()  # the "a" key-value pair has changed

In [23]:
td2 = TensorDict({"a": torch.zeros(2, 4, 5), "b": torch.zeros(2, 4)}, batch_size=[2, 4])
tensordict[:-1] = td2
tensordict["a"], tensordict["b"]

(tensor([[[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],
 
         [[ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.],
          [ 0.,  0.,  0.,  0.,  0.]],
 
         [[-1., -1., -1., -1., -1.],
          [-1., -1., -1., -1., -1.],
          [-1., -1., -1., -1., -1.],
          [-1., -1., -1., -1., -1.]]]),
 tensor([[[0.],
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          [0.]],
 
         [[0.],
          [0.],
          [0.],
          [0.]]]))

We can set values easily just by indexing the tensordict:

#### Masking
We mask `TensorDict` as we mask tensors.

In [24]:
mask = torch.BoolTensor([[1, 0, 1, 0], [1, 0, 1, 0], [1, 0, 1, 0]])
tensordict[mask]

TensorDict(
    fields={
        a: Tensor(torch.Size([6, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([6, 1]), dtype=torch.float32)},
    batch_size=torch.Size([6]),
    device=cpu,
    is_shared=False)

#### Stacking

TensorDict supports stacking, stacking is done in a lazy fashion, returning a `LazyStackedTensorDict` item.

In [25]:
#Stack
clonned_tensordict = tensordict.clone()
staked_tensordict = torch.stack([tensordict, clonned_tensordict], dim=0)
print(staked_tensordict)
if staked_tensordict[0] is tensordict and staked_tensordict[1] is clonned_tensordict:
    print("every tensordict is awesome!")

LazyStackedTensorDict(
    fields={
        a: Tensor(torch.Size([2, 3, 4, 5]), dtype=torch.float32),
        b: Tensor(torch.Size([2, 3, 4, 1]), dtype=torch.float32)},
    batch_size=torch.Size([2, 3, 4]),
    device=cpu,
    is_shared=False)
every tensordict is awesome!


If we want to have a contiguous tensordict, we can call `.to_tensordict()` or `.contiguous()`. It is recommended to perform this operation before accessing the values of the stacked tensordict for efficiency purposes

In [26]:
assert isinstance(staked_tensordict.contiguous(), TensorDict)
assert isinstance(staked_tensordict.to_tensordict(), TensorDict)

#### Unbind
TensorDict can unbind among a dim over the tensordict batch size

In [27]:
list_tensordict = tensordict.unbind(0)
assert type(list_tensordict) == tuple
assert len(list_tensordict) == 3
assert (torch.stack(list_tensordict, dim=0).contiguous() == tensordict).all()

#### Cat
TensorDict supports cat to concatenate among a dim. The dim must be in the batch_size.

In [28]:
#Cat
list_tensordict = tensordict.unbind(0)
assert torch.cat(list_tensordict, dim=0).shape[0] == 12

#### View
Support for the view operation returning a `ViewedTensorDict`. Use `to_tensordict` to comeback to retrieve TensorDict

In [29]:
assert type(tensordict.view(-1)) == ViewedTensorDict
assert tensordict.view(-1).shape[0] == 12

#### Permute
We can permute the dims of `TensorDict`. Permute is a Lazy operation that returns PermutedTensorDict. Use `to_tensordict` to convert to `TensorDict`.

In [30]:
assert type(tensordict.permute(1,0)) == PermutedTensorDict
assert tensordict.permute(1,0).batch_size == torch.Size([4, 3])

#### Reshape
Reshape allows reshaping the `TensorDict` batch size

In [31]:
assert tensordict.reshape(-1).batch_size == torch.Size([12])

#### Squeeze and Unsqueeze
Tensordict also supports squeeze and unsqueeze. Unsqueeze is a lazy operation that returns UnsqueezedTensorDict. Use `to_tensordict` to retrieve a tensordict after unsqueeze.

In [32]:
unsqueezed_tensordict = tensordict.unsqueeze(0)
assert type(unsqueezed_tensordict) == UnsqueezedTensorDict
assert unsqueezed_tensordict.batch_size == torch.Size([1, 3, 4])

assert type(unsqueezed_tensordict.squeeze(0)) == TensorDict
assert unsqueezed_tensordict.squeeze(0) is tensordict