In [1]:
from typing import List, Tuple
import torch
from torch.utils.data import Dataset

In [2]:
class OfficeSeq2Seq(Dataset):
    """Seq2Seq dataset that pads inside __getitem__ and does teacher forcing shift."""
    def __init__(self,
                 enc_src: List[List[int]],
                 enc_tgt: List[List[int]],
                 max_src: int,
                 max_tgt: int,
                 pad_id: int) -> None:
        super().__init__()
        assert len(enc_src) == len(enc_tgt), "src/tgt length mismatch"
        self.src = enc_src
        self.tgt = enc_tgt
        self.max_src = max_src
        self.max_tgt = max_tgt
        self.pad_id = pad_id

    @staticmethod
    def _pad_to(ids: List[int], L: int, pad_id: int) -> List[int]:
        if len(ids) >= L:
            return ids[:L]
        return ids + [pad_id] * (L - len(ids))

    def __len__(self) -> int:
        return len(self.src)

    def __getitem__(self, i: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # pad source/target
        s = self._pad_to(self.src[i], self.max_src, self.pad_id)        # (S,)
        t = self._pad_to(self.tgt[i], self.max_tgt, self.pad_id)        # (T,)

        # teacher forcing shift
        dec_in = t[:-1]     # (T-1,)
        labels = t[1:]      # (T-1,)

        # convert to long & contiguous tensors
        s = torch.tensor(s, dtype=torch.long).contiguous()
        dec_in = torch.tensor(dec_in, dtype=torch.long).contiguous()
        labels = torch.tensor(labels, dtype=torch.long).contiguous()
        return s, dec_in, labels

`__len__`: tells PyTorch how many samples are in the dataset

`__getitem__`: when DataLoader asks for the i-th item in our dataset:
- take the i-th source (self.src[i]) and pad it to max_src
- take the i-th target (self.tgt[i]) and pad it to max_tgt

**Teacher Forcing & Sequence Shifting**:

In seq2seq models (like Transformers), we train the decoder to predict the next token
given all previous true tokens, not its own predictions.
This is called Teacher Forcing.

We prepare the target sequence `t` like this:

| Token role | Example | Explanation |
|-------------|----------|-------------|
| `t` | `[BOS, H, e, l, l, o, EOS]` | the full target sequence |
| Decoder input `dec_in` | `[BOS, H, e, l, l, o]` | shifted right (starts with BOS) |
| Labels `labels` | `[H, e, l, l, o, EOS]` | shifted left (the "next" tokens) |

During training, at each time step *t*, the model sees the real previous token (from `dec_in[t-1]`) and learns to predict the next token (`labels[t]`).

At inference time, we feed back the modelâ€™s *own* predictions instead
(one token at a time).

**`pad_to`** (paddling function):
- if the list ids (token sequence) is longer than $L$, it will truncate it: `ids[:L]`
- if shorter, append enough [pad] tokens to reach length L