This post contains the code behind this video explanation:

{{< video https://youtu.be/JDy58DtZC_g >}}

In [None]:
#| code-fold: true

import torch
from torch import tensor
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

Imagine a supervised learning scenario of a classification task with sequential data as features and a binary target.

Let's simulate a toy dataset and take a look at it:

In [None]:
#| code-fold: true

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self):
        self.xs = [
            list(range(11, 13)),
            list(range(13, 16)),
            list(range(16, 21)),
            list(range(21, 24)),
            list(range(22, 25)),
            list(range(25, 30)),
        ]
        self.ys = [0, 0, 0, 1, 1, 1]
        assert len(self.xs) == len(self.ys)
    def __len__(self): 
        return len(self.xs)
    def __getitem__(self, idx):
        return {
            "x": self.xs[idx],
            "y": self.ys[idx],
        }

In [None]:
dset = CustomDataset()

for item in dset:
    print(item)

{'x': [11, 12], 'y': 0}
{'x': [13, 14, 15], 'y': 0}
{'x': [16, 17, 18, 19, 20], 'y': 0}
{'x': [21, 22, 23], 'y': 1}
{'x': [22, 23, 24], 'y': 1}
{'x': [25, 26, 27, 28, 29], 'y': 1}


In [None]:
dloader = DataLoader(dset, batch_size=2, shuffle=False)

In [None]:
for batch in dloader:
    print(batch)

RuntimeError: each element in list of batch should be of equal size

## A first solution attempt

We can refactor our dataset and make it generate items with `x` sequences that all have the same length (a parameter `max_len` that we define beforehand).

In [None]:
class CustomDatasetFixLen(torch.utils.data.Dataset):
    def __init__(self, max_len=10):
        self.max_len = max_len
        self.xs = [
            list(range(11, 13)),
            list(range(13, 16)),
            list(range(16, 21)),
            list(range(21, 24)),
            list(range(22, 25)),
            list(range(25, 30)),
        ]
        self.ys = [0, 0, 0, 1, 1, 1]
        assert len(self.xs) == len(self.ys)
    def __len__(self): 
        return len(self.xs)
    def __getitem__(self, idx):
        x = self.xs[idx]
        pad_len = self.max_len - len(x)
        x = x + [0]*pad_len
        return {
            "x": np.array(x),
            "y": self.ys[idx],
        }

In [None]:
dset = CustomDatasetFixLen(max_len=10)

In [None]:
for item in dset:
    print(item)

{'x': array([11, 12,  0,  0,  0,  0,  0,  0,  0,  0]), 'y': 0}
{'x': array([13, 14, 15,  0,  0,  0,  0,  0,  0,  0]), 'y': 0}
{'x': array([16, 17, 18, 19, 20,  0,  0,  0,  0,  0]), 'y': 0}
{'x': array([21, 22, 23,  0,  0,  0,  0,  0,  0,  0]), 'y': 1}
{'x': array([22, 23, 24,  0,  0,  0,  0,  0,  0,  0]), 'y': 1}
{'x': array([25, 26, 27, 28, 29,  0,  0,  0,  0,  0]), 'y': 1}


That works but is wasteful because we will be padding to `max_len` = 10, even when we only need to pad to length 3 (for example, if the batch is formed by the first two items).
That could limit the batch size we work with slowing down the training or even lead to unnecessary computations during the forward pass if we just pass our batches without masking. 
So, ideally, we would like to pad only as much as we need _on each batch_.
In other words, we want to dynamically (per batch basis) adapt the padding.

## There must be a better way

Let's implement our own collate function, i.e. the logic to put items together, that will allow us to the padding on a per batch basis (thus we call it `dynamic_length_collate`)

In [None]:
def dynamic_length_collate(batch):
    max_len = max(len(item["x"]) for item in batch)
    batch_x = []
    for item in batch:
        pad_len = max_len - len(item["x"])
        batch_x.append(item["x"] + [0]*pad_len)
    return {
        "x": tensor(batch_x).type(torch.float),
        "y": tensor([item["y"] for item in batch])
    }

In [None]:
dset = CustomDataset()  # Use our original dataset, without fix max_len
dloader = DataLoader(dset, batch_size=2, shuffle=False,
                     collate_fn=dynamic_length_collate)

In [None]:
for batch in dloader:
    print(batch)

{'x': tensor([[11., 12.,  0.],
        [13., 14., 15.]]), 'y': tensor([0, 0])}
{'x': tensor([[16., 17., 18., 19., 20.],
        [21., 22., 23.,  0.,  0.]]), 'y': tensor([0, 1])}
{'x': tensor([[22., 23., 24.,  0.,  0.],
        [25., 26., 27., 28., 29.]]), 'y': tensor([1, 1])}


That works!

For the sake of completeness, let's use our dataloader with the custom collate function and actually feed the data into a (toy) neural network.

In [None]:
# A very toy example of a neural network
model = torch.nn.LSTM(input_size=1, hidden_size=2, batch_first=True)

for batch in dloader:
    bs, seq_len = batch["x"].shape
    pred = model(batch["x"].reshape(bs, seq_len, 1))
    print(pred)
    break

(tensor([[[ 9.4607e-04,  4.0929e-03],
         [ 5.0468e-04,  5.7644e-03],
         [-1.5826e-01,  1.9474e-02]],

        [[ 2.6432e-04,  2.6775e-03],
         [ 1.3929e-04,  3.5764e-03],
         [ 7.2860e-05,  3.3866e-03]]], grad_fn=<TransposeBackward0>), (tensor([[[-1.5826e-01,  1.9474e-02],
         [ 7.2860e-05,  3.3866e-03]]], grad_fn=<StackBackward0>), tensor([[[-0.2420,  0.0843],
         [ 0.0071,  1.2391]]], grad_fn=<StackBackward0>)))


## Fin

----
Any bugs, questions, comments, suggestions? Ping me on [twitter](https://www.twitter.com/fabridamicelli) or drop me an e-mail (fabridamicelli at gmail).