<a href="https://colab.research.google.com/github/nkthiebaut/guanaco/blob/main/notebooks/Guanaco_step_by_step.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Guanaco: toy Llama3 implementation

In this notebook, we'll implement the Llama3 model architecture brick by brick, starting
from a simple model that averages character embeddings that we'll apply to simple morse
code strings (vocabulary size 3). We'll then add:
- Rotary Position Encodings (RoPE)
- Self-attention
- The SwiGLU activation function
- The full self-attention block with RMSNorm

Once ready, we'll train the model on the TinyStories dataset, with text tokenized at the
byte level.

Note that the actual Llama3 implementation:
- uses the [Tiktoken tokenizer](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py)
- has Key-Value caching
- uses Grouped query attention


## First step: morse code modeling

In [None]:
%pip install -q datasets einops lightning

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/542.0 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/542.0 kB[0m [31m3.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━[0m [32m471.0/542.0 kB[0m [31m6.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m542.0/542.0 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m14.7 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m12.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m

## Modeling imports

In [None]:
import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
import lightning as L

device = "cuda" if torch.cuda.is_available() else "cpu"

### Toy model: Average Embeddings

Let's start with a simple model that predicts the next token in sequences of Morse code (three characters: dots ".", dashes "-", and spaces " ", encoded as 1, 2, and 0 respectively).

The morse alphabet looks like this:
```python
morse_code_dict = {
    'A': '.-', 'B': '-...', 'C': '-.-.', 'D': '-..', 'E': '.', 'F': '..-.',
    'G': '--.', 'H': '....', 'I': '..', 'J': '.---', 'K': '-.-', 'L': '.-..',
    'M': '--', 'N': '-.', 'O': '---', 'P': '.--.', 'Q': '--.-', 'R': '.-.',
    'S': '...', 'T': '-', 'U': '..-', 'V': '...-', 'W': '.--', 'X': '-..-',
    'Y': '-.--', 'Z': '--..',
    '0': '-----', '1': '.----', '2': '..---', '3': '...--', '4': '....-', '5': '.....',
    '6': '-....', '7': '--...', '8': '---..', '9': '----.'
}
```
and, for instance, `"Houston, we have a problem"` reads: `.... --- ..- ... - --- -. --..-- ....... .-- . ....... .... .- ...- . ....... .- ....... .--. .-. --- -... .-.. . -- .-.-.-`.

Each sequence of length $T$ contains $T-1$ example, e.g. `"..-."` gives the following 3 (input, output) pairs: `((".", "."), ("..", "-"), ("..-", "."))`.

We use a character-level tokenization and map each character to continuous embeddings $x$.

Instead of creating numerous examples for a single input, it is more efficient to use a mask matrix like this

$$
\text{Mask}=
\begin{pmatrix}
1 & 0 & 0 & \cdots & 0 \\
1 & 1 & 0 & \cdots & 0 \\
1 & 1 & 1 & \cdots & 0 \\
\vdots & \vdots & \vdots & \ddots & \vdots \\
1 & 1 & 1 & \cdots & 1
\end{pmatrix}
$$
that is applied to the attention map pre-softmax, to make sure we don't attend to future tokens in the sequence.

In the example above, the `"..-."` input is tokenized as `[1, 1, 2, 1]`, then converting to randomly initialized embeddings
`X = [[0.2, -0.6], [0.2, -0.6], [0.4, 0.1]]` , and the targets are `y=[1, 2, 1]` (we drop the last input and the first target).

If we apply the mask directly to the inputs then
$$\text{Mask} \; X= \begin{pmatrix}
1 & 0 & 0 \\
1 & 1 & 0 \\
1 & 1 & 1
\end{pmatrix}\begin{pmatrix}
0.2 & -0.6 \\
0.2 & -0.6 \\
0.4 & 0.1 \\
\end{pmatrix}= \begin{pmatrix}
0.2 & -0.6 \\
0.4 & -1.2 \\
0.8 & -1.1
\end{pmatrix}
$$

Thanks to the application of mask, only previous tokens are averaged out when predicting the next token. We can use

$$
\text{Mask} \; X = \begin{pmatrix}
0.2 & -0.6 \\
0.4 & -1.2 \\
0.8 & -1.1
\end{pmatrix} \quad \text{as inputs and} \qquad
y=\begin{pmatrix}1\\ 2\\ 1\end{pmatrix} \quad \text{as targets,}
$$
without leaking future information.

In [None]:
class BaselineModel(nn.Module):
    def __init__(self, vocab_size: int, n_dim_emb: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_dim_emb)
        self.output_layer = nn.Linear(n_dim_emb, vocab_size, bias=False)

        self.vocab_size = vocab_size
        self.n_dim_emb = n_dim_emb

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        B, T = inputs.shape  # B: batch_size, T: sequence_length ("timesteps")

        embeddings = self.embedding(inputs)  # (B, T) -> (B, T, C)

        # Mask out future tokens
        mask = torch.tril(torch.ones(T, T)).to(embeddings.device)
        embeddings = mask @ embeddings

        logits = self.output_layer(embeddings)  # (B, T, C) -> (B, T, V)

        return logits

def compute_loss(logits: torch.tensor, targets: torch.tensor) -> float:
    logits = rearrange(logits, 'B T V -> (B T) V') # using https://einops.rocks/
    # Equivalent to logits = logits.view(B*T, self.n_dim_emb)
    targets = rearrange(targets, 'B T -> (B T)')
    # Equivalent to targets = targets.view(B*T)
    loss = F.cross_entropy(logits, targets)
    # Sanity check: should start around -torch.log(torch.tensor(1/vocab_size))
    # for a randomly initialized model
    return loss

morse_corpus = [
    "....................",
    ".-.-.-.-.-.-.-.-.-.-",
    "... --- ..- ... - --- -.",
]

corpus_dummy = ["......", "......"]
corpus_alternating = [".-.-.-.-", "..--..--"]

vocab = [" ", ".", "-"]

token_ids = torch.tensor(
    [
        [vocab.index(char) for char in sample]
        for sample in [".-.-.-.-", ".-.-.-.-"]
    ]
).to(device)

model = BaselineModel(vocab_size=len(vocab), n_dim_emb=2)
model = model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

context_length = 100
n_epochs = 100000

for i, n in enumerate(range(n_epochs)):
    inputs = token_ids[:, :-1]
    targets = token_ids[:, 1:]

    logits = model(inputs)
    loss = compute_loss(logits, targets)
    if i % (n_epochs/5) == 0 or i == n_epochs-1:
        print(f"i={i}, loss={loss:.3f}")

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

i=0, loss=3.695
i=20000, loss=0.047
i=40000, loss=0.022
i=60000, loss=0.014
i=80000, loss=0.010
i=99999, loss=0.008


Q: what would be the loss of a perfect model for this corpus?

In [None]:
# inference: try the following prompt: "..--", ".-.-", "...."
prompt = "..--"
x = torch.tensor([vocab.index(char) for char in prompt]).unsqueeze(0).to(device)

print("Decoded string:", "".join([vocab[i] for i in x[0]]), end="")

for _ in range(10):
    x = x[:, -context_length:]
    pred = model(x)[0, -1].argmax(dim=-1)
    print(vocab[pred.item()], end="")
    x = torch.cat((x[0], pred.unsqueeze(0))).unsqueeze(0)

Decoded string: ..--.-.-.-..-.

### Torch Lightning version

Torch lightning is a framework built on top of PyTorch that removes the need for:
- explicit calls to backprop or the optimizer (`loss.backward()`, `optimizer.step()`, and `optimizer.zero_grad()`)
- explicit device transfers (`.to(cuda)`)

It also makes the code cleaner, has nice logging, and more.

Let's rewrite the training code above using Lightning.

In [None]:
class LightningModel(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        return self.model(inputs)

    def training_step(self, tokens: torch.tensor) -> float:
        inputs = tokens[:, :-1]
        targets = tokens[:, 1:]

        logits = self.forward(inputs)

        logits = rearrange(logits, 'B T C -> (B T) C')
        targets = rearrange(targets, 'B T -> (B T)')

        loss = F.cross_entropy(logits, targets)
        self.log(
            "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

    def generate(self, x: str, n_tokens: int = 5):
        """Predict next token with greedy decoding."""
        # inference: try the following prompt: "..--", ".-.-", "...."
        x = torch.tensor([vocab.index(char) for char in x]).unsqueeze(0)

        for _ in range(n_tokens):
            pred = self(x)[:, -1, :]  # Logits of the next token prediction (B, V)
            next_tokens = pred.argmax(dim=-1) # Next token_id with highest proba (B)
            next_tokens = rearrange(next_tokens, "B -> B 1")
            x = torch.cat((x, next_tokens), dim=1)
        return "".join([vocab[i] for i in x[0]])

model = LightningModel(BaselineModel(len(vocab), 128))

We need to create a `DataLoader` object to feed the data to the Lightning Trainer.

In [None]:
from torch.utils.data import DataLoader

tiny_dataset = torch.tensor(
    [[vocab.index(char) for char in sample] for sample in corpus_alternating]
)
tiny_dataloader = DataLoader(tiny_dataset, batch_size=1)
next(iter(tiny_dataloader))

tensor([[1, 2, 1, 2, 1, 2, 1, 2]])

In [None]:
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, tiny_dataloader)
model.generate(".-.-")

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type          | Params
----------------------------------------
0 | model | BaselineModel | 768   
----------------------------------------
768       Trainable params
0         Non-trainable params
768       Total params
0.003     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type          | Para

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=10` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


'.-.-.-.-.'

In [None]:
import logging
# Silence Lightning GPU reports from now on
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

### Rotary Position Embeddings (RoPE)

Since we're just averaging embeddings for the past characters, we're loosing the position information and are thus representing "..--" and ".-.-", for instance, with the same average vector.

In Transformers models the position information is restored through transformations to the embedding vectors, that depend on the position index $n$. As of 2024, the dominant approach to position encodings is [RoPE](https://arxiv.org/pdf/2104.09864), that we've covered in a [previous notebook](https://colab.research.google.com/drive/1lbKdV9lUWDqrfUlZalnzTRYF7i4FOJY8#scrollTo=p6d2V7RWhEk4).

Our average embeddings model can be viewed as a single self-attention layer model, with unit attention weight matrix, i.e.
$$A=Q K^T =
\begin{pmatrix}
1 & 1  & \cdots & 1 \\
1 & 1  & \cdots & 1 \\
\vdots & \vdots  & \ddots & \vdots \\
1 & 1 & \cdots & 1
\end{pmatrix}
$$.

To prepare the full Llama implementation with RoPE, we first rewrite the model above as a self-attention layer with uniform attention weights $a_{m,n}=1$.

The self-attention layer with causal mask is given by
$$
\operatorname{Attention}\left ( \color{orange}{Q}, \color{green}{K}, V \right ) = \operatorname{softmax}\left ( \frac{\color{orange}{Q}\color{green}{K}^{T}}{\sqrt{d_{k}}} + M \right )V \quad \text{with} \quad M = \begin{pmatrix}
0 & -\infty  & \cdots & -\infty \\
0 & 0  & \cdots & -\infty \\
\vdots & \vdots  & \ddots & \vdots \\
0 & 0 & \cdots & 0
\end{pmatrix}
$$
Note that softmax maps $-\infty$ to 0, such that adding the mask above inside softmax is equivalent (but more efficient!) to multiplying the softmax output by the mask we used earlier:
$$
\text{Mask}=
\begin{pmatrix}
1 & 0 & \cdots & 0 \\
1 & 1 & \cdots & 0 \\
\vdots & \vdots & \ddots & \vdots \\
1 & 1 & \cdots & 1
\end{pmatrix}
$$

The query and key matrics are given by
$$
\color{orange}{Q}=W_q x = \left(\begin{array}{ccc}
\text{------} & q_1^T & \text{------} \\
\text{------} & q_2^T & \text{------} \\
& \vdots & \\
\text{------} & q_m^T & \text{------}
\end{array}\right), \qquad
\color{green}{K}^T=W_k x = \left(\begin{array}{cccc}
\mid & \mid & & \mid \\
k_1 & k_2 & \cdots & k_n \\
\mid & \mid & & \mid
\end{array}\right)
$$

Those formulas correspond to our average embedding model if we set $\color{orange}{Q} = \color{green}{K}$ as constant matrices of arbitrary value with shape $T\times d_k$  (where $d_k$ is arbitrary!) and $V=X$.

Let's check that we recover the expected average weights with a "self-attention"-like implementation of average embeddings.

In [None]:
import math

T = 3
head_dim = 4

Q = torch.ones((T, head_dim))
K = torch.ones((T, head_dim))
attention = Q @ K.T / math.sqrt(head_dim)

mask = torch.full((T, T), float("-inf"))
mask = torch.triu(mask, diagonal=1)

scores = F.softmax(attention + mask, dim=-1)
scores

tensor([[1.0000, 0.0000, 0.0000],
        [0.5000, 0.5000, 0.0000],
        [0.3333, 0.3333, 0.3333]])

In [None]:
class UniformAttention(nn.Module):
    def __init__(self, vocab_size: int, n_dim_emb: int):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, n_dim_emb)
        self.output_layer = nn.Linear(n_dim_emb, vocab_size, bias=False)
        self.head_dim = 1

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        B, T = inputs.shape  # B: batch_size, T: sequence_length ("timesteps")

        embeddings = self.embedding(inputs)  # (B, T) -> (B, T, C)

        Q = torch.ones((T, head_dim))
        K = torch.ones((T, head_dim))
        attention = (Q @ K.T / math.sqrt(head_dim)).to(embeddings.device)

        mask = torch.full((T, T), float("-inf"))
        mask = torch.triu(mask, diagonal=1).to(embeddings.device)

        scores = F.softmax(attention + mask, dim=-1)
        x = scores @ embeddings

        logits = self.output_layer(x)  # (B, T, C) -> (B, T, V)

        return logits

model = LightningModel(UniformAttention(vocab_size=len(vocab), n_dim_emb=128))
trainer = L.Trainer(max_epochs=100, callbacks=[], enable_progress_bar=False)
trainer.fit(model, tiny_dataloader)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type             | Params
-------------------------------------------
0 | model | UniformAttention | 768   
-------------------------------------------
768       Trainable params
0         Non-trainable params
768       Total params
0.003     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type    

In [None]:
model.generate(".-.-")

'.-.-.-..-'

### Exercise: add RoPE

Using our previous RoPE notebook, implement a RoPE rotation function for keys and values. Check whether it fixes the ordering blindness problem.

Note: use a generic implementation to anticipate the following, but note that for this simple attention map, RoPE takes a simpler form where the attention weights $a_{m,n}=q_m^T \cdot k_n$ are replaced by

$$
\operatorname{RoPE}(q_m^T)\cdot\operatorname{RoPE}(k_n) =  \cos((n-m)\theta_1) + \cos((n-m)\theta_2) + \dots = \sum_{i=0}^{d/2} \cos((n-m)\theta_i)
$$
where $\theta_{i}= 10000^{-\frac{i-1}{d/2}}$.

In [None]:
# @title Solution

def compute_complex_rotations(T, C):
    c_values = torch.arange(1, C/2 + 1)
    thetas = 10000 ** (2 * (c_values - 1) / C)  # Shape (C/2,)
    timesteps = torch.arange(T)  # Shape (T,)

    # Angular frequencies for each (t, c) pairs
    omegas = torch.outer(timesteps, thetas)  # Shape (T, C/2)

    # Turn those into complex numbers
    z = torch.polar(torch.ones_like(omegas), omegas)
    return z

def apply_rope_rotation(q, complex_rotations):
    q_pairs = rearrange(q, 'T (C p) -> T C p', p=2)  # Shape (T, C) -> (T, C/2, 2)
    q_complex = torch.view_as_complex(q_pairs)
    q_rotated = q_complex * complex_rotations
    q_rotated = torch.view_as_real(q_rotated)  # Back to real numbers
    q_rotated = rearrange(q_rotated, 'T C p -> T (C p)')  # Shape (T, C/2, 2) -> (T, C)
    return q_rotated

T = 2
C = 2

q = torch.tensor([[1., 2.], [1., 0.]])
k = torch.randn(T, C)
print(f"Dot-products before rotation: {q @ k.T}")

z = compute_complex_rotations(T, C)
q = apply_rope_rotation(q, z)
k = apply_rope_rotation(k, z)
print(f"Dot-products after rotation: {q @ k.T}")

### Exercise add RoPE to training code

Remember to take into account the batch dimension.

In [None]:
# @title Solution preliminary: batch RoPE
B, T, C = 3, 5, 4

# We'll implement a batch version of the rope rotation function
# and check that results match our earlier function
z = compute_complex_rotations(T, C)
Q = torch.randn(B, T, C)
K = torch.randn(B, T, C)

q = apply_rope_rotation(Q[1], z)
k = apply_rope_rotation(K[1], z)
print(f"Dot-products first batch item:\n{q @ k.T}")

def apply_rope_batch(q, complex_rotations):
    q_pairs = rearrange(q, 'B T (C p) -> B T C p', p=2)
    q_complex = torch.view_as_complex(q_pairs)
    q_rotated = q_complex * complex_rotations
    q_rotated = torch.view_as_real(q_rotated)  # Back to real numbers
    q_rotated = rearrange(q_rotated, 'B T C p -> B T (C p)')
    return q_rotated

Q = apply_rope_batch(Q, z)
K = apply_rope_batch(K, z)
print(f"Batch version:\n{(Q @ K.mT)[1]}")

In [None]:
# @title Solution

class UniformAttention(nn.Module):
    def __init__(self, T:int, V: int, C: int):
        """
          T: max number of timesteps (context window size)
          V: vocabulary size
          C: number of channels (embedding dimension)
        """
        super().__init__()
        self.embedding = nn.Embedding(V, C)
        self.output_layer = nn.Linear(C, V, bias=False)
        self.head_dim = C

        self.complex_rotations = compute_complex_rotations(T, C)

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        B, T = inputs.shape  # B: batch_size, T: sequence_length ("timesteps")
        z = self.complex_rotations[:T, :]

        embeddings = self.embedding(inputs)  # (B, T) -> (B, T, C)

        Q = self.head_dim ** 1/4 / T * torch.ones((B, T, self.head_dim)) / 109
        Q = apply_rope_batch(Q, z)
        K = self.head_dim ** 1/4 / T * torch.ones((B, T, self.head_dim)) / 109
        K = apply_rope_batch(K, z)

        attention = (Q @ K.mT / math.sqrt(self.head_dim)).to(embeddings.device)

        mask = torch.full((T, T), float("-inf"))
        mask = torch.triu(mask, diagonal=1).to(embeddings.device)

        scores = F.softmax(attention + mask, dim=-1)

        x = scores @ embeddings

        logits = self.output_layer(x)  # (B, T, C) -> (B, T, V)

        return logits

model = LightningModel(UniformAttention(T=10, V=len(vocab), C=64))
trainer = L.Trainer(max_epochs=100)
trainer.fit(model, tiny_dataloader)

# Check that the representation on the last token depends on the order of the
# preceeding tokens
x = torch.tensor([[vocab.index(char) for char in ".-."]])
print(model(x)[0][-1])
x = torch.tensor([[vocab.index(char) for char in "-.."]])
print(model(x)[0][-1])

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:lightning.pytorch.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO: 
  | Name  | Type             | Params
-------------------------------------------
0 | model | UniformAttention | 384   
-------------------------------------------
384       Trainable params
0         Non-trainable params
384       Total params
0.002     Total estimated model params size (MB)
INFO:lightning.pytorch.callbacks.model_summary:
  | Name  | Type    

Training: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=100` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=100` reached.


tensor([-3.3485,  1.2309,  1.5887], grad_fn=<SelectBackward0>)
tensor([-3.3505,  1.2288,  1.5915], grad_fn=<SelectBackward0>)


In [None]:
# Check that the model is sensitive to permutations
print(model.generate(".-.-"))
print(model.generate("..--"))

.-.-.-.-.
..--.-.-.


### Exercise: implement the self attention layer

Introduce the value vector $v = W_v x$, replace our dummy $q$ and $k$ with $q=W_q x$ and $k = W_k x$. Here $W_q\in \mathbb R^{T, d_k}$ and $W_k\in\mathbb R^{T, d_k}$, and $W_v\in\mathbb R^{T, d_k}$, are all learnable matrices.

Note that hte self-attention layer, like all layers in Transformers, is a [**residual layer**](https://en.wikipedia.org/wiki/Residual_neural_network): it adds its computation to the input (for stability, mainly to avoid vanishing gradients).

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, emb_dim=64, head_dim=64):
        super().__init__()
        self.head_dim = head_dim
        # TODO

    def forward(self, x, complex_rotations, mask):
        # TODO
        # Since the mask and rotations are the same for the whole
        # network, we don't store them in each module but pass them
        # when calling the forward function.
        return x

In [None]:
# @title Solution

class SelfAttention(nn.Module):
    def __init__(self, emb_dim=64, head_dim=64):
        super().__init__()
        self.head_dim = head_dim
        self.Wq = nn.Linear(emb_dim, head_dim, bias=False)
        self.Wk = nn.Linear(emb_dim, head_dim, bias=False)
        self.Wv = nn.Linear(emb_dim, head_dim, bias=False)

    def forward(self, x, complex_rotations, mask):
        # Compute Queries, Keys, and Values from embeddings
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        # Apply RoPE to queries and keys
        Q = apply_rope_batch(Q, complex_rotations)
        K = apply_rope_batch(K, complex_rotations)

        attention = (Q @ K.mT / math.sqrt(self.head_dim)).to(x.device)

        scores = F.softmax(attention + mask, dim=-1)

        return x + scores @ x


# Test:
B, T, C = 2, 2, 2

x = torch.randn(B, T, C)
complex_rotations = compute_complex_rotations(T, C)
mask = torch.full((T, T), float("-inf"))
mask = torch.triu(mask, diagonal=1)

sa = SelfAttention(T, C)
sa(x, complex_rotations, mask)

tensor([[[-1.5206, -3.5257],
         [-1.3390, -0.8051]],

        [[-1.5104, -3.6652],
         [ 1.3377, -0.7909]]], grad_fn=<AddBackward0>)

### Exercise: implement SwiGLU layer

Reuse the implementation from the [SwiGLU notebook](https://colab.research.google.com/drive/1_6oJEHmgO5xJK_Pud5J8oEzzjzRCf46C#scrollTo=_RBR5gI6VsUb).

In [None]:
class FeedForward(nn.Module):
    def __init__(self, hidden_dims=100):
        super().__init__()
        # TODO

    def forward(self, x):
        # TODO
        return x

In [None]:
# @title Solution
class FeedForward(nn.Module):
    def __init__(self, hidden_dims=100):
        super().__init__()
        self.fc1 = nn.Linear(2, hidden_dims, bias=False)
        self.silu = nn.SiLU()  # x * sigmoid(x)
        self.fc2 = nn.Linear(2, hidden_dims, bias=False)
        self.fc3 = nn.Linear(hidden_dims, 2, bias=False)

    def forward(self, x):
        gate = F.silu(self.fc1(x))
        x = self.fc2(x)
        x = x * gate
        x = self.fc3(x)
        return x

### Exercise: implement the Transformer Block

Combine our `SelfAttention` and `FeedForward` layer into a single block. Reuse the RMSNorm layer from [this previous notebook](https://colab.research.google.com/drive/1M9uhPAZkzV4ABXkJSQRELtjtLJfyTIZY#scrollTo=UZ3_AVHhBCSF).



In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, hidden_dims=100):
        super().__init__()
        # TODO

    def forward(self, x):
        # TODO
        return x

In [None]:
# @title Solution
from einops import reduce

class RMSNorm(nn.Module):
    def __init__(self, emb_dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(emb_dim))

    def forward(self, x):
        # Note: explicit casting to fp32 to avoid numerical underflow
        x_fp32 = x.to(torch.float32)
        mean_square = reduce(x_fp32**2, '... d -> ... 1', 'mean')
        inverse_rms = torch.rsqrt(mean_square + self.eps)
        inverse_rms = inverse_rms.type_as(x)  # For fp16 compatibility
        return self.weight * x * inverse_rms


class TransformerBlock(nn.Module):
    def __init__(self, emb_dim: int, head_dim: int):
        super().__init__()
        self.emb_dim = emb_dim
        self.head_dim = head_dim

        self.att_norm = RMSNorm(emb_dim)
        self.self_attention = SelfAttention(emb_dim, head_dim)
        self.ffn_norm = RMSNorm(emb_dim)
        self.feed_forward = FeedForward(hidden_dims=emb_dim)


    def forward(
        self,
        x: torch.tensor,
        complex_rotations: torch.tensor,
        mask: torch.tensor,
    ):
        x = self.att_norm(x)
        x = x + self.self_attention(x, complex_rotations, mask)
        x = self.ffn_norm(x)
        x = x + self.feed_forward(x)
        return x

# Test:
B, T, C = 2, 2, 2

x = torch.randn(B, T, C)
complex_rotations = compute_complex_rotations(T, C)
mask = torch.full((T, T), float("-inf"))
mask = torch.triu(mask, diagonal=1)

tb = TransformerBlock(T, C)
tb(x, complex_rotations, mask)

tensor([[[-0.1386,  1.4345],
         [ 0.9247,  1.1678]],

        [[ 0.2469, -1.3571],
         [ 0.2683, -1.3554]]], grad_fn=<AddBackward0>)