# Training a Mini-GPT to Learn Two-Digit Addition

[![Twitter Handle](https://img.shields.io/badge/Twitter-@gaohongnan-blue?style=social&logo=twitter)](https://twitter.com/gaohongnan)
[![LinkedIn Profile](https://img.shields.io/badge/@gaohongnan-blue?style=social&logo=linkedin)](https://linkedin.com/in/gao-hongnan)
[![GitHub Profile](https://img.shields.io/badge/GitHub-gao--hongnan-lightgrey?style=social&logo=github)](https://github.com/gao-hongnan)
[![Code](https://img.shields.io/badge/View-Code-blue?style=flat-square&logo=github)](https://github.com/gao-hongnan/omniverse/tree/5221d5d8b9bd845568b2e323d908be282c6e8434/omnivault/transformer/projects/adder)
![Tag](https://img.shields.io/badge/Tag-Structured_Musings-purple)

```{contents}
:local:
```

## Motivation

Generative Pre-trained Transformer (GPT) are well known to perform bad on
arithmetic tasks such as addition. This should not come as a surprise since GPT
is a _language_ model and not a _math_ model. It is designed to train on a large
corpus of text and learn the patterns and structure of natural language. While
we do encounter many arithmetic operations in corpus, the encoding of these
operations are often in a form that is in the text sense, not in the
mathematical sense. After all, what GPT does best is to predict the next token
over the entire **vocabulary** distribution.

In one of the examples provided from the repository
[minGPT](https://github.com/karpathy/minGPT/tree/master), Karpathy demonstrates
training a GPT model to learn the addition of two numbers presented as strings.
This is a simple task designed to illustrate how a decoder-only model can be
trained to learn "addition". Thus, the input is a sequence of characters
representing an addition operation (like "12 + 35") and the output is the
sequence of characters representing the result of the addition (like "47").

To this end, we replicate his example, which serves as a proof-of-concept to
show that decoder only models, which are often used for language-related tasks,
can learn other patterns or "languages," such as the "language" of arithmetic.

In [115]:
from __future__ import annotations

import os
import time
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import rich
import torch
from rich.pretty import pprint
from torch import nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader, Dataset, Subset, random_split
from omegaconf import OmegaConf as om


In [116]:
import sys
from pathlib import Path


def find_root_dir(current_path: Path | None = None, marker: str = '.git') -> Path | None:
    """
    Find the root directory by searching for a directory or file that serves as a
    marker.

    Parameters
    ----------
    current_path : Path | None
        The starting path to search from. If None, the current working directory
        `Path.cwd()` is used.
    marker : str
        The name of the file or directory that signifies the root.

    Returns
    -------
    Path | None
        The path to the root directory. Returns None if the marker is not found.
    """
    if not current_path:
        current_path = Path.cwd()
    current_path = current_path.resolve()
    for parent in [current_path, *current_path.parents]:
        if (parent / marker).exists():
            return parent
    return None

current_file_path = Path(os.getcwd())
root_dir          = find_root_dir(current_file_path, marker='omnivault')

if root_dir is not None:
    sys.path.append(str(root_dir))
    from omnivault._types._alias import Accuracy, Loss
    from omnivault.transformer.config.composer import Composer, DataConfig
    from omnivault.transformer.config.constants import MaybeConstant
    from omnivault.transformer.config.decoder import (
        AddNormConfig,
        DecoderBlockConfig,
        DecoderConfig,
        MultiHeadedAttentionConfig,
        PositionwiseFeedForwardConfig,
    )
    from omnivault.transformer.config.global_ import MaybeGlobal
    from omnivault.transformer.config.trainer import TrainerConfig
    from omnivault.transformer.config.optim import AdamConfig, OptimizerConfig
    from omnivault.transformer.core.dataset import AdderDataset, create_loader, split_dataset, construct_dummy_batch_future_masks, construct_dummy_batch_target_padding_masks
    from omnivault.transformer.core.trainer import Trainer
    from omnivault.transformer.core.vocabulary import AdderVocabulary
    from omnivault.transformer.decoder.core import GPTDecoder
    from omnivault.transformer.modules.attention.core import ScaledDotProductAttention
    from omnivault.transformer.utils.reproducibility import seed_all
    from omnivault.transformer.core.tokenizer import AdderTokenizer
    from omnivault.transformer.utils.general_utils import create_directory, download_file, validate_and_cleanup
    from omnivault.transformer.utils.config_utils import load_yaml_config, merge_configs
    from omnivault.core.logger import RichLogger
    from omnivault.utils.inspector.core import get_field_annotations
    import inspect
else:
    raise ImportError("Root directory not found.")

## Config

In [117]:
yaml_cfg = load_yaml_config(yaml_path=root_dir / "omnivault/transformer/projects/adder/config.yaml")
cfg = merge_configs(yaml_cfg, args_list=[])
om.resolve(cfg)  # inplace ops

In [118]:
constants: MaybeConstant = MaybeConstant(NUM_DIGITS=2, TOKENS=[
            "0",
            "1",
            "2",
            "3",
            "4",
            "5",
            "6",
            "7",
            "8",
            "9",
            "+",
            "*",
            "-",
            "=",
            "<BOS>",
            "<EOS>",
            "<PAD>",
            "<UNK>",
        ]
)
global_config: MaybeGlobal = MaybeGlobal(seed=42, debug=True, debug_samples=100)
data_config: DataConfig = DataConfig(**cfg.data)
optimizer_config = AdamConfig(name="torch.optim.Adam", lr=0.2, betas=(0.9, 0.98), eps=1e-9)
trainer_config = TrainerConfig(device="cpu")

composer = Composer(constants=constants, global_=global_config, data=data_config, optimizer=optimizer_config, trainer=trainer_config)
pprint(composer)

LOGGER = RichLogger(**composer.logger.model_dump(mode="python")).logger


## Reproducibility

Reproducibility in deep learning ensures that experiments can be repeated with
identical results, critical for verifying research findings and deploying
reliable models. Distributed training introduces complexity because it involves
multiple computation units which may not synchronize their random states
perfectly. If training is paused and resumed, ensuring each unit starts with the
correct seed to reproduce the exact computational path becomes challenging. To
address this, one can find more sophisticated examples in libraries like
Composer, where the whole library's core is built around training deep neural
nets in any environment (distributed or not) with reproducibility in mind.

```{admonition} References
:class: seealso

-   [Composer](https://github.com/mosaicml/composer/blob/dev/composer/utils/reproducibility.py)
-   [PyTorch Reproducibility](https://pytorch.org/docs/stable/notes/randomness.html)
-   [PyTorch Worker](https://pytorch.org/docs/stable/notes/randomness.html#dataloader)
-   [PyTorch deterministic algorithms](https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html)
-   [CUBLAS reproducibility](https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility)
```

In [119]:
print(get_field_annotations(func_or_method = seed_all)[0])
print("\n")
print(inspect.getdoc(seed_all))

seed_all(composer.global_.seed, seed_torch=True, set_torch_deterministic=False)

[('seed', <class 'int'>, 1992), ('seed_torch', <class 'bool'>, True), ('set_torch_deterministic', <class 'bool'>, True)]


Seeds all relevant random number generators to ensure reproducible
outcomes. Optionally seeds PyTorch and activates deterministic
behavior in PyTorch based on the flags provided.

Parameters
----------
seed : int, default 1992
    The seed number for reproducibility.
seed_torch : bool, default True
    If True, seeds PyTorch's RNGs.
set_torch_deterministic : bool, default True
    If True, activates deterministic mode in PyTorch.

Returns
-------
seed : int
    The seed used for reproducibility.


42

## Vocabulary

In [120]:
vocabulary = AdderVocabulary.from_tokens(tokens=constants.TOKENS, num_digits=constants.NUM_DIGITS)  # type: ignore[attr-defined]
token_to_index = vocabulary.token_to_index
index_to_token = vocabulary.index_to_token
vocab_size = vocabulary.vocab_size
pprint(token_to_index)
pprint(index_to_token)
pprint(vocab_size)

Assign `vocab_size` to `composer.model` because we don't want to hardcode
`vocab_size` beforehand, and want to derive concrete values from the
`Vocabulary` object.

In [121]:
try:
    composer.model.vocab_size = vocab_size
except AttributeError as err:
    LOGGER.error(err)

Ah okay haha, this is the price of writing overly complex and useless code to
look fancy and you end up a mess. Anyways, we will handle this later on where
we can explicitly instantiate the model config class.

## Tokenization

In [122]:
tokenizer = AdderTokenizer(vocabulary=vocabulary)
assert tokenizer.vocabulary.token_to_index == token_to_index
assert tokenizer.vocabulary.index_to_token == index_to_token

In [123]:
pprint(tokenizer.encode("1"))

In [124]:
sequence = "15+57=072"
sequences = ["15+57=072", "01+02=003"]

In [125]:
encoded_sentence = tokenizer.encode(sequence)
print(f"Encoded sentence: {encoded_sentence}")

decoded_sentence = tokenizer.decode(encoded_sentence)
print(f"Decoded sentence: {decoded_sentence}")

Encoded sentence: [14, 1, 5, 10, 5, 7, 13, 0, 7, 2, 15]
Decoded sentence: 15+57=072


In [126]:
encoded_sentences = tokenizer.encode_batch(sequences)  # type: ignore[attr-defined]
print(f"Encoded sentences: {encoded_sentences}")
decoded_sentences = tokenizer.decode_batch(encoded_sentences)  # type: ignore[attr-defined]
print(f"Decoded sentences: {decoded_sentences}")

Encoded sentences: [[14, 1, 5, 10, 5, 7, 13, 0, 7, 2, 15], [14, 0, 1, 10, 0, 2, 13, 0, 0, 3, 15]]
Decoded sentences: ['15+57=072', '01+02=003']


In [127]:
# PAD = vocabulary.token_to_index[vocabulary.PAD]
# UNK = vocabulary.token_to_index[vocabulary.UNK]
# ADD = vocabulary.token_to_index[vocabulary.ADD]
# EQUAL = vocabulary.token_to_index[vocabulary.EQUAL]
# BOS = vocabulary.token_to_index[vocabulary.BOS]
# EOS = vocabulary.token_to_index[vocabulary.EOS]

## Create Dataset

In [128]:
def pad_number(num: int, length: int) -> str:
    """
    Pad numbers with zeros in front so that they have uniform length.

    Note, if a + b = c and num digits allowed to add is 2, then for
    a and b we always pad to length 2, but for c we always pad to length 3.

    Example
    -------
    6 + 90 = 96 -> 06 + 90 = 096

    Parameters
    ----------
    num : int
        Number to be padded.
    num_digits : int
        Length of the resulting padded number string.

    Returns
    -------
    str
        Padded number string.
    """
    return str(num).zfill(length)


def equation_to_string(a: int, b: int, c: int, num_digits: int) -> str:
    """
    Formats the addition equation as a string.

    Parameters
    ----------
    a : int
        First addend.
    b : int
        Second addend.
    c : int
        Sum of a and b.
    num_digits : int
        Number of digits each number in the equation should have.

    Returns
    -------
    str
        Formatted equation string.
    """
    padded_a = pad_number(a, num_digits)
    padded_b = pad_number(b, num_digits)
    padded_c = pad_number(c, num_digits + 1) # note the padding here!
    return f"{padded_a}+{padded_b}={padded_c}"

def decode_equation(vocab: AdderVocabulary, equation: torch.Tensor | List[int]) -> str:
    """
    Convert an equation in list format to string format.

    Parameters
    ----------
    equation : List[int]
        The equation in list format.

    Returns
    -------
    str
        The equation in string format.
    """
    if isinstance(equation, torch.Tensor):
        equation = equation.tolist()

    UNK = vocab.token_to_index[vocab.UNK]
    decoded_equation = "".join([str(index_to_token.get(x, UNK)) for x in equation])
    return decoded_equation.replace("<BOS>", "").replace("<EOS>", "")

def batch_decode_equation(vocab: AdderVocabulary, equations: torch.Tensor | List[List[int]]) -> List[str]:
    decoded_equations = []
    for equation in equations:
        decoded_equation = decode_equation(vocab, equation)
        decoded_equations.append(decoded_equation)
    return decoded_equations

def encode_equation(vocab: AdderVocabulary, equation: str, num_digits: int, device: torch.device) -> torch.Tensor:
    """
    Convert an equation (up to the equal sign in it) in string format to a list.

    Parameters
    ----------
    equation : str
        The equation in string format.
    num_digits : int
        Number of digits each number in the equation should have.
    device : torch.device
        The device to which the tensor should be sent.

    Returns
    -------
    torch.Tensor
        The equation in list format as a tensor.
    """
    plus_idx = equation.index(vocab.ADD)
    equal_idx = equation.index(vocab.EQUAL)

    BOS = vocab.token_to_index[vocab.BOS]
    UNK = vocab.token_to_index[vocab.UNK]

    a = pad_number(int(equation[:plus_idx]), num_digits)
    b = pad_number(int(equation[plus_idx + 1:equal_idx]), num_digits)

    new_equation = f"{a}+{b}="

    return torch.tensor(
        [BOS] + [token_to_index.get(n, UNK) for n in new_equation],
        dtype=torch.int
    ).to(device)

In [129]:
def create_add_dataset(
    vocab: AdderVocabulary, num_digits: int, dataset_size: int, rng_seed: int = 1337
) -> Tuple[List[torch.Tensor], List[str]]:
    BOS = vocab.token_to_index[vocab.BOS]
    EOS = vocab.token_to_index[vocab.EOS]
    UNK = vocab.token_to_index[vocab.UNK]

    rng = torch.Generator()
    rng.manual_seed(rng_seed)

    max_num = 10**num_digits - 1

    dataset_str = []
    for _ in range(dataset_size):
        a = torch.randint(low=0, high=max_num + 1, size=(1,), generator=rng).item()
        b = torch.randint(low=0, high=max_num + 1, size=(1,), generator=rng).item()
        c = a + b

        equation = equation_to_string(a, b, c, num_digits)

        dataset_str.append(equation)

    dataset_tensor = [
        torch.tensor([BOS] + [token_to_index.get(n, UNK) for n in x] + [EOS])
        for x in dataset_str
    ]
    return dataset_tensor, dataset_str

In [130]:
dataset_tensor, dataset_str = create_add_dataset(vocab=vocabulary, num_digits=2, dataset_size=4)
pprint(dataset_tensor)
pprint(dataset_str)

In [131]:
print(f"Decoded equation: {decode_equation(vocabulary, dataset_tensor[0])}")
assert (
    decode_equation(vocabulary, dataset_tensor[0])
    == dataset_str[0]
    == decode_equation(vocabulary, [15, 1, 5, 10, 5, 7, 13, 0, 7, 2, 14])
)

Decoded equation: 15+57=072


if we encode equation, we can encode up to equal sign like below.

In [132]:
print(f"Encoded equation: {encode_equation(vocabulary, dataset_str[0], num_digits=2, device=composer.trainer.device)}")

torch.testing.assert_close(
    encode_equation(vocabulary, dataset_str[0], num_digits=2, device=composer.trainer.device),
    torch.tensor([14, 1, 5, 10, 5, 7, 13], dtype=torch.int32),
)

Encoded equation: tensor([14,  1,  5, 10,  5,  7, 13], dtype=torch.int32)


Uncomment the below code to generate the dataset into a text file and yes, I am
lazy to add a config variable for whether to generate the dataset or not.

In [133]:
# dataset, dataset_str = create_add_dataset(vocab, self.num_digits, self.dataset_size)

# write dataset_str to a file
# with open("dataset_str.txt", "w") as f:
#     for item in dataset_str:
#         f.write("%s\n" % item)

### Encoding Strategy Overview

Our strategy for encoding arithmetic expressions is pretty self-explanatory,
where given a string `D1 + D2 = D3`, we encode it as `<BOS>D1+D2=0D3<EOS>`.
However, this is verbose for clarity sake. In fact, Karpathy's encoding strategy
simplifies arithmetic expressions by concatenating the digits of operands and
the result into a single string without explicit symbols for operations or
equality. This method relies on a fixed number of digits (`num_digits`) for
operands, which streamlines the model's interpretation of the sequence. For
example, if `num_digits` is set to 2, every encoded expression is structured to
follow a predictable pattern: the first two digits represent the first operand,
the next two digits represent the second operand, and the final digits are
encoded as 3 digits because the max sum of two 2-digit numbers is 199, which is
3 digits. The digits of the result are encoded in reverse order. This
counterintuitive approach is designed to align with the GPT model's learning
algorithm, facilitating easier learning of the addition operation by mimicking
the traditional right-to-left calculation process in addition.

To illustrate, let's examine the encoding of arithmetic expressions with
`num_digits=2`:

For the expression `6 + 39 = 45`, we have the following:

-   The first two digits `06` represent the number 6, zero-padded to adhere to
    the `num_digits=2` requirement.
-   The next two digits `39` represent the number 39, already fitting the digit
    requirement.
-   The final part `054` represents the result 45, reversed to `54` and preceded
    by a zero to maintain the total length of $2n + (n + 1) = 7 $ digits for
    `num_digits=2`.


## Dataset

In [134]:
create_directory(composer.data.dataset_dir)
download_file(url=composer.data.dataset_url, output_path=composer.data.dataset_path)

  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100   97k  100   97k    0     0   244k      0 --:--:-- --:--:-- --:--:--  244k


In [135]:
with open(composer.data.dataset_path, "r") as file:
    sequences = [line.strip() for line in file]

dataset = AdderDataset(data=sequences, tokenizer=tokenizer)

pprint(next(iter(dataset)))

## Construct Batches, Collate Function and DataLoader

We first reverse engineer what our dataset is returning. The disclaimer here is
that for decoder only models like GPT, many people often omit the padding mask
since all the samples $\mathbf{x}$ are chunked to sequence/context length of
window size $T$, and future masks are usually handled within the `Attention`
class since we will never attend to the future tokens. However, for the sake of
clarity, we will include the padding and future mask in the dataset (i.e.
actually it is for the sake of my own understanding when I started to implement
decoder from scratch).

In [136]:
input, target, target_padding_mask, future_mask = next(iter(dataset))

### Input and Target

I think if you've read my
[section here](https://www.gaohongnan.com/transformer/decoder/implementation.html#construction-of-input-and-target-sequences),
then we would easily see that given an input sequence $\mathbf{x}$, the target
sequence $\mathbf{y}$ is simply the input sequence $\mathbf{x}$ shifted by one
time step to the left. 

In [137]:
print(f"Input : {input}")
print(f"Target: {target}")

Input : tensor([14,  1,  5, 10,  5,  7, 13,  0,  7,  2])
Target: tensor([16, 16, 16, 16, 16, 16,  0,  7,  2, 15])


### Target Padding Mask

When you're dealing with sequences of different lengths, you pad the shorter
sequences with a special token `PAD` (usually $0$ or $-100$) to make them the
same length as the longest one in the batch. These paddings should not
contribute to the model's learning, so you need to mask them out. In practice,
you'll often see a mask argument in `Attention` layers in PyTorch where if
`True`, the attention scores are set to `-inf` for the padded positions so that
these positions become zero after the softmax operation, thereby not
contributing to the weighted sum of the input sequence.

In a decoder-only model like GPT, the input sequence is essentially the target.
The model aims to generate tokens that come after the given input, treating it
as the "history" or "context" for the task of text generation. Unlike
encoder-decoder models like the original Transformer, where the encoder
processes a source sequence and the decoder generates a target sequence, a
decoder-only model works solely with what would traditionally be considered the
target sequence.

Consequently, although the terminology "target padding mask" might seem more
intuitive in the context of encoder-decoder models, where the distinction
between source (input) and target (output) sequences is clear. The distinction
is blurred in decoder-only models like GPT as the model processes input to
predict the next token in a sequence. Here, the source is essentially the target
at different stages of processing: the model uses previous tokens (source) to
predict the next token (target). However, during my implementation, I was mainly
referring to transformer models that use encoder-decoder architecture, and the
terminology therefore stemmed from that context.

The definition of a target padding mask is a binary mark that ignores pad-tokens
in the source input (in decoder only model, the source is the target). And the
shape is $(\mathcal{B}, T)$.

Let's illustrate the target padding mask with an example. Suppose we have a
batch of sequences with different lengths:

In [138]:
target_batch = [
    [5, 7, 9],
    [8, 6],
    [3, 12, 4, 11, 17],
    [2, 1, 4, 5],
]
pprint(target_batch)

If we try to "batch" these sequences, PyTorch would throw an error indicating
that you need all sequences to have the same length.

In [139]:
try:
    target_batch = torch.tensor(target_batch, dtype=torch.int64)
except ValueError as err:
    LOGGER.error(err)

To address this issue, we could pad the sequences to the same length and create a mask to indicate
which positions are padded.  We pad the shorter sequences with a special token `PAD`
to make them the same length as the longest one in the batch. 

In [140]:
PAD = vocabulary.token_to_index[vocabulary.PAD]

max_len = max(len(seq) for seq in target_batch)
target_batch = [seq + [PAD] * (max_len - len(seq)) for seq in target_batch]
pprint(target_batch)

target_batch = torch.tensor(target_batch, dtype=torch.int64)
pprint(target_batch)

In [141]:
batch_size, seq_len = target_batch.size()

target_padding_mask = target_batch != PAD

pprint(target_padding_mask)

assert target_padding_mask.size() == (batch_size, seq_len) == (4, 5)

Of course, we would need a _batch_ of these masks, so we would have a shape of
$(\mathcal{B}, T)$ like mentioned above. As we will see later, we will still
need to broadcast the shape to $(\mathcal{B}, 1, T, T)$ to match the shape of
the attention scores.

Theoretically speaking, it is possible for the sequence length $T$ to vary
across samples $\mathbf{x}$. However, we usually have the same length for all
samples in GPT, and in this particular case, we do know that each sample
necessarily have the same length by _design_. However, for the sake of
explanation, we note that in our `Dataset`, it will only generate 1 single
sample data point and do not worry about different sequence length across other
samples in the dataset $\mathcal{S}$, but in deep learning we train in
mini-batches $\mathcal{B}$, and with different batch sizes we may encounter
issues (i.e. matrix multiplication may not work).

### Future Mask

In the decoder, each position can only attend to positions that come before it
in the sequence to maintain the auto-regressive property. This is different from
the encoder, where all positions can attend to all other positions.

The definition of future mask is basically a look-ahead mask to ensure that each
position only attends to positions before it in the sequence where we mask out
future positions (i.e., positions that come after the current position) so that
they don't contribute to the current attention scores. Before the softmax
operation, we'll mark these positions as `-inf` so that they become zero after
the softmax operation - effectively zeroing out the attention scores for future
positions. What does zeroing out these masked logits actually does? Basically,
the attention mechanism can be thought of as a weighted average of all the
tokens in the input sequence. Each token is assigned a weight, with higher
weights indicating more relevance to the token under consideration. If a certain
token should not be considered at all (e.g., it's a future token that should not
be visible to the current decoder step, or it's a padding token), its weight
should be zero.

The shape of the future mask is $(T, T)$ for a target sequence/sample
$\mathbf{x}$ of length $T$. Let's see a concrete example to illustrate the
future mask.

In [142]:
seq_len = 5
future_mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
future_mask = future_mask == 0

pprint(future_mask)
assert future_mask.size() == (seq_len, seq_len) == (5, 5)

### Merge Padding and Future Masks

We see from our `decoder` implementation below, that one of the method is 
creating the target masks. In other words, we are creating the target padding
masks and future masks, and merging them together. 


```{code-block} md
---
linenos: true
emphasize-lines: 27
---

def create_target_masks(
    self,
    batch_size: int,
    seq_len: int,
    target_padding_masks: torch.BoolTensor | NotGiven = NOT_GIVEN,
    future_masks: torch.BoolTensor | NotGiven = NOT_GIVEN,
) -> torch.BoolTensor:
    target_masks_shape = (batch_size, 1, seq_len, seq_len)
    if target_padding_masks is NOT_GIVEN and future_masks is NOT_GIVEN:
        target_padding_masks = cast(
            torch.BoolTensor, construct_dummy_batch_target_padding_masks(batch_size, seq_len)
        )
        future_masks = cast(torch.BoolTensor, construct_dummy_batch_future_masks(batch_size, seq_len))

    if target_padding_masks is NOT_GIVEN:
        target_padding_masks = cast(
            torch.BoolTensor, construct_dummy_batch_target_padding_masks(batch_size, seq_len)
        )

    if future_masks is NOT_GIVEN:
        future_masks = cast(torch.BoolTensor, construct_dummy_batch_future_masks(batch_size, seq_len))

    assert target_padding_masks.shape == future_masks.shape == target_masks_shape  # type: ignore[union-attr]

    return cast(
        torch.BoolTensor,
        torch.logical_and(cast(torch.Tensor, target_padding_masks), cast(torch.Tensor, future_masks)).bool(),
    )
```

The purpose of applying `logical_and` between `target_padding_mask` and
`future_mask` is to combine the constraints from both masks when calculating
self-attention scores in the transformer's decoder. The `target_padding_mask` is
designed to mask out the padding tokens in the input sequence, while the
`future_mask` ensures that a given position cannot attend to future positions in
the sequence. By combining these masks, you can perform the necessary masking
for both padding and future tokens in a single step.

Here's how it works:

1. `target_padding_mask`: Masks out the padding tokens so that they don't
   contribute to the attention calculations. True values mean "attend to this
   token," and False values mean "ignore this token."

2. `future_mask`: The future mask is created as a lower triangular matrix, where
   the lower triangle, including the diagonal, is filled with ones, and the
   upper triangle is filled with zeros. Masks out future tokens in a sequence so
   that a token at a given position can only attend to positions that come
   before it (and itself). True values mean "attend to this token," and False
   values mean "ignore this token."

3. `logical_and(target_padding_mask, future_mask)`: Combines the two masks. A
   True in the resulting mask means that the condition for both padding and
   future attention is satisfied.

By combining these two masks, the decoder obeys the autoregressive property,
ensuring it doesn't see future tokens, while also ignoring padding tokens in the
input sequence. We may term it the `target_mask`.

#### First Sample First Token

-   `target_padding_mask` has size of `[4, 5]`.
    -   We zoom in to the first row (sample) which is of length 5.
    -   This length 5 is the sequence length, which is `T, T, T, F, F`
        indicating the last 2 tokens being padded.
-   `future_mask` has size of `[5, 5]`.
    -   We note that this is indepedent of batch size. Each sample should have
        the same future mask shape of `[L, L]`.
    -   This `L=5` should necessary be same for the sequence length in
        `target_padding_mask`.
-   First, let's consider one batch of 4 samples. What we do first is to
    broadcast `future_mask` to `[4, 5, 5]` because we want each sample/row in
    the batch to have the same future mask. As shown below:

In [143]:
pprint(future_mask)
future_mask = future_mask.view(1, seq_len, seq_len).expand(size=(batch_size, -1, -1))
pprint(future_mask)
pprint(future_mask.shape)

-   Now, we can zoom in to one particular sample since both
    `target_padding_mask` and `future_mask` have the same first dimension of
    batch size.
-   What is incomplete is that we need to broadcast `target_padding_mask`'s last
    dimension to have the same dimensions as `future_mask`. This means we
    broadcast `[4, 5]` to `[4, 5, 5]`. But why?
-   For simplicity, we slice the first same of both below.
-   The first row of the `future_mask` of the first sample is `T, F, F, F, F`.
    This corresponds to what? This is the future mask of the first token in the
    sequence. Well, that is confusing, because it apparently have 5 elements,
    and has "information" of the other 4 tokens in the sequence. Let's explain
    in details below:
    -   Regarding the first row of the `future_mask` in the first sample, which
        is `[T, F, F, F, F]`, it might initially seem confusing why there are 5
        elements. Each of these elements, in fact, corresponds to whether the
        first token can attend to other tokens at each respective position in
        the sequence. Here's how to interpret it:
        -   The first element (`True`) indicates that the first token can attend
            to itself.
        -   The next four elements (`False`) specify that the first token should
            not attend to any of the future tokens in the sequence.
-   Consequently, what is the first token in the sequence of the
    `target_padding_mask`? Recall earlier we mentioned that the first sample's
    `target_padding_mask` is `T, T, T, F, F` and therefore the first token in
    the sequence is `T`.
-   What do we want to achieve here? We want to make sure that the model does
    not **attend** to tokens in the sequence that are masked with `False`.
-   In other words, the first token in the sequence of the first sample has
    `target_padding_mask` of `T` and `future_masks` of `T, F, F, F, F`.
-   We need to broadcast this `T` to `T, T, T, T, T` to align with
    `T, F, F, F, F` because? Because we need ensure that this first token in the
    sequence is also able to considered in relation to every other token in the
    sequence.
-   So the first token is not a padded token, which is `T`, similarly, the first
    token needs to attend to itself at the first position, hence `T` and `T`
    give `T`. But for the second `T` in the now broadcasted
    `target_padding_mask`, it is still representing the first token or?
-   Broadcasting the first token's `target_padding_mask` value of `T` to
    `[T, T, T, T, T]` ensures that when this first token is being considered for
    attention computations, it is free to attend to any position, barring any
    restrictions set by `future_mask`.
-   Tricky: after broadcasting, each `T` in `[T, T, T, T, T]` is still
    representing the first token. They indicate that when the first token is
    compared with _any_ token in the sequence (including itself), it is not a
    padding token. The element-wise `AND` with the `future_mask` then further
    refines this by restricting it from attending to future tokens.


In [144]:
pprint(target_padding_mask)
pprint(target_padding_mask[0])

target_padding_mask = target_padding_mask.view(batch_size, 1, seq_len).expand(size=(batch_size, seq_len, seq_len))
pprint(target_padding_mask)
pprint(target_padding_mask.shape)

In [145]:
pprint(target_padding_mask[0])
pprint(future_mask[0])
pprint(target_padding_mask[0] & future_mask[0])

#### First Sample Fourth Token

Now let's look at another example—the 4th token in the sequence, where
`target_padding_mask = [T, T, T, F, F]` and `future_mask` is a lower triangular
matrix with `True`s.

1. **4th Token's target_padding_mask**: The 4th token has a value of `F` in
   `target_padding_mask`, indicating it's a padding token.

2. **4th Row of future_mask**: The 4th row in `future_mask` is
   `[True, True, True, True, False]`. This means that if this token were not a
   padding token, it would be allowed to attend to all the previous tokens in
   the sequence and itself, but not to any future token.

3. **Broadcast target_padding_mask**: To align `target_padding_mask` with
   `future_mask`, we'd broadcast `F` from the `target_padding_mask` to
   `[F, F, F, F, F]`. This way, when we consider the 4th token in relation to
   any other token in the sequence, it's still marked as a padding token.

4. **Element-wise AND with future_mask**: After broadcasting, you'd perform an
   element-wise AND between `[F, F, F, F, F]` and
   `[True, True, True, True, False]`, resulting in `[F, F, F, F, F]`.

5. **Interpretation**: This effectively means that the 4th token won't attend to
   any other token in the sequence, and no token will attend to it either, as it
   is a padding token.

So, the masks are doing their jobs correctly: the `target_padding_mask`
indicates whether each token is a padding token or not, and `future_mask`
dictates the "rules" of attention regarding what each token can attend to.
Combining them ensures that both conditions are met.


### Further Add a Singleton Dimension in Target Masks

Now both masks are of shape: `(B, L, L)` but we need to add a singleton
dimension to the last dimension to make it `(B, 1, L, L)`.

In deep learning frameworks like PyTorch, the dimensions of the tensors involved
in operations like matrix multiplication or attention mechanisms often have
specific semantic meanings. In the context of attention mechanisms, especially
in the transformer architecture, the attention mask usually has a shape that is
compatible with the attention logits for element-wise multiplication.

In the transformer model, the attention logits are often computed as a dot
product between query and key vectors, resulting in a tensor of shape
`(Batch size, Num heads, Sequence length, Sequence length)` or `(B, H, L, L)`.
Here, `B` is the batch size, `H` is the number of attention heads, and `L` is
the sequence length.

To make the mask tensor compatible for element-wise operations with this 4D
tensor, it needs to have a shape that can be broadcasted to `(B, H, L, L)`. A
mask of shape `(B, 1, L, L)` fulfills this requirement.

The singleton dimension is added so that the mask can be easily broadcast to the
shape of the attention logits tensor during the computation. When a tensor with
shape `(B, 1, L, L)` is element-wise multiplied with a tensor of shape
`(B, H, L, L)`, the singleton dimension (the `1`) allows the mask to be used for
each attention head without explicitly replicating the mask `H` times. This is
more memory-efficient and often faster.

Thus, adding a singleton dimension in masks is a preparatory step that allows
for efficient element-wise operations later in the model's forward pass.


In [146]:
target_padding_mask = target_padding_mask.unsqueeze(1)
pprint(target_padding_mask.shape)

future_mask = future_mask.unsqueeze(1)
pprint(future_mask.shape)

target_mask = target_padding_mask & future_mask
pprint(target_mask.shape)

### Split to Train-Valid-Test

In [147]:
batch_size   = 256

composer.data.train_loader["batch_size"] = batch_size
composer.data.valid_loader["batch_size"] = batch_size
composer.data.test_loader["batch_size"] = batch_size

train_dataset, valid_dataset, test_dataset = split_dataset(
    dataset=dataset, split=composer.data.split, seed=composer.global_.seed
)

train_size, valid_size, test_size = len(train_dataset), len(valid_dataset), len(test_dataset)
train_size, valid_size, test_size

(7000, 2000, 1000)

In [148]:
# max_seq_len is determined by 1+ num_digits + 1 + num_digits + 1 + num_digits + 1 + 1
# where the 1s represent BOS, Plus sign, Equal sign, the extra digit in the sum, EOS, respectively.
max_seq_len = 1 + 1 + 1 + 1 + 2 * composer.constants.NUM_DIGITS + (composer.constants.NUM_DIGITS + 1)
assert max_seq_len == composer.data.context_length

### Create DataLoader

In [149]:
train_loader = create_loader(
    dataset=train_dataset,
    loader_config=composer.data.train_loader,
    collate_fn_config=composer.data.collate_fn,
)

valid_loader = create_loader(
    dataset=valid_dataset,
    loader_config=composer.data.valid_loader,
    collate_fn_config=composer.data.collate_fn,
)

test_loader = create_loader(
    dataset=test_dataset,
    loader_config=composer.data.test_loader,
    collate_fn_config=composer.data.collate_fn,
)

The `collate_fn` defines how to combine these variable-length samples into a
batch. This usually involves padding the sequences in the batch to a common
length, which is typically the length of the longest sequence in the batch. Note
here the padding in collate is "redundant" since in our earlier code we ensured
that all sample has same number of characters by way of padding zeros in front.
For example, `23 + 3 =26` will become `23 + 03 = 026`. Consequently, all samples
in the mini-batch will have same length by definition.

In [150]:
torch.manual_seed(composer.global_.seed)

batch_index = 0
for batch in train_loader:
    # Each batch is a tuple containing all elements for the batch
    inputs_padded, targets_padded, padding_masks_padded_and_expanded, future_masks_expanded = batch

    # Print the length of each component in the batch
    print("Batch Size:", len(inputs_padded))

    # Now you can print shapes or other properties of each batch element
    print("Inputs Shape:", inputs_padded.shape)
    print("Targets Shape:", targets_padded.shape)

    # Decoding and other processing can be done here
    # For example, decoding the first sequence in the batch
    print("Decoded First Equation/Sample of the Batch:", decode_equation(vocabulary, inputs_padded[0].tolist()))

    print("-" * 80)

    batch_index += 1
    if batch_index == 4: break

Batch Size: 256
Inputs Shape: torch.Size([256, 10])
Targets Shape: torch.Size([256, 10])
Decoded First Equation/Sample of the Batch: 31+04=035
--------------------------------------------------------------------------------
Batch Size: 256
Inputs Shape: torch.Size([256, 10])
Targets Shape: torch.Size([256, 10])
Decoded First Equation/Sample of the Batch: 37+49=086
--------------------------------------------------------------------------------
Batch Size: 256
Inputs Shape: torch.Size([256, 10])
Targets Shape: torch.Size([256, 10])
Decoded First Equation/Sample of the Batch: 47+26=073
--------------------------------------------------------------------------------
Batch Size: 256
Inputs Shape: torch.Size([256, 10])
Targets Shape: torch.Size([256, 10])
Decoded First Equation/Sample of the Batch: 53+05=058
--------------------------------------------------------------------------------


## Model

In [151]:
# Create individual component configurations
masked_self_attention_mha_config = MultiHeadedAttentionConfig(
     attention=ScaledDotProductAttention(),
    d_model=128, H=4, dropout=0.1
)

feed_forward_config = PositionwiseFeedForwardConfig(
    d_model=128, d_ff=256, activation=nn.GELU(approximate="tanh"), dropout=0.1, bias=True
)

add_norm_config_1 = AddNormConfig(feature_dim=128, dropout=0.1)
add_norm_config_2 = AddNormConfig(feature_dim=128, dropout=0.1)

# Create DecoderBlockConfig
decoder_block_config = DecoderBlockConfig(
    masked_self_attention_mha=masked_self_attention_mha_config,
    feed_forward=feed_forward_config,
    add_norm_1=add_norm_config_1,
    add_norm_2=add_norm_config_2,
)

# Create the overall DecoderConfig
model_config = DecoderConfig(
    d_model=128,
    vocab_size=vocab_size,
    context_length=max_seq_len,
    num_decoder_blocks=2,
    dropout=0.1,
    decoder_block=decoder_block_config,
)

model = GPTDecoder(model_config).to(composer.trainer.device)

model_size = model.total_trainable_parameters
print(f'model_size: {model_size}, train_set_size: {train_size}')

model_size: 270226, train_set_size: 7000


## Training Paradigm

### Optimizer

### Learning Rate Scheduler

see common utils

### Loss

Talk and link to bottom notes

In [152]:
warmup_steps = 3 * len(train_loader)
# lr first increases in the warmup steps, and then decays
lr_fn        = lambda step: model_config.d_model**(-0.5) * min([(step+1)**(-0.5), (step+1)*warmup_steps**(-1.5)])
# optimizer    = torch.optim.Adam(model.parameters(), lr=0.2, betas=(0.9, 0.98), eps=1e-9)

# optimizer_config = OptimizerConfig(name="torch.optim.Adam", lr=0.2, betas=(0.9, 0.98), eps=1e-9)
# optimizer   = optimizer_config.build(params=model.parameters())

# optimizer_config = OptimizerConfig(name="torch.optim.Adam", lr=0.2)
# optimizer   = optimizer_config.build(params=model.parameters(), betas=(0.9, 0.98), eps=1e-9)

optimizer   = composer.optimizer.build(params=model.parameters())
print(optimizer)
scheduler    = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_fn)
criterion    = nn.CrossEntropyLoss(ignore_index=PAD, reduction="mean")


@dataclass
class Metrics:
    loss: Loss
    accuracy: Accuracy

Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.98)
    capturable: False
    differentiable: False
    eps: 1e-09
    foreach: None
    fused: None
    lr: 0.2
    maximize: False
    weight_decay: 0.0
)


1. `input` is indeed `[bs, 10]` because max len is 11, so removed last token.
2. `target` should be `[bs, 10]` but left shifted of the real original input but somehow i got 11.
3. Think of vocab size to be num classes in my classification problem. But the

In [153]:
    # Create optimizer based on model parameters
    if composer.trainer.apply_weight_decay_to_different_param_groups:
        assert hasattr(composer.optimizer, "weight_decay")
        optimizer = optimizer_pydantic_config.build(
            params=apply_weight_decay_to_different_param_groups(
                model=model, weight_decay=composer.optimizer.weight_decay
            )
        )
    else:
        optimizer = optimizer_pydantic_config.build(params=model.parameters())

    # Create criterion
    criterion = criterion_pydantic_config.create_instance()
    assert criterion.ignore_index == vocabulary.token_to_index[vocabulary.PAD]

    # Create Scheduler noam
    # TODO: this part is hardcoded in a way since we are using LambdaLR.
    # I do not have time to make it more "automated" so this is anti-config-pattern.
    warmup_steps = 3 * len(train_loader)

    # lr first increases in the warmup steps, and then decays
    noam = lambda step: noam_lr_decay(step, d_model=composer.model.d_model, warmup_steps=warmup_steps)  # noqa: E731

    scheduler_config_cls = SCHEDULER_REGISTRY[cfg.scheduler.name]


    if issubclass(scheduler_config_cls, LambdaLRConfig):
        scheduler_pydantic_config = scheduler_config_cls(lr_lambda=noam, **cfg.scheduler)
    else:
        scheduler_pydantic_config = scheduler_config_cls(**cfg.scheduler)  # type: ignore[assignment]

    assert composer.scheduler is MISSING  # now it is MISSING for us to fill up.
    composer.scheduler = scheduler_pydantic_config
    scheduler = scheduler_pydantic_config.build(optimizer=optimizer)


NameError: name 'optimizer_pydantic_config' is not defined

In [None]:
from omnivault.transformer.core.state import State

state = State(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    vocabulary=vocabulary,
    tokenizer=tokenizer,
)
state.pretty_print()

In [None]:
pprint(composer.trainer)

talk about state fromr eadme

│   │   use_amp=False,
│   │   autocast_config={'enabled': False},
│   │   scaler_config={'enabled': False, 'init_scale': 65536.0, 'growth_factor': 2.0, 'backoff_factor': 0.5, 'growth_interval': 2000},

In [106]:
# pprint(composer)

composer.trainer.use_amp=False
composer.trainer.autocast_config = {'enabled': False, 'dtype': torch.bfloat16, 'cache_enabled': True}
composer.trainer.scaler_config["enabled"]=False

In [154]:
from omnivault.transformer.core.trainer import Trainer, TrainerEvent
from omnivault.transformer.core.callbacks import save_state


In [155]:
trainer = Trainer(
    state=state,
    composer=composer,
    logger=LOGGER,
    device=composer.trainer.device,  # type: ignore[arg-type]
)
trainer.remove_callback(event=TrainerEvent.ON_VALID_EPOCH_END.value, callback=save_state)
# trainer.add_callback(
#     TrainerEvent.ON_VALID_EPOCH_END.value,
#     lambda trainer: evaluate_and_generate_on_valid_epoch_end(trainer, num_batches_to_eval=None),
# )
_trained_state = trainer.fit(train_loader=train_loader, valid_loader=valid_loader, test_loader=test_loader)
_trained_state.pretty_print()
history = _trained_state.history

CPU Autocast only supports dtype of torch.bfloat16 currently.


                                      

RuntimeError: Currently, AutocastCPU only support Bfloat16 as the autocast_cpu_dtype

In [91]:
trainer = Trainer(
    model=model,
    train_dataloader=train_loader,
    criterion=criterion,
    optimizer=optimizer,
    scheduler=scheduler,
    grad_norm_clip=1.0,
    device=composer.trainer.device,
    valid_dataloader=valid_loader,
    # test_dataloader=test_loader,
    # NOTE: uncomment the above line to enable testing after each epoch
    # but seeding will affect.
)

if DEBUG:
    trained_model = trainer.fit(max_epochs=2) # or 15
    # torch.save(model.state_dict(), 'model_debug.pt')
    # model_debug = torch.load('./model_debug.pt')
    # if are_both_models_same(model.state_dict(), model_debug):
    #     print("Pass")
    # else:
    #     print("Fail")

else:
    trained_model = trainer.fit(max_epochs=30)

    # torch.save(model.state_dict(), 'model_non_debug.pt')

TypeError: __init__() got an unexpected keyword argument 'model'

In [None]:
break

```
Epoch 1/2
----------
100%|██████████| 28/28 [00:04<00:00,  5.60it/s]
Average Epoch Training Loss   : 2.41188
100%|██████████| 8/8 [00:00<00:00, 20.18it/s]
Average Epoch Validation Loss : 1.66422
Epoch 2/2
----------
100%|██████████| 28/28 [00:03<00:00,  9.07it/s]
Average Epoch Training Loss   : 1.36899
100%|██████████| 8/8 [00:00<00:00, 20.31it/s]
Average Epoch Validation Loss : 1.16084
Training complete
```

```
Epoch 1/2
----------
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:03<00:00,  8.95it/s]
Average Epoch Training Loss   : 2.40482
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 15.67it/s]
Average Epoch Validation Loss : 1.72585
Epoch 2/2
----------
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 28/28 [00:02<00:00, 12.24it/s]
Average Epoch Training Loss   : 1.37748
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 29.34it/s]
Average Epoch Validation Loss : 1.15630
Training complete
```

In [None]:
trained_model

In [None]:
batch = next(iter(train_loader))
pprint(batch)

inputs, targets, target_padding_masks, future_masks = batch


# Step 2: Pass the sample through the model
trained_model.eval()  # Set the model to evaluation mode

with torch.no_grad():
    # Assuming your model and sample require specific formatting, adjust as necessary
    logits = model(inputs, target_padding_masks=target_padding_masks, future_masks=future_masks)

In [None]:
last_decoder_block = trained_model.decoder_blocks[-1] # take last decoder block? more feature?
# pprint(last_decoder_block)

masked_self_attention_mha = last_decoder_block.masked_self_attention_mha
pprint(masked_self_attention_mha)

context_vector, attention_weights = masked_self_attention_mha.context_vector, masked_self_attention_mha.attention_weights
pprint(attention_weights.shape)
# but has H=4 heads so do we take 1 head and check the heatmap?
# torch.Size([208, 4, 10, 10])

last_batch_last_sample_first_head_attention_weights = attention_weights[-1, 0:1, :, :].squeeze(0)
pprint(last_batch_last_sample_first_head_attention_weights.shape)


the xy axis is keys and queries, which is correct `Q @ K.T`

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

# Your existing setup
last_decoder_block = trained_model.decoder_blocks[-1]
masked_self_attention_mha = last_decoder_block.masked_self_attention_mha
context_vector, attention_weights = masked_self_attention_mha.context_vector, masked_self_attention_mha.attention_weights

# Number of heads
num_heads = attention_weights.size(1)

# Labels for each character in the sequence, including BOS
labels = ['<BOS>'] + list('59+14=073')

# Loop over each head and plot its heatmap
for head in range(num_heads):
    plt.figure(figsize=(10, 10))

    # Extract attention weights for the last sample in the last batch for this head
    attention_matrix = attention_weights[-1, head, :, :].detach().numpy()

    sns.heatmap(attention_matrix, annot=True, cmap='viridis', xticklabels=labels, yticklabels=labels)
    plt.title(f"Attention Weights Heatmap for '<BOS>59+14=073' - Head {head+1}")
    plt.xlabel("Keys")
    plt.ylabel("Queries")
    plt.show()


## DEBUG

### W


```
Epoch 1/2
----------
100%|██████████| 28/28 [00:04<00:00,  5.91it/s]
100%|██████████| 8/8 [00:00<00:00, 14.21it/s]
Training Loss   : 2.08882
Validation Loss : 1.27368
Epoch 2/2
----------
  7%|▋         | 2/28 [00:00<00:04,  5.68it/s]
100%|██████████| 28/28 [00:04<00:00,  6.42it/s]
100%|██████████| 8/8 [00:00<00:00, 21.24it/s]
Training Loss   : 1.23194
Validation Loss : 1.10291
Training complete
```

CHANGED EOS and BOS SWAP POSITION

```
Epoch 1/2
----------
 32%|███▏      | 9/28 [00:01<00:02,  6.90it/s]
100%|██████████| 28/28 [00:03<00:00,  8.22it/s]
100%|██████████| 8/8 [00:00<00:00, 26.60it/s]
Training Loss   : 2.10450
Validation Loss : 1.28284
Epoch 2/2
----------
100%|██████████| 28/28 [00:02<00:00,  9.82it/s]
100%|██████████| 8/8 [00:00<00:00, 25.31it/s]
Training Loss   : 1.23119
Validation Loss : 1.09374
```

### M 

```
Epoch 1/2
----------
100%|██████████| 28/28 [00:02<00:00, 13.63it/s]
100%|██████████| 8/8 [00:00<00:00, 34.32it/s]
Training Loss   : 2.08863
Validation Loss : 1.26961
Epoch 2/2
----------
100%|██████████| 28/28 [00:01<00:00, 14.40it/s]
100%|██████████| 8/8 [00:00<00:00, 44.49it/s]
Training Loss   : 1.23620
Validation Loss : 1.11484
Training complete

---

Epoch 29/30
----------
100%|██████████| 28/28 [00:01<00:00, 15.37it/s]
100%|██████████| 8/8 [00:00<00:00, 41.86it/s]
Training Loss   : 0.01514
Validation Loss : 0.00067
Epoch 30/30
----------
100%|██████████| 28/28 [00:01<00:00, 15.38it/s]
100%|██████████| 8/8 [00:00<00:00, 42.62it/s]
Training Loss   : 0.01448
Validation Loss : 0.00057
Training complete
```

```
Epoch 1/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.22it/s]
100%|██████████| 8/8 [00:00<00:00, 29.15it/s]
Training Loss   : 2.10712
Validation Loss : 1.27872
Epoch 2/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.78it/s]
100%|██████████| 8/8 [00:00<00:00, 29.36it/s]
Training Loss   : 1.23133
Validation Loss : 1.09508
Epoch 3/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.91it/s]
100%|██████████| 8/8 [00:00<00:00, 29.92it/s]
Training Loss   : 1.03519
Validation Loss : 0.87629
Epoch 4/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.41it/s]
100%|██████████| 8/8 [00:00<00:00, 28.30it/s]
Training Loss   : 0.87725
Validation Loss : 0.78361
Epoch 5/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.86it/s]
100%|██████████| 8/8 [00:00<00:00, 30.18it/s]
Training Loss   : 0.79957
Validation Loss : 0.73302
Epoch 6/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.35it/s]
100%|██████████| 8/8 [00:00<00:00, 28.31it/s]
Training Loss   : 0.76029
Validation Loss : 0.69880
Epoch 7/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.45it/s]
100%|██████████| 8/8 [00:00<00:00, 30.39it/s]
Training Loss   : 0.72721
Validation Loss : 0.68126
Epoch 8/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.53it/s]
100%|██████████| 8/8 [00:00<00:00, 27.81it/s]
Training Loss   : 0.70416
Validation Loss : 0.63668
Epoch 9/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.25it/s]
100%|██████████| 8/8 [00:00<00:00, 29.87it/s]
Training Loss   : 0.64809
Validation Loss : 0.48974
Epoch 10/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.47it/s]
100%|██████████| 8/8 [00:00<00:00, 22.07it/s]
Training Loss   : 0.39999
Validation Loss : 0.16127
Epoch 11/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.28it/s]
100%|██████████| 8/8 [00:00<00:00, 29.57it/s]
Training Loss   : 0.20621
Validation Loss : 0.08199
Epoch 12/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.71it/s]
100%|██████████| 8/8 [00:00<00:00, 15.80it/s]
Training Loss   : 0.13922
Validation Loss : 0.05043
Epoch 13/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.76it/s]
100%|██████████| 8/8 [00:00<00:00, 29.94it/s]
Training Loss   : 0.11169
Validation Loss : 0.03682
Epoch 14/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.03it/s]
100%|██████████| 8/8 [00:00<00:00, 30.37it/s]
Training Loss   : 0.08848
Validation Loss : 0.02700
Epoch 15/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.33it/s]
100%|██████████| 8/8 [00:00<00:00, 30.22it/s]
Training Loss   : 0.07917
Validation Loss : 0.02183
Epoch 16/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.20it/s]
100%|██████████| 8/8 [00:00<00:00, 27.96it/s]
Training Loss   : 0.06974
Validation Loss : 0.01599
Epoch 17/30
----------
100%|██████████| 28/28 [00:02<00:00, 11.07it/s]
100%|██████████| 8/8 [00:00<00:00, 30.04it/s]
Training Loss   : 0.05679
Validation Loss : 0.01285
Epoch 18/30
----------
100%|██████████| 28/28 [00:02<00:00, 11.82it/s]
100%|██████████| 8/8 [00:00<00:00, 18.65it/s]
Training Loss   : 0.04896
Validation Loss : 0.00878
Epoch 19/30
----------
100%|██████████| 28/28 [00:03<00:00,  9.00it/s]
100%|██████████| 8/8 [00:00<00:00, 26.38it/s]
Training Loss   : 0.04387
Validation Loss : 0.00921
Epoch 20/30
----------
100%|██████████| 28/28 [00:02<00:00,  9.98it/s]
100%|██████████| 8/8 [00:00<00:00, 28.77it/s]
Training Loss   : 0.04160
Validation Loss : 0.00447
Epoch 21/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.33it/s]
100%|██████████| 8/8 [00:00<00:00, 28.80it/s]
Training Loss   : 0.03468
Validation Loss : 0.00423
Epoch 22/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.97it/s]
100%|██████████| 8/8 [00:00<00:00, 29.11it/s]
Training Loss   : 0.03085
Validation Loss : 0.00279
Epoch 23/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.41it/s]
100%|██████████| 8/8 [00:00<00:00, 27.02it/s]
Training Loss   : 0.02741
Validation Loss : 0.00197
Epoch 24/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.09it/s]
100%|██████████| 8/8 [00:00<00:00, 29.83it/s]
Training Loss   : 0.02015
Validation Loss : 0.00132
Epoch 25/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.08it/s]
100%|██████████| 8/8 [00:00<00:00, 29.30it/s]
Training Loss   : 0.01844
Validation Loss : 0.00229
Epoch 26/30
----------
100%|██████████| 28/28 [00:02<00:00, 13.23it/s]
100%|██████████| 8/8 [00:00<00:00, 29.35it/s]
Training Loss   : 0.01913
Validation Loss : 0.00103
Epoch 27/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.50it/s]
100%|██████████| 8/8 [00:00<00:00, 27.89it/s]
Training Loss   : 0.01545
Validation Loss : 0.00076
Epoch 28/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.30it/s]
100%|██████████| 8/8 [00:00<00:00, 26.09it/s]
Training Loss   : 0.01616
Validation Loss : 0.00104
Epoch 29/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.28it/s]
100%|██████████| 8/8 [00:00<00:00, 25.84it/s]
Training Loss   : 0.01504
Validation Loss : 0.00092
Epoch 30/30
----------
100%|██████████| 28/28 [00:02<00:00, 12.59it/s]
100%|██████████| 8/8 [00:00<00:00, 25.87it/s]
Training Loss   : 0.01006
Validation Loss : 0.00047
Training complete
```

In [None]:
break

```
x -> tensor([[15,  9,  8, 10,  3,  5, 13]])
future_mask -> 7x7
tensor([[ True, False, False, False, False, False, False],
│   │   [ True,  True, False, False, False, False, False],
│   │   [ True,  True,  True, False, False, False, False],
│   │   [ True,  True,  True,  True, False, False, False],
│   │   [ True,  True,  True,  True,  True, False, False],
│   │   [ True,  True,  True,  True,  True,  True, False],
│   │   [ True,  True,  True,  True,  True,  True,  True]])

logits--> 1x7x18 because 1 sample
tensor([[[  7.8,  -0.2,  -2.3,  -1.1,  -0.1,  -3.2,  -4.4,
          -2.4,   3.7,  -0.9,  -5.1,  -4.5,  -5.6,  -2.2,
          -0.5,  -4.2,  -2.9,  -4.9],
        [  0.3,   3.7,   0.9,   1.7,   0.4,  -4.0,  -6.0,
          -2.3,   8.5,   7.3,  -6.0,  -5.1,  -6.2,  -3.0,
         -10.9,  -3.8,  -5.3,  -5.9],
        [-10.5,  -0.4,   4.3,   2.4,  -6.3,  -8.9,  -0.1,
           8.2,   8.6,   0.4,   1.2,   0.9,   0.7,   0.6,
           6.9,   0.0,   0.4,   1.2],
        [ -2.8,   9.6,   2.0,  -6.2,  -8.2,  -2.3,   5.7,
           6.6,  -0.3,  -4.7,  -0.5,  -0.9,  -0.9,   1.2,
           2.3,  -0.4,   0.1,  -1.5],
        [ -2.9,   1.6,  -1.0,  -5.8,  -0.2,   6.2,  14.1,
           8.0,  -4.0,  -9.7,  -2.1,  -3.4,  -3.2,  -1.4,
           0.0,  -1.7,   0.0,  -3.0],
        [ -9.4,   1.7,   5.4,  -1.3,  -6.6,  -4.7,   6.7,
          10.2,   1.9,  -9.6,   0.8,   0.6,   0.7,   1.2,
          10.2,   0.4,   1.3,   1.2],
        [  0.3,  16.1,   3.2,  -4.4,  -5.7,  -2.9,  -3.7,
          -6.1,  -2.1,   4.0,  -0.4,   0.1,  -0.4,   0.0,
           0.6,  -0.6,  -1.2,  -0.7]]])

logits.argmax(dim=-1) -> 1x7
tensor([[0,  8,  8,  1,  6, 14,  1]])
```

`logits.argmax(dim=-1)` basically compress 1x7x18 to 1x7 where for each row of the
7 rows, find the index that is maximum for example, first row 7.8 is max of all
18 elements, so index 0 is returned. `tensor([[0,  8,  8,  1,  6, 14,  1]])`

There is some meaning here too, remember our input `[15, 9, 8, 10, 3, 5, 13]`
this is basically the BOS (15) up till the equal sign, then
`[ 0, 8, 8, 1, 6, 14, 1]` is basically the prediction of each token what comes
next.

1. **Input Sequence**: Your input sequence is `[15, 9, 8, 10, 3, 5, 13]`. In
   this context, `15` could be a special token like BOS (Beginning of Sentence)
   or something else depending on your encoding scheme.

2. **Output Tensor Interpretation**: The output tensor
   `tensor([[ 0, 8, 8, 1, 6, 14, 1]])` represents the model's sequential
   predictions for each step of the input:

   - The first element `0` is the prediction following the first element `15` of
     the input.
   - The second element `8` is the prediction after seeing the first two
     elements `15, 9` of the input.
   - The third element `8` is predicted after seeing `15, 9, 8`.
   - The fourth element `1` follows after `15, 9, 8, 10`.
   - The sequence continues in this manner, with each new prediction based on an
     increasingly longer prefix of the input sequence.

3. **Sequential Predictions**: This output suggests that the model is working in
   an autoregressive manner. It generates predictions one token at a time, and
   each prediction is based on the sequence of tokens it has seen up to that
   point.

4. **Specific Meanings of Output Tokens**: The actual meaning of each token in
   your output tensor (`0`, `8`, `1`, `6`, `14`, etc.) depends on your specific
   encoding and task. In a language model, these would correspond to specific
   words or characters. In a numerical context, they could represent numbers or
   operations.

In summary, the output tensor reflects the model's predictions for what comes
next in the sequence, based on the current and all previous input tokens. Each
element in the output is the model's guess for the next token, considering the
sequence of tokens it has seen up to that point.

> Then we move on to the concat operation:


- In our model, after processing the input `[15, 9, 8, 10, 3, 5, 13]`, it
  predicts the next token to be `1`. This prediction is based on the entire
  sequence seen so far.

- The process of extending the input sequence with this new token (`1`) and then
  feeding this extended sequence back into the model for further predictions is
  indeed an example of greedy decoding. The model is iteratively building a
  longer sequence, one token at a time, always choosing the most likely next
  token at each step.

- This process would continue until a stopping condition is met, which might be
  the prediction of an EOS (End of Sentence) token or reaching a maximum
  sequence length.


> for i in range(num_digits + 2):
> now you know why loop over 4 times in total if num digits is 2.
> This is because, after equal sign, we will have answer of 3 digits (xyz)
> and an EOS token, our stop condition!

Lastly: `tensor([[15,  9,  8, 10,  3,  5, 13,  1,  3,  3, 14]])` is the full predicted
after EOS is met. 


In [None]:
config.global_config.seed = 42

In [None]:
def construct_future_mask(seq_len: int) -> torch.BoolTensor:
    future_mask = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool), diagonal=1).to(torch.bool)
    future_mask = future_mask.contiguous()
    future_mask = future_mask == 0
    return torch.BoolTensor(future_mask)

def construct_padding_mask(input_sequence: torch.Tensor, pad_token_id: int) -> torch.BoolTensor:
    padding_mask = input_sequence != pad_token_id
    return torch.BoolTensor(padding_mask)

In [None]:
@torch.no_grad()
def compute_sum(model, x) -> List[int]:
    "Function for computing the sum of two numbers."
    # x=[[15,  9,  8, 10,  3,  5, 13]]
    for _ in range(num_digits + 2):
        # pprint(x)
        pad_mask = (x != PAD).view(1, 1, 1, x.size(-1)).to(DEVICE)
        future_mask = construct_future_mask(seq_len=x.size(1))
        batch_size, seq_len = x.size()
        future_mask = future_mask.view(1, seq_len, seq_len).expand(size=(batch_size, -1, -1)).unsqueeze(1)
        #print(pad_mask.shape, future_mask.shape)
        #inputs, targets, target_padding_masks, future_masks = construct_batches(x)
        #print(target_padding_masks.shape, future_masks.shape)
        logits = model(input_tokens=x, target_padding_masks=pad_mask, future_masks=future_mask)
        pprint(logits.shape)
        time.sleep(100)
        #logits = model(inputs, target_padding_masks=target_padding_masks, future_masks=future_masks)

        last_output = logits.argmax(-1)[:, -1].view(1, 1)
        x = torch.cat((x, last_output), 1).to(DEVICE)
        # STOPPING CONDITION!
        if last_output.item() == EOS:
            break
        #return
    return x[0]


def evaluate(model, dataloader, num_batch=None):
    """
    Function for evaluation the model.

    This function take equations, and truncate them up to the equal-sign, and feed
    them to the model to get the predictions, compare them with the correct answers,
    and output the accuracy.
    """
    model.eval()
    acc, count = 0, 0
    num_wrong_to_display = 5
    for idx, batch in enumerate(dataloader):
        (
            inputs,
            targets,
            target_padding_masks,
            future_masks,
        ) = batch  # construct_batches(batch)
        for equation in inputs:
            # pprint(equation)
            # add EOS behind equation
            equation = torch.cat((equation, torch.tensor([EOS])), 0) # TODO: PLEASE DO NOT DO THIS - DO NOT MODIFY LIKE THIS.
            # fmt: off
            loc_equal_sign = equation.tolist().index(EQUAL)
            loc_EOS        = equation.tolist().index(EOS)
            input          = equation[0 : loc_equal_sign + 1].view(1, -1).to(DEVICE)
            ans            = equation[: loc_EOS + 1].tolist()
            ans_pred       = compute_sum(model, input)
            count += 1
            # fmt: on

            if ans == ans_pred.tolist():
                acc += 1
            else:
                if num_wrong_to_display > 0:
                    print(
                        f'correct equation: {decode_equation(vocab=vocab, equation=equation).replace("<PAD>","")}'
                    )
                    print(f"wrongly predicted as:        {decode_equation(vocab=vocab, equation=ans_pred)}")
                    num_wrong_to_display -= 1
        if num_batch and idx > num_batch:
            break
    return acc / count


def what_is(question: str) -> str:
    "function for computing the sum of two numbers with input in literal string format"
    pred = compute_sum(model, encode_equation(question, num_digits).view(1, -1))
    pred = decode_equation(pred)
    pred = pred[pred.index("=") + 1 :]
    return question + pred


The provided code implements a form of greedy decoding for sequence generation.
Let's break down how it aligns with the principles of greedy decoding:

1. **Greedy Decoding Principle**: Greedy decoding in sequence generation models
   involves choosing the most probable next token at each step of the sequence
   generation. This is done iteratively until a stopping condition is met (like
   reaching an EOS token or a maximum length).

2. **Implementation in Your Code**:

   - The `compute_sum` function generates a sequence by repeatedly predicting
     the next token and appending it to the input.
   - For each iteration in `compute_sum`:
     - The model (`model(x, pad_mask, future_mask)`) generates logits for the
       next token based on the current sequence (`x`).
     - `last_output = logits.argmax(-1)[:, -1].view(1, 1)` picks the most
       probable next token (the token with the highest logit value) from the
       logits. This is the essence of greedy decoding.
     - This token is then appended to the sequence:
       `x = torch.cat((x, last_output), 1)`.
   - The process continues until the model generates an EOS token, as indicated
     by `if last_output.item() == EOS: break`.

3. **Evaluation Function**:

   - The `evaluate` function further confirms this approach by feeding truncated
     sequences (up to the equal sign) from the dataloader to the `compute_sum`
     function and comparing the model's predictions to the correct answers.

4. **Characteristics of Greedy Decoding**:
   - Greedy decoding is computationally efficient and straightforward but may
     not always produce the best possible sequence. It does not reconsider past
     decisions; it always picks the most likely next token at each step without
     considering the global context of the sequence.

In summary, the provided code, especially the `compute_sum` function, implements
a typical greedy decoding approach. It iteratively generates a sequence by
choosing the most probable next token at each step, which is characteristic of
greedy decoding in sequence generation tasks.


In [None]:
print('training set examples the model gives an incorrect result:')
# rng = torch.Generator().manual_seed(config.global_config.seed)
seed_all(1992, seed_torch=True)

train_acc = evaluate(model, train_loader, 2)
pprint(train_acc) #
# print('validataion set examples the model gives an incorrect result:')
val_acc = evaluate(model, valid_loader)
pprint(val_acc)
# print('test set examples the model gives an incorrect result:')
test_acc = evaluate(model, test_loader)
pprint(test_acc)
# result = f'''train_size: {train_size}, test_acc: {test_acc}, val_acc: {val_acc}, train_acc: {train_acc}
#                 '''
# print(result)

QUESTION:

another not so smart question of the day: For an input sequence x1,x2,...,x_L, when it forward pass all the way through the decoder model, up till before the pre-logits/head/linear layer, and assuming for simplicity that we squeeze out the first batch dimension (only 1 sample), the the shape of the pre-logits is [L, D] where L is seq len and D the hidden embedding dimension. Am I right to say that the last row of [L, D] being the last token's representation, holds info of the full context of all previous tokens.

1. This means the last token in the input sequence (the last row in [L, D]) is a function of all previous tokens, so it is not surprising why the tutorial will just use the last row/token's corresponding prediction as the next predicted token/word, given all previous tokens.

> Important to know the last token or last row of [L, D] is actually a function of all previous tokens, here it is unmasked already.
> So if confused, just remember the pre logits last row, corresponding to the last token in the input sequence, is a function of all previous tokens.
> It just means that row holds all information, context, of all previous tokens so we can say its conditioned on all previous tokens.

train acc: 0.021484375 , 0.0185546875

non debug

```
correct equation: 24+86=110
predicted:        24+86=100
correct equation: 84+26=110
predicted:        84+26=100
validataion set examples the model gives an incorrect result:
test set examples the model gives an incorrect result:
train_size: 7000, train_loss: 0.013642309483007662,
                val_loss: 0.0008140208410623018, test_loss: 0.00040599027124699205,
                test_acc: 1.0, val_acc: 1.0, train_acc: 0.9996448863636364
```

### <a id='toc1_10_5_'></a>[MultiHeadAttention](#toc0_)

We start off by understanding the rationale of the following block:

```python
Q = self.W_Q(query).contiguous() # Z @ W_Q -> BxLxD @ DxD = BxLxD
K = self.W_K(key).contiguous()   # Z @ W_K
V = self.W_V(value).contiguous() # Z @ W_V
```

#### <a id='toc1_10_5_1_'></a>[A Primer](#toc0_)

In the context of the Transformer architecture and self-attention mechanism, the
matrices $\mathbf{W}^{Q}, \mathbf{W}^{K},$ and $\mathbf{W}^{V}$ are learnable
parameters designed to project the input embeddings $\mathbf{Z}$ into distinct
subspaces tailored for attention calculations. Let's explore their purpose and
their resulting transformations:

1. **The Role of Weights**:

   - $\mathbf{W}^{Q}$: Projects input embeddings into a query subspace,
     determining the type of information each token seeks from others.
   - $\mathbf{W}^{K}$: Positions the embeddings in a key subspace, highlighting
     the token features that others would search for.
   - $\mathbf{W}^{V}$: Transforms embeddings into a value subspace, showcasing
     the actual token content to be aggregated by the attention scores.

2. **Intuitive & Mathematical Interpretations**:

   - **Query Transformation** ($\mathbf{Z} \mathbf{W}^{Q}$): Intuitively, it
     tailors the raw embeddings to optimally question the rest of the sequence.
     Mathematically, it's a linear transformation of the embedding space into
     the query space, akin to a high-dimensional rotation and scaling,
     emphasizing aspects relevant to querying.

   - **Key Transformation** ($\mathbf{Z} \mathbf{W}^{K}$): Intuitively, it
     accentuates token features that other tokens might seek. Mathematically,
     it's another linear transformation emphasizing aspects that make tokens
     searchable.

   - **Value Transformation** ($\mathbf{Z} \mathbf{W}^{V}$): Intuitively, it
     prepares tokens to share their intrinsic content when beckoned by the
     attention mechanism. Mathematically, it's a linear transformation
     accentuating token content aspects.

3. **Creating Q, K, V**:

   - $\mathbf{Q} = \mathbf{Z} \mathbf{W}^{Q}$
   - $\mathbf{K} = \mathbf{Z} \mathbf{W}^{K}$
   - $\mathbf{V} = \mathbf{Z} \mathbf{W}^{V}$

   These operations recast the embedded tokens into roles for the attention
   mechanism:

   - $\mathbf{Q}$: Information seekers. The queries are seeking information, and
     the computation $Q @ K^T$ finds how much each part of the input (holder)
     should be attended to.
   - $\mathbf{K}$: Information gatekeepers. The keys hold the information being
     sought, and their arrangement in space defines the subspace that the
     queries are projected onto to find these relevance scores.
   - $\mathbf{V}$: Information providers. The values contain the content that
     needs to be retrieved, and once we have the attention weights, we know how
     much of each value to retrieve and combine to form the output.

   Mathematically, the resulting matrices ($\mathbf{Q}, \mathbf{K}, \mathbf{V}$)
   have rows that represent different aspects (querying, key, value) of the
   original tokens.

4. **Relevance to Self-Attention**:

   The transformations set the stage for attention score calculations. In this
   step, each query vector in $\mathbf{Q}$ computes its similarity (via dot
   product) against all key vectors in $\mathbf{K}$. This score matrix reveals
   the attention weightage for each token regarding every other token in the
   sequence.

   Specifically, $\mathbf{Q} @ \mathbf{K}^T$ calculates how each token (query)
   aligns with every other token (key). It's akin to measuring the relevance of
   each word to every other word in the sequence.

   After normalizing these scores (typically with softmax), we get the attention
   weights. These weights guide how the value vectors in $\mathbf{V}$ are
   aggregated. The outcome is a new matrix where each row aggregates
   contextually relevant information from the entire sequence. This enriched
   output feeds into subsequent transformer layers for further processing.

Overall, by using the $\mathbf{W}^{Q}, \mathbf{W}^{K},$ and $\mathbf{W}^{V}$
matrices, the transformer fine-tunes its focus on inter-token relationships,
enabling the model to capture intricate contextual nuances within a given
sequence.

<img src="transformer.png" width="600">


### <a id='toc1_10_6_'></a>[AddNorm (Residual Connection + Layer Normalization)](#toc0_)

- https://www.d2l.ai/chapter_attention-mechanisms-and-transformers/transformer.html#residual-connection-and-layer-normalization
- https://nlp.seas.harvard.edu/annotated-transformer

#### <a id='toc1_10_6_1_'></a>[Residual Block](#toc0_)

A residual block takes an input $X$ and a sub-layer (or function) $f$, and computes $X + f(X)$.

```python
class ResidualBlock(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(
        self,
        x: torch.Tensor,
        sublayer: Callable[[torch.Tensor], torch.Tensor],
    ) -> torch.Tensor:
        return x + sublayer(x)
```

The intuition behind a residual block is to facilitate the training of deeper networks by providing a "shortcut" or "skip connection" that allows the gradient to be directly backpropagated to earlier layers. Essentially, in a standard deep learning model, each layer transforms its input. As the network depth increases, these transformations can degrade the network's performance, mainly due to the vanishing or exploding gradient problems. This makes it challenging to train very deep networks.

The residual block aims to address this problem. It adds the original input back to the output of the network layer, forming $F(x) + x$ instead of just $F(x)$. Mathematically, if $x$ is the input and $F(x)$ is the transformed version, then the residual block computes $F(x) + x$.

This architecture has a few advantages:

1. **Easier Learning**: During training, if the best transformation is an identity map (i.e., the output should be the same as the input), the residual block can easily learn this. The layers in $F(x)$ only need to learn to approximate zero in this case, which is generally easier than learning an identity map in a traditional stack of layers.

2. **Mitigating Vanishing/Exploding Gradients**: The skip connections provide an unobstructed path for the gradients to flow, which can help mitigate the vanishing or exploding gradient problems in very deep networks.

3. **Enabling Deeper Networks**: Because of the above advantages, residual blocks make it possible to train very deep networks effectively. Deep networks can represent very complex functions, which can be advantageous for many tasks.

4. **Parameter Efficiency**: Residual blocks often require fewer parameters to achieve similar performance compared to traditional deep networks, making them more parameter-efficient.

In summary, the residual block is a simple yet effective idea that has enabled the training of much deeper networks, thereby pushing the boundaries of what is achievable in various machine learning tasks.

#### <a id='toc1_10_6_2_'></a>[Layer Normalization](#toc0_)

Layer normalization normalizes the features across the feature dimension. Given the feature $X$ with shape $[B, L, D]$ (where $B$ is the batch size, $L$ is the sequence length, and $D$ is the feature dimension), layer normalization computes:

$$
\text{Norm}(X) = \frac{X - \text{mean}(X)}{\sqrt{\text{var}(X) + \epsilon}} \times \gamma + \beta
$$

Where $\gamma$ and $\beta$ are learnable parameters and $\epsilon$ is a small constant for numerical stability.

```python
class LayerNorm(nn.Module):
    def __init__(self, feature_dim: int, eps: float = 1e-6) -> None:
        super().__init__()
        # fmt: off
        self.gamma = nn.Parameter(torch.ones(feature_dim))
        self.beta  = nn.Parameter(torch.zeros(feature_dim))
        self.eps   = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        mean = x.mean(dim=-1, keepdim=True)
        std  = x.std(dim=-1, keepdim=True)
        # fmt: on
        return self.gamma * (x - mean) / (std + self.eps) + self.beta
```

#### <a id='toc1_10_6_3_'></a>[Combining Both](#toc0_)

Finally, you can combine these into a single block, much like the `ResidualConnection` or `AddNorm` classes you mentioned earlier.

```python
class AddNorm(nn.Module):
    def __init__(self, feature_dim, dropout_rate):
        super(AddNorm, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.layer_norm = LayerNorm(feature_dim)

    def forward(self, x, sublayer_output):
        return self.layer_norm(x + self.dropout(sublayer_output))
```

This `AddNorm` class applies dropout to the output of the sub-layer, adds it to the original input, and then applies layer normalization. Note that this version doesn't include an embedded layer normalization operation in the residual block; instead, it utilizes a separate layer normalization class, which is then used in the `AddNorm` class.

### <a id='toc1_10_7_'></a>[How Loss is Computed?](#toc0_)

The unreduced loss for the Cross Entropy calculation is given by:

$$
\mathcal{L}(\mathcal{X}, \mathcal{Y}) = \{l_1, \ldots, l_N\}^\top, \quad l_n = -\mathcal{W}_{\mathcal{Y}_n} \cdot \log \left( \frac{\exp(\mathcal{X}_{n, \mathcal{Y}_n})}{\sum_{c=1}^\mathcal{C} \exp(\mathcal{X}_{n, c})} \right) \cdot \mathbb{1}\{\mathcal{Y}_n \neq \text{ignore\_index}\}
$$

where:

- $\mathcal{X}$ is the input tensor of logits, with shape \([B, d_1, \ldots,
  d_K, \mathcal{C}]\) where $\mathcal{C}$ is the number of classes and
  $[d_1, \ldots, d_K]$ represent any additional dimensions.
- $\mathcal{Y}$ is the target tensor of class indices, with shape \([B, d_1,
  \ldots, d_K]\).
- $\mathcal{W}$ is a tensor of weights corresponding to class indices.
- $N$ is the product of the batch size and any additional dimensions, i.e.,
  $N = B \times d_1 \times \ldots \times d_K$. It spans all elements in the
  batch and across the additional dimensions, effectively flattening these into
  a single dimension for the loss calculation.

For the reduced loss, the calculation depends on the reduction method ('mean' or
'sum'). The mean reduction averages the loss over all $N$ elements, while the
sum reduction simply sums over them:

$$
\mathcal{L}(\mathcal{X}, \mathcal{Y}) =
\begin{cases}
\sum_{n=1}^N \left( \frac{l_n}{\sum_{n=1}^N \mathcal{W}_{\mathcal{Y}_n} \cdot \mathbb{1}\{\mathcal{Y}_n \neq \text{ignore\_index}\}} \right), & \text{if reduction = 'mean'}\\
\sum_{n=1}^N l_n, & \text{if reduction = 'sum'}
\end{cases}
$$

This formulation emphasizes that the loss is computed element-wise for each
class index in the target tensor $\mathcal{Y}$, and then either summed or
averaged depending on the chosen reduction method. The indicator function
$\mathbb{1}\{\}$ ensures that the ignore_index is not considered in the loss
computation.


1. **Define the Loss Function**: The `nn.CrossEntropyLoss` function:
    - `nn.CrossEntropyLoss` in PyTorch expects the input logits to be of shape
    `[N, C, d1, d2, ..., dK]` (where `N` is the batch size, `C` is the number of
    classes, and `d1` to `dK` are optional additional dimensions) and the target
    to be of shape `[N, d1, d2, ..., dK]`.
    - Let's look a simplified example in image classification. The target is a
    single integer representing the class label, and the input logits are a
    vector of length `C` (the number of classes).

In [None]:
rng = torch.Generator().manual_seed(config.global_config.seed)
criterion = nn.CrossEntropyLoss(reduction="mean")
targets = torch.tensor([1, 0, 0, 0]) # indicating sample 1 is class 1 and sample 2 is class 0
logits  = torch.tensor([[0.1, 0.9], [0.9, 0.1], [0.8, 0.2], [0.3, 0.7]])
loss   = criterion(logits, targets)
pprint(loss)

Here things are simple, because the target is a single integer representing the class label, and the input logits are a vector of length `C` (the number of classes).

The confusion arises when the target is a sequence of integers, as in the case of sequence-to-sequence prediction. In this case, the target is a sequence of integers representing the class labels, and the input logits are a sequence of vectors of length `C` (the number of classes).

Let's walk through an example for concrete understanding.

Consider the following example:

- Batch size: 2
- Sequence length: 3
- Number of classes/Vocab size: 4
- Targets is of shape: `[B, L] = [2, 3]`
- Logits is of shape: `[B, L, V] = [2, 3, 4]` where `V` is `C` in the above definition.

In [None]:
# fmt: off
rng        = torch.Generator().manual_seed(config.global_config.seed)

B, L, V    = 2, 3, 4                                                   # Assuming we have B = batch size, L = sequence length, V = vocab size

logits     = torch.randn(B, L, V, generator=rng)                       # logits from the head
targets    = torch.randint(low=0, high=V, size=(B, L), generator=rng)  # targets are the labels
# fmt: on

pprint(logits)
pprint(targets)
pprint(logits[0]) # logits for the first sequence [L=10, V=18]
pprint(targets[0]) # target for the first sequence [L=10]

We establish some conceptual understanding first:

- Each sample in the batch has the following characteristics:
    - Denote `target` and `logit` as the target and logits for a particular sample in the batch.
    - The `target` is of shape `[L] = [3]` and each element is the class/vocab label for each token in the sequence.
    - The `logit` is of shape `[L, V] = [3, 4]` and each row is the logits for each token in the sequence.
    - Therefore, we want to compare each row in `logit` with each element in `target` to compute the loss.
    - We can think of each row in `logit` as the prediction for each token in the sequence, and each element in `target` as the ground truth for each token in the sequence.
    - Intuitively this means that within each sample, there are many "sub-samples" where each sub-sample is a token in the sequence. If you can visualize this, then there should be no confusion.
- In code, we can do so with the following manner:
    - Calculate loss for each token in each sample individually and then sum them up.
    - Reduction by mean will mean we need to divide our `total_loss` by the total number
        of samples in the batch. But remember that even though technically we have
        2 samples in the batch, we are actually treating each token in each sample
        as a sub-sample, so the total samples is `B * L` where `B` is the batch size
        and `L` is the sequence length.

In [None]:
criterion  = nn.CrossEntropyLoss(reduction="mean")

total_loss = 0
for b in range(B):
    for l in range(L):
        logit      = logits[b, l].unsqueeze(0)
        target     = targets[b, l].unsqueeze(0)
        total_loss += criterion(logit, target)

pprint(total_loss)
total_loss  = total_loss / (B * L)
pprint(total_loss)

In PyTorch however, if you have a logits tensor of shape `[B, S, V]`, you need to permute it to
  `[B, V, S]` to align with the format that `CrossEntropyLoss` wants, so that `V` (vocab size) is
  treated as `C` (number of classes), and `S` (sequence length) is treated as
  one of the additional dimensions `d1, d2, ..., dK`.

But all in all, if you understood the previous loop to calculate the loss for each token in each sample individually and then sum them up, then dividing to fulfill reduction of mean, then you should be fine.

In [None]:
# Permute logits to shape [B, V, S]
logits_permuted = logits.permute(0, 2, 1)

# Instantiate the CrossEntropyLoss
# By default, it reduces by averaging the losses over each observation in the input
criterion = nn.CrossEntropyLoss(reduction="mean")

loss = criterion(logits_permuted, targets)
pprint(loss)

#### Masking and Ignore Index

In [None]:
# fmt: off
rng        = torch.Generator().manual_seed(config.global_config.seed)

B, L, V    = 2, 3, 4                                                   # Assuming we have B = batch size, L = sequence length, V = vocab size

logits     = torch.randn(B, L, V, generator=rng)                       # logits from the head
targets    = torch.randint(low=0, high=V, size=(B, L), generator=rng)  # targets are the labels
# fmt: on

pprint(logits)
pprint(targets)
pprint(logits[0]) # logits for the first sequence [L=10, V=18]
pprint(targets[0]) # target for the first sequence [L=10]

In [None]:
targets[:, 0] = -123

In [None]:
targets

In [None]:
PAD_ = -123

In [None]:
criterion  = nn.CrossEntropyLoss(reduction="mean", ignore_index=PAD_)

NON_IGNORE_COUNT = 0

total_loss = 0
for b in range(B):
    for l in range(L):
        logit      = logits[b, l].unsqueeze(0)
        target     = targets[b, l].unsqueeze(0)
        if target == torch.tensor([PAD_]):
            continue
        total_loss += criterion(logit, target)
        NON_IGNORE_COUNT += 1

pprint(total_loss)
total_loss  = total_loss / NON_IGNORE_COUNT
pprint(total_loss)

NOTE: `NON_IGNORE_COUNT` is used instead of `BxL`, why? Cause we are averaging over
all non-ignored guys!

In [None]:
# Permute logits to shape [B, V, S]
logits_permuted = logits.permute(0, 2, 1)

# Instantiate the CrossEntropyLoss
# By default, it reduces by averaging the losses over each observation in the input
criterion  = nn.CrossEntropyLoss(reduction="mean", ignore_index=PAD_)

loss = criterion(logits_permuted, targets)
pprint(loss)

#### Why mask our target in Adder?

Well simply put, we do not care what the model predict for anything before the equal sign.

For example

```
12+97=109
```

and still the 

```
x = [BOS,1,2,+,9,7,=,1,0,9]
y = [1  ,2,+,9,7,=,1,0,9,EOS]
```

requires us to predict tokens given say, BOS, given say, 1,2,+,9, 7. What we want is
for it to predict what is next after `=`, so earlier guys all ignore.

By masking out (or ignoring) the tokens before the =, you are guiding the model to focus on learning to predict the result of the addition operation, starting from the = sign.


## <a id='toc1_11_'></a>[Potential to use Module Dict?](#toc0_)

In [None]:
class ModelModuleDict(nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = nn.ModuleDict({
            'fc1': nn.Linear(2, 5),
            'relu': nn.ReLU(),
            'fc2': nn.Linear(5, 1)
        })

    def forward(self, x):
        for layer in self.layers.values():
            x = layer(x)
        return x

# Initialize a random tensor as input
input_tensor = torch.randn(1, 2)

In [None]:
seed_all(1, seed_torch=True)
model_sequential = nn.Sequential(
    nn.Linear(2, 5),
    nn.ReLU(),
    nn.Linear(5, 1)
)
# Forward pass using nn.Sequential model
model_sequential(input_tensor)


In [None]:
seed_all(1, seed_torch=True)
model_moduledict   = ModelModuleDict()
model_moduledict(input_tensor)

## <a id='toc1_12_'></a>[Training with GPT-like Model](#toc0_)

If you're working with a GPT-like model, which is a decoder-only architecture, the training mechanics differ slightly compared to the encoder-decoder models like seq2seq. In a GPT-style model, the entire sequence (input and output) is provided to the model at once, and each token is predicted based on the tokens that came before it. The model is still autoregressive, but there's no separate encoder to produce an intermediate representation; the "encoding" is effectively built into the ongoing autoregressive decoding process.

In your case, if the equations are like `90+38=128`, during training you'd provide `90+38=` as the input and then use the remaining part `128` as the expected output, potentially along with special tokens to demarcate sequence boundaries or to flag the equation/result parts. However, unlike an encoder-decoder model where the decoder gets to "peek" at the correct output during training (also known as "teacher forcing"), here every token in the output is predicted one by one, based solely on the preceding tokens.

In such a setup, you can definitely feed the entire equation to the model and try to predict each subsequent token based on the preceding tokens. For example, given `90+38=`, the model should predict `1`, `2`, `8` in succession.

### <a id='toc1_12_1_'></a>[Loss Computation](#toc0_)

For training a GPT-like model, you'd usually use a standard loss function like cross-entropy loss for each token's prediction. You'd compare the token predicted by the model to the actual token in the target sequence to compute the loss. This is calculated for each token and then averaged over the sequence or batch, depending on your implementation.

### <a id='toc1_12_2_'></a>[Example](#toc0_)

In a GPT-like model, each token in the sequence is used to predict the next token. The model takes a sequence of tokens and produces a new sequence of the same length where each new token is predicted based on all the preceding tokens in the input sequence. The loss is then computed between the predicted sequence and the target sequence.

Let's take a closer look at an example:

- The original tensor: `[15, 9, 0, 10, 3, 8, 13, 1, 2, 8, 14]` which corresponds to `<SOS>90+38=128<EOS>`
- Input tensor:  `[15, 9,  0,  10, 3,  8,  13, 1,  2, 8]`, which corresponds to `<SOS>90+38=128` without `EOS`
- Target tensor:     `[9,  0,  10, 3,  8,  13, 1,  2,  8, 14]`
                     `[16, 16, 16, 16, 16, 16, 1,  2,  8, 14]`

During training:

1. **First Timestep**: The model takes `[15]` (or `[<BOS>]` if 15 is your BOS token) and tries to predict the next token. Ideally, it should predict `9`. But here, your target sequence starts with masked tokens (`16`, if 16 is your masking token). So the loss is computed between the predicted token and the masked token `16`. But since `CrossEntropyLoss` has an `ignore_index` (now you know what they are right!), you can set it to say `16` or (default `-1` but you would need to change padding number) and tell the model that whenever the ground truth is `16`, the loss
is zeroed out so it is not counted? This allows the model to focus on learning from the relevant parts of the sequence while ignoring the masked portions.

2. **Second Timestep**: The model takes `[15, 9]` and predicts the next token, which should be `0`. Again, the target is a masked token `16`.

3. **...**

4. **Eighth Timestep**: The model takes `[15, 9,  0,  10, 3,  8,  13]` (which is `90+38=`) and predicts the next token. Now the target is `1`, so the loss is computed between the predicted token and `1`. There is no mask anymore here, so the loss will be computed.
5. **Ninth Timestep**: The model takes `[15, 9,  0,  10, 3,  8,  13, 1]` (which is `90+38=1`) and predicts the next token. Now the target is `2`, so the loss is computed between the predicted token and `2`.
   1. Here's an important thing for beginners (me), In a typical GPT-like architecture used for sequence-to-sequence tasks like this one, the model doesn't use its own predictions as input during training. Instead, it uses the original, ground-truth input sequence. This is known as "teacher forcing." In teacher forcing, even if the model predicts a wrong token at some timestep, it doesn't affect the input sequence for subsequent timesteps. The model continues to get the original input sequence for the entire training epoch.
   2. So if model predicts a `3` during the eighth timestep, where the ground trut is `1`, the model would simply incur a higher loss for that prediction. However, the input for the ninth timestep would still be the ground truth sequence up to that point, regardless of what the model predicted at the eighth timestep.
   3. But it is noted that this behaviour is still autoregressive.
6. **Tenth**: The model takes `[15, 9,  0,  10, 3,  8,  13, 1, 2]` and predicts the next token which is `8`.
7. **Last**: The model takes `[15, 9,  0,  10, 3,  8,  13, 1, 2, 8]` and predicts the next token which is `14` the `EOS`.
   1. The reason you need to predict `EOS` is simple intuitively, consider the case where there's no need for `EOS`, then the model will not know when to stop.

This goes on until the entire sequence is processed. Note that the model never actually "sees" the target tokens during the prediction. It is solely relying on the tokens that came before the current token in the input sequence. After the model makes its prediction, then the predicted tokens are compared to the target tokens to compute the loss, which is then backpropagated to update the model weights.

### <a id='toc1_12_3_'></a>[Confusion: Training versus Inference](#toc0_)

The statement "it generates one token at a time and uses its own previously generated tokens as context for generating subsequent tokens" is generally true for GPT-like models during the inference stage, not during training. During inference (or generation), the model does indeed use its own previously generated tokens to produce the next token, since there is no ground truth sequence to rely on. In that case, if the model makes an incorrect prediction at a certain timestep, that incorrect token is used as part of the context for the following timestep.

During training, however, the model typically uses the ground truth tokens for the preceding sequence as context for predicting each next token, as described in your example. This resembles teacher forcing, in that the ground truth, rather than the model's own predictions, is used to guide training.

So there's no contradiction, but the behavior is context-dependent:

- During training, the ground truth sequence is used for context.
- During inference, the model's own previously generated tokens are used for context.

Both approaches are consistent with the autoregressive nature of the model: in both cases, the token at each position is generated based on the tokens at all previous positions. The difference lies in whether those preceding tokens come from the ground truth (during training) or from the model's own previous outputs (during inference).

### <a id='toc1_12_4_'></a>[Training vs Inference](#toc0_)

In an autoregressive model like a Transformer decoder, the concept of "learning
the representation of the sequence as it goes" does not refer to the model
processing one token at a time during actual forward passes. Instead, it refers
to the model's ability to generate or predict one token at a time during
inference, while training on a full sequence in a batched manner.

During training:

- All tokens are processed in parallel for efficiency. This is possible because
  the entire sequence is known beforehand (it's the training data).
- The "autoregressive" property is enforced by using masks in the self-attention
  mechanism. This masking ensures that the prediction for each token can only
  depend on previously generated tokens, not on future tokens which the model
  has no access to during inference. This is how the model learns the
  conditional probability distribution of each token given the previous tokens,
  despite the parallel processing of tokens.

During inference:

- The model starts with an initial token (such as a start-of-sequence token) and
  generates the next token based on this single input.
- Then, the model uses both the initial token and the newly generated token to
  predict the third token, and so on.
- This process is sequential and each new token is predicted based on the
  previously generated tokens, creating a sequence one token at a time.

So, when we say that the model learns the representation of the sequence as it
goes, we mean that the model is trained to handle sequences in such a way that
it can generate them one piece at a time, respecting the causal order inherent
to the task (e.g., language modeling). The parallel processing during training
does not contradict the autoregressive nature of the model; it is simply a
computational efficiency that is enabled by knowing the full sequence in
advance.


## <a id='toc1_13_'></a>[Questions](#toc0_)

### <a id='toc1_13_1_'></a>[Why Masked == 0 in some?](#toc0_)

The use of `mask == 0` in the `masked_fill` operation is a result of how the mask is constructed. Essentially, different implementations may represent masks differently:

1. **Boolean Masking with True/False**: In some implementations, the mask might be a Boolean tensor where `True` denotes the positions to mask (set to negative infinity) and `False` for the positions to keep. In such cases, you can directly use the mask in `masked_fill` as in your provided code:

    ```python
    attention_scores = attention_scores.masked_fill(mask, float("-inf"))
    ```

    Here, if `mask[i][j]` is `True`, `attention_scores[i][j]` would be set to `-inf`.

2. **Integer Masking with 1/0**: In other implementations, the mask might be an integer tensor where `1` denotes the positions to keep and `0` denotes the positions to mask. In such cases, you'll often find the mask is inverted (`mask == 0`) before using `masked_fill`:

    ```python
    attention_scores = attention_scores.masked_fill(mask == 0, float("-inf"))
    ```

    Here, if `mask[i][j]` is `0`, `attention_scores[i][j]` would be set to `-inf`.

The core functionality—masking certain positions in the attention scores—is the same in both cases. The difference lies in how the mask tensor is constructed and interpreted. So, if you find an implementation using `mask == 0`, it's likely using an integer mask where `0` signifies positions to mask, whereas if it's directly using `mask`, it's probably a Boolean mask where `True` signifies positions to mask.

### <a id='toc1_13_2_'></a>[what is the reason of setting the attention scores's mask indexes to negative infinity](#toc0_)


In the attention mechanism, particularly in the Scaled Dot-Product Attention, attention scores are computed for each query-key pair and then passed through a softmax function to obtain attention weights. These weights are used to take a weighted sum of the value vectors, resulting in the final output or the context vectors. The purpose of the mask is to prevent certain tokens (like padding tokens) from being attended to.

The reason for setting masked attention scores to negative infinity (`-inf`) lies in the properties of the softmax function:

1. **Softmax Behavior**: The softmax function transforms its input (the attention scores in this case) into a probability distribution. Mathematically, the softmax function for a given vector $x$ is defined as:

$$
\text{Softmax}(x)_i = \frac{e^{x_i}}{\sum_{j=1}^{N} e^{x_j}}
$$

2. **Impact of Negative Infinity**: When you pass negative infinity through the softmax function, $e^{-\infty}$ approaches zero. As a result, the masked positions get a near-zero weight in the attention mechanism.

$$
\text{Softmax}(-\infty) = \frac{e^{-\infty}}{\sum_{j=1}^{N} e^{x_j}} \approx 0
$$

3. **Avoiding Unwanted Attention**: The point of setting these specific positions to `-inf` is to ensure that when softmax is applied, these positions get zero attention weights. This is a way of making sure that the model does not attend to the positions we've masked (like padding tokens or future tokens in the sequence, depending on the mask).

In summary, setting the masked attention scores to `-inf` and then passing them through a softmax effectively nullifies the contribution of the masked positions in the resulting attention-weighted sum of the value vectors. This is a commonly used trick to impose a certain structure (like masking out future information in the decoder) or to handle variable-length sequences with padding.

### <a id='toc1_13_3_'></a>[Why do we need both ignore index in Loss and also negative infinity mask](#toc0_)

Using an "ignore index" in the `CrossEntropyLoss` function in PyTorch can ignore the effect of certain tokens (like padding tokens) during the loss computation. However, the purpose of the mask in the attention mechanism and the "ignore index" in the loss function serve different roles in the model, and they operate at different stages of the computational graph.

1. **Ignore Index in Loss Function**: The "ignore index" in the loss function ensures that the model's output at certain positions (typically corresponding to padding tokens) does not contribute to the loss. This happens at the very end of the forward pass, just before backpropagation begins.

2. **Mask in Attention Mechanism**: The mask in the attention mechanism, on the other hand, operates during the forward pass at the time when attention scores are computed. This is a more "internal" operation and ensures that certain positions do not contribute to the output at all, not just during the loss computation but actually in the intermediate representations (i.e., context vectors) that the model computes.

To put it another way, even if you're ignoring certain tokens in your loss calculation, those tokens can still influence the model's output unless they're masked out in the attention mechanism itself.

For example, consider a decoder in a sequence-to-sequence model:
- If you don't use a mask in the attention mechanism, future tokens could influence the output at the current timestep, which is not desirable.
- Even if you use an "ignore index" in your loss function, it doesn't prevent the model from "cheating" by peeking at the future tokens if they are not masked in the attention mechanism.

So in summary, using an "ignore index" in `CrossEntropyLoss` is not a replacement for using attention masks. Both have specific roles in the model, and they are often used together to ensure both that the model attends to the right tokens and that it is trained properly.

### <a id='toc1_13_4_'></a>[Target and Preds/Logits Shape](#toc0_)

The target tensor for the cross-entropy loss function should typically have a shape of `[batch_size, sequence_length]` where each entry in the tensor is an integer representing the index of the true class (i.e., the actual word/token from the vocabulary) for that position in the sequence. Here `batch_size` refers to the number of sequences in each batch, and `sequence_length` is the length of each sequence.

Let's break it down step-by-step:

1. **Last Linear Layer of Decoder**: When you say that the last linear layer of your decoder has shape `[bs, vocab_size]`, it means that for each example in the batch, you're outputting a distribution over the vocabulary. The values can be logit scores that represent the likelihood of each word in your vocabulary being the next word in the sequence.

2. **Target Shape**: In comparison, your target tensor should contain the actual words (as integers) that appear at each position in your sequence for each example in the batch. The target tensor does not need to have a `vocab_size` dimension because it is not a distribution; it contains the indices of the actual next words. Thus, it should have a shape `[bs, sequence_length]`.

3. **Cross-Entropy Loss**: When using the cross-entropy loss, the logits (i.e., the output from your linear layer) should have a shape `[bs, sequence_length, vocab_size]`, while the target should have a shape `[bs, sequence_length]`. The cross-entropy loss function will internally apply a softmax to the logits, and then compute the log-likelihood between the predicted distribution and the target class.

To sum up, if your decoder's last linear layer has shape `[bs, vocab_size]` for each time step, make sure that your target tensor has the shape `[bs, sequence_length]`, and your logits should be `[bs, sequence_length, vocab_size]` when you feed them into the cross-entropy loss function.

### <a id='toc1_13_5_'></a>[Why do we flatten prediction and target (logits)?](#toc0_)

Flattening both the predicted logits and the target labels serves a specific purpose when using the cross-entropy loss function for sequence data. Let's dig into each component to understand why this is done:

#### <a id='toc1_13_5_1_'></a>[Background](#toc0_)

1. **Logits Tensor**: In a sequence-to-sequence model, you usually generate a sequence of logits for each item in your batch. The logits for each position in the sequence form a vector of size `vocab_size`, which gives you a probability distribution across all possible tokens.
  
   Shape: `[batch_size, sequence_length, vocab_size]`

2. **Targets Tensor**: Your ground truth data, the `targets`, are integers representing the correct class labels (or tokens) at each sequence position.

   Shape: `[batch_size, sequence_length]`

#### <a id='toc1_13_5_2_'></a>[Traditional Loss Computation](#toc0_)

Typically, the cross-entropy loss between predicted probabilities and target labels for one data point is computed, and then you average over all data points. In sequence-to-sequence models, you can think of each position in the sequence as a separate data point.

#### <a id='toc1_13_5_3_'></a>[Why Flatten?](#toc0_)
1. **Batch and Sequence Unification**: The idea of flattening both logits and targets is to treat each `(batch, sequence_position)` pair as an independent data point. Instead of having a batch of sequences, you have a "flattened" batch of tokens. This simplifies the application of the loss function by converting the 3D logits tensor and 2D targets tensor into 2D and 1D tensors, respectively.

2. **Efficiency**: Loss computations often benefit from vectorization for computational efficiency. By flattening the tensors, you enable a more efficient matrix operation, which is generally faster than using nested loops over each sequence and batch.

3. **Alignment**: The key is to ensure that each row in the flattened logits corresponds to the same position in the flattened targets. This alignment is crucial for the correct computation of the loss.

#### <a id='toc1_13_5_4_'></a>[Step-by-step Flattening](#toc0_)

1. **Logits Flattening**: `logits.view(-1, logits.size(-1))` will take the 3D tensor `[batch_size, seq_length, vocab_size]` and reshape it into a 2D tensor of shape `[batch_size * seq_length, vocab_size]`.

2. **Targets Flattening**: `targets.view(-1)` will take the 2D tensor `[batch_size, seq_length]` and convert it into a 1D tensor of shape `[batch_size * seq_length]`.

3. **Loss Calculation**: Both flattened tensors are then used in the cross-entropy loss function. The loss between each row in the flattened logits and the corresponding element in the flattened targets is computed.

By flattening the tensors this way, you maintain the correspondence between each logit and its corresponding target, enabling you to correctly compute the loss for each token across all sequences and batches.

### <a id='toc1_13_6_'></a>[Why sometimes unsqueeze masks?](#toc0_)

The `unsqueeze` operation is used to add an additional dimension to the tensor. In attention mechanisms, particularly the scaled dot-product attention used in models like the Transformer, the masks usually need to have the same number of dimensions as the attention logits for proper broadcasting.

For instance, let's say your source tensor (`src`) has a shape of $B \times L$ where $B$ is the batch size and $L$ is the sequence length. The attention logit tensor resulting from the query-key dot product would then have shape $B \times N \times L \times L$, where $N$ is the number of attention heads.

The mask needs to align with the $L \times L$ dimensions of this 4D tensor. In order to accomplish that, you add singleton dimensions to make it compatible with the attention logit tensor. By unsqueezing the mask tensor from $B \times L$ to $B \times 1 \times 1 \times L$, you enable broadcasting such that the mask effectively gets expanded to $B \times N \times L \times L$ during the attention calculation, perfectly aligning with the attention logits.

That's why the line:
```python
self.src_mask = (src != pad).unsqueeze(-2)
```
adds a singleton dimension, converting the shape from $B \times L$ to $B \times 1 \times 1 \times L$ for proper broadcasting during the attention computations.

### <a id='toc1_13_7_'></a>[Why does sequence length differ for source and target, usually I thought it is just all L, same.](#toc0_)

...

### <a id='toc1_13_8_'></a>[Am i right to assume that the core idea of autoregressive model like decoder only (GPT like) is that for a given sample, there will eventually be L rows where L is the seq length, and therefore I can intuitively view it as 1 sample having L samples, since for each row, we will compute the loss. Am I right in my understanding? Do not hesistate to correct me.](#toc0_)


Your description captures an important aspect of autoregressive models like decoder-only Transformers (e.g., GPT). Specifically, you're talking about how the model treats a sequence during training. Let's break down your understanding step by step.

1. **Sequence Length**: When you mention "L rows," where "L" is the sequence length, you're essentially pointing out that each sequence is divided into "L" time steps (or tokens). Each time step becomes an input-output pair for training the model.

2. **One Sequence as Multiple Samples**: You're correct to intuit that a single sequence of length "L" can be treated like "L" samples, at least in the context of loss calculation. This is because, during training, the model computes the loss at each time step by comparing the predicted token with the actual next token in the sequence.

3. **Loss Computation**: The loss is often computed at each position and then averaged over the sequence length or summed up, depending on the specific loss function or training regime.

However, it's crucial to clarify that although a single sequence may contribute "L" terms to the loss function, this is not equivalent to having "L" independent samples. The key difference lies in the autoregressive property: the prediction at each time step is conditioned on the preceding tokens. This introduces a temporal dependency across the "L" positions, making them not entirely independent samples.

In other words, while it's accurate to say that a single sequence contributes multiple terms to the loss function, these terms are correlated because they come from the same sequence and are generated in an autoregressive manner.

To summarize, you're mostly correct in your understanding that a single sequence is broken down into multiple steps for the purpose of loss computation, but it's important to remember that these steps are not independent samples due to the autoregressive nature of the model.





## Some Implementation Details

```
Performs one decoder forward pass given encoder hidden states, the decoder input tokens and attention masks.
B = batch size
S = source sequence length
T = target sequence length
E = embedding dimensionality
V = vocabulary size
```

### Input

Let's view input's first two samples:

```
tensor([[15,  4,  9, 10,  1,  3, 13,  0,  6,  2],
│   │   [15,  3,  5, 10,  4,  6, 13,  0,  8,  1]])
```

which is

-   shape is `[2, 10]` which is `BxL`.
-   `49+13=62` but no `EOS` as we truncated last token.
-   `35+46=81` but no `EOS` as we truncated last token.

### Positional Encodings

#### Why do we hardcode batch size of 1 when creating P?

The tensor $P$ for positional encoding is initialized with a batch size of 1.
This makes it easy to add to the actual input sequences later, during the
forward pass. Positional encodings are not dependent on the specific input
sequence but are a function of the position within the sequence. Therefore, they
can be precomputed and stored. When you look at the forward pass:

```python
def forward(self, Z: torch.Tensor) -> torch.Tensor:
    Z = self._add_positional_encoding(Z)
    return self.dropout(Z)
```

and the `_add_positional_encoding` method:

```python
def _add_positional_encoding(self, Z: torch.Tensor) -> torch.Tensor:
    """Add the positional encoding tensor to the input tensor."""
    return Z + self.P[:, : Z.shape[1], :].to(Z.device)
```

You'll see that $P$ is sliced to match the sequence length of $Z$ and then added
to $Z$. Because of broadcasting rules in PyTorch, $P$ will automatically be
broadcasted to the batch size of $Z$ during this addition. This is why $P$ is
initialized with a batch size of 1; it keeps the implementation flexible while
making the broadcasting implicit.

#### Why do we register P as a buffer in PyTorch?

In your `PositionalEncoding` class, the tensor `self.P` holds the pre-computed
positional encodings. If you intend for this tensor to be automatically moved to
the correct device when the module is moved, and if it should not be a learnable
parameter, then registering it as a buffer would be a good idea. This ensures
that `self.P` is part of the module's state but is not updated during
backpropagation.

You could register `self.P` as a buffer right after you initialize it in the
`_init_positional_encoding` method:

```python
def _init_positional_encoding(self) -> torch.Tensor:
    """Initialize the positional encoding tensor."""
    P = torch.zeros((1, self.max_seq_len, self.d_model))
    position = self._get_position_vector()
    div_term = self._get_div_term_vector()
    P[:, :, 0::2] = torch.sin(position / div_term)
    P[:, :, 1::2] = torch.cos(position / div_term)
    self.register_buffer("P", P, persistent=True)
    return P
```

Using `register_buffer` ensures that:

1. `self.P` is automatically moved to the device the model is moved to (e.g.,
   from CPU to GPU).
2. `self.P` is saved when you save the model using `torch.save` or `torch.load`.

The `persistent=False` argument indicates that the buffer should not be part of
the model's `state_dict`, meaning it won't be saved or loaded with the model. If
you do want it to be part of the `state_dict`, you can simply omit this
argument.

### Attention

#### Why do we call contiguous on Q, K and V?

D2L's code uses `reshape` to reshape the `Q`, `K` and `V`, where other code such
as from the Annotated Transformer uses `view`. When you use `view`, this assumes
the tensor is `contiguous`, so it is better to call `contiguous` first.

#### Why do we want to transpose Q, K, and V?

The transposition of $Q$, $K$, and $V$ in multi-head attention serves a specific
purpose: to allow for parallel computation across multiple attention heads. In
the original shape, the "heads" dimension does not exist; the tensor is simply
$B \times L \times D$, where $B$ is the batch size, $L$ is the sequence length,
and $D$ is the model dimension. By transposing, we create a new shape
$B \times H \times L \times (D/H)$, where $H$ is the number of heads. This
enables the following:

1. **Parallelization**: Each head can now be computed in parallel since each
   head operates independently of the others.
2. **Optimization**: Modern hardware accelerators like GPUs are optimized for
   certain tensor operations, and having a shape that aligns well with these
   optimizations can result in faster computation.
3. **Readability and Maintainability**: It's easier to understand and debug the
   operations for each head when they're isolated like this.

#### Why do we want to reverse transpose Q, K, and V?

After the attention scores are computed and used to weight $V$, we get a new
tensor for each head. However, these tensors are still in the transposed shape
$B \times H \times L \times (D/H)$, and they need to be concatenated and
linearly transformed to continue through the network. The reverse transposition
essentially does the following:

1. **Concatenation**: Converts the multiple heads back into a single tensor.
   This is required because subsequent layers (like feed-forward neural
   networks) expect input in the original $D$-dimensional space.

2. **Compatibility**: The rest of the neural network architecture often expects
   input tensors to have a specific shape (usually $B \times L \times D$).
   Reverse transposing ensures that the output of the multi-head attention block
   can be fed into subsequent layers without issue.

3. **Resource Efficiency**: By reducing the tensor back to its original
   dimensions, we can save memory and computational resources, which is
   beneficial when you're training large models or operating under hardware
   constraints.

In summary, the initial transposition is done to facilitate parallel computation
across heads, and the reverse transposition is done to concatenate these heads
and prepare the tensor for subsequent layers.


## Why we need Positional Vector

Positional encoding is critical cause the cat ate the mouse is the same as the
mouse ate the cat without it

Without it the attention Q and K matmul would result in a permutation invariant
matrix. So adding position info makes the last token in the attention matrix
(say mouse from the cat ate the mouse) would allow the word mouse to hold info
for every other word in the sentence as well as knowing every other token
position (including knowing it's the last token)


## <a id='toc1_14_'></a>[TODO](#toc0_)

1. Add Positional Encoding
2. Add LR Scheduler
3. Check why need to use `torch.nn.utils.clip_grad_norm_` to clip gradients
4. Why unsqueeze mask?
5. Can you init weights inside Encoder instead of outside?
6. Add Epoch and Batch State see my old code.
7. Important use `Vocab` class like in https://github.com/jsbaan/transformer-from-scratch/blob/main/vocabulary.py.

https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/tutorial6/Transformers_and_MHAttention.html



## <a id='toc1_15_'></a>[References and Further Readings](#toc0_)

- https://slds-lmu.github.io/seminar_nlp_ss20/