# Guanaco: toy Llama3 implementation

In this notebook, we'll implement the Llama3 model architecture brick by brick, starting
from a simple model that averages character embeddings that we'll apply to simple morse
code strings (vocabulary size 3). We'll then add:
- Rotary Position Encodings (RoPE)
- Self-attention
- The SwiGLU activation function
- The full self-attention block with RMSNorm

Once ready, we'll train the model on the TinyStories dataset, with text tokenized at the
byte level.

Note that the actual Llama3 implementation:
- uses the [Tiktoken tokenizer](https://github.com/meta-llama/llama3/blob/main/llama/tokenizer.py)
- has Key-Value caching
- uses Grouped query attention

In [None]:
%pip install -q datasets einops lightning

## Modeling imports

In [None]:
import math
import torch
import torch.nn.functional as F
from torch import nn

from einops import rearrange
import lightning as L

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
def compute_complex_rotations(T, C):
    c_values = torch.arange(1, C/2 + 1)
    thetas = 10000 ** (2 * (c_values - 1) / C)  # Shape (C/2,)
    timesteps = torch.arange(T)  # Shape (T,)

    # Angular frequencies for each (t, c) pairs
    omegas = torch.outer(timesteps, thetas)  # Shape (T, C/2)

    # Turn those into complex numbers
    z = torch.polar(torch.ones_like(omegas), omegas)
    return z


def apply_rope(q, complex_rotations):
    q_pairs = rearrange(q, 'B T (C p) -> B T C p', p=2)
    q_complex = torch.view_as_complex(q_pairs)
    q_rotated = q_complex * complex_rotations
    q_rotated = torch.view_as_real(q_rotated)  # Back to real numbers
    q_rotated = rearrange(q_rotated, 'B T C p -> B T (C p)')
    return q_rotated

In [None]:
class SelfAttention(nn.Module):
    def __init__(self, emb_dim=64, head_dim=64):
        super().__init__()
        self.head_dim = head_dim
        self.Wq = nn.Linear(emb_dim, head_dim, bias=False)
        self.Wk = nn.Linear(emb_dim, head_dim, bias=False)
        self.Wv = nn.Linear(emb_dim, head_dim, bias=False)

    def forward(self, x, complex_rotations, mask):
        # Compute Queries, Keys, and Values from embeddings
        Q = self.Wq(x)
        K = self.Wk(x)
        V = self.Wv(x)

        # Apply RoPE to queries and keys
        Q = apply_rope(Q, complex_rotations)
        K = apply_rope(K, complex_rotations)

        attention = (Q @ K.mT / math.sqrt(self.head_dim)).to(x.device)

        scores = F.softmax(attention + mask, dim=-1)
        return scores @ V

B, T, C = 4, 6, 8
x = torch.randn(B, T, C)
complex_rotations = compute_complex_rotations(T, C//2)
mask = torch.full((T, T), float("-inf"))
mask = torch.triu(mask, diagonal=1)
sa = SelfAttention(C, C//2)
sa(x, complex_rotations, mask).shape

torch.Size([4, 6, 4])

In [None]:
class FeedForward(nn.Module):
    def __init__(self, emb_dims: int, hidden_dims: int):
        super().__init__()
        self.fc1 = nn.Linear(emb_dims, hidden_dims, bias=False)
        self.fc2 = nn.Linear(emb_dims, hidden_dims, bias=False)
        self.fc3 = nn.Linear(hidden_dims, emb_dims, bias=False)

    def forward(self, x: torch.tensor) -> torch.tensor:
        gate = F.silu(self.fc1(x) )  # Silu(x) = x * sigmoid(x)
        x = self.fc2(x)
        x = x * gate
        x = self.fc3(x)
        return x

In [None]:
from einops import reduce

class RMSNorm(nn.Module):
    def __init__(self, emb_dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(emb_dim))

    def forward(self, x):
        # Note: explicit casting to fp32 to avoid numerical underflow
        x_fp32 = x.to(torch.float32)
        mean_square = reduce(x_fp32**2, '... d -> ... 1', 'mean')
        inverse_rms = torch.rsqrt(mean_square + self.eps)
        inverse_rms = inverse_rms.type_as(x)  # For fp16 compatibility
        return self.weight * x * inverse_rms

x = torch.randn(4, 6, 8).to(device)
RMSNorm(8).to(device)(x)[:2]

tensor([[[-0.3956, -1.6693, -0.1431,  0.4340,  0.0379,  0.8714, -1.9350,
           0.5856],
         [-0.1804, -0.5587, -0.3852, -1.2573, -0.3222,  2.3822, -0.0231,
          -0.3837],
         [ 1.8508,  1.6662,  0.6516,  0.1083, -0.1892, -0.3640, -0.0362,
           1.0920],
         [-0.5574,  0.3649,  1.8399,  0.1525,  0.5354,  1.0126,  0.2384,
          -1.6670],
         [-1.4708,  1.0690,  0.3660, -0.1334, -1.3622,  0.4177, -1.4851,
          -0.5541],
         [ 1.4139,  1.4970,  0.4411,  0.2976, -0.0950,  0.6230, -1.7535,
           0.0713]],

        [[ 0.4707, -0.7305, -1.0720, -0.5796, -0.5046,  1.0500, -1.8574,
          -0.9761],
         [ 0.3643, -0.1469,  0.2433,  1.3644, -0.2506, -0.7742,  2.1016,
          -0.9196],
         [ 1.3076, -0.0627,  1.2623, -0.3011,  0.3594, -0.7715, -0.2176,
           1.9572],
         [-0.5783,  0.3152, -0.0031, -1.3697, -2.3325, -0.1721, -0.3228,
           0.3401],
         [ 0.3280, -1.3511,  1.7302, -0.8761, -0.9936, -0.6455,  0.4

In [None]:
class MultiHeadSelfAttentionBlock(nn.Module):
    def __init__(self, emb_dim=64, n_heads=4):
        super().__init__()
        assert emb_dim % n_heads == 0
        self.head_dim = emb_dim // n_heads


        self.att_norm = RMSNorm(emb_dim)
        self.heads = nn.ModuleList([SelfAttention(emb_dim, self.head_dim)] * n_heads)
        self.projection = nn.Linear(emb_dim, emb_dim)

        self.ffn_norm = RMSNorm(emb_dim)
        self.feed_forward = FeedForward(emb_dim, 4*emb_dim)


    def forward(self, x, complex_rotations, mask):
        x = self.att_norm(x)
        x = x + torch.cat([h(x, complex_rotations, mask) for h in self.heads], dim=-1)
        x = self.ffn_norm(x)
        x = x + self.feed_forward(x)
        return x

# Test:
B, T, C = 4, 6, 8
n_heads = 2

x = torch.randn(B, T, C)
complex_rotations = compute_complex_rotations(T, C // n_heads)
mask = torch.full((T, T), float("-inf"))
mask = torch.triu(mask, diagonal=1)

mhsa = MultiHeadSelfAttentionBlock(C, n_heads)
mhsa(x, complex_rotations, mask).shape

torch.Size([4, 6, 8])

In [None]:
class Guanaco(nn.Module):
    def __init__(self, vocab_size=3, emb_dim=64, n_heads=4, n_layers=2, max_len=128):
        super().__init__()
        self.embeddings = nn.Embedding(vocab_size, emb_dim)
        self.emb_dim = emb_dim
        self.max_len = max_len

        self.blocks = nn.ModuleList(
            [MultiHeadSelfAttentionBlock(emb_dim, n_heads)]*n_layers
        )

        self.output_norm = RMSNorm(emb_dim)
        self.output_layer = nn.Linear(emb_dim, vocab_size)

        complex_rotations = compute_complex_rotations(max_len, emb_dim // n_heads)
        self.register_buffer('complex_rotations', complex_rotations)

        mask = torch.triu(torch.full((max_len, max_len), float("-inf")), diagonal=1)
        self.register_buffer('mask', mask)


    def forward(self, token_ids):
        token_ids = token_ids[:, -self.max_len:]
        B, T = token_ids.shape

        complex_rotations = self.complex_rotations[:T, :]
        mask = self.mask[:T, :T]

        x = self.embeddings(token_ids) # (B,T,C)
        # x = self.blocks(x, complex_rotations, mask)
        for block in self.blocks:
            x = block(x, complex_rotations, mask) # (B,T,C)
        x = self.output_norm(x) # (B,T,C)
        logits = self.output_layer(x) # (B,T,V)
        return logits

model = Guanaco().to(device)
x = torch.tensor([[1, 2, 0]]).to(device)
model(x)


tensor([[[ 0.8426,  0.6673,  0.6552],
         [-0.4229, -0.5656,  0.1760],
         [ 0.8213, -0.3349,  0.9779]]], device='cuda:0',
       grad_fn=<ViewBackward0>)

In [None]:
class LightningModel(L.LightningModule):
    def __init__(self, model, vocab):
        super().__init__()
        self.model = model
        self.vocab = vocab

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        return self.model(inputs)

    def training_step(self, tokens: torch.tensor) -> float:
        inputs = tokens[:, :-1]
        targets = tokens[:, 1:]

        logits = self.forward(inputs)

        logits = rearrange(logits, 'B T C -> (B T) C')
        targets = rearrange(targets, 'B T -> (B T)')

        loss = F.cross_entropy(logits, targets)
        self.log(
            "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=1e-3)
        return optimizer

## TinyStories dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("roneneldan/TinyStories")
dataset.set_format(type="torch", columns=["text"])
dataset

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Repo card metadata block was not found. Setting CardData to empty.


DatasetDict({
    train: Dataset({
        features: ['text'],
        num_rows: 2119719
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 21990
    })
})

### Tokenization

#### Custom byte level-tokenization

In [None]:
def str_to_code_units(text: str) -> list[int]:
  """Converts a string to a bytes object using UTF-8 encoding."""
  return [code_unit for code_unit in text.encode('utf-8')]

def code_units_to_str(token_ids: list[int]) -> str:
  """Converts a bytes object to a string using UTF-8 encoding."""
  return bytes(token_ids).decode('utf-8', errors="replace")

def bytes_to_bits(integers: list[int]) -> list[str]:
  """Converts a list of decimal integers a list of bits."""
  return [bin(i)[2:].zfill(8) for i in integers]


for example in ["e", "é", "落語家", "Hello"]:
    tokens = str_to_code_units(example)
    print(f"{example} -> {tokens} (binary representation {bytes_to_bits(tokens)})")

e -> [101] (binary representation ['01100101'])
é -> [195, 169] (binary representation ['11000011', '10101001'])
落語家 -> [232, 144, 189, 232, 170, 158, 229, 174, 182] (binary representation ['11101000', '10010000', '10111101', '11101000', '10101010', '10011110', '11100101', '10101110', '10110110'])
Hello -> [72, 101, 108, 108, 111] (binary representation ['01001000', '01100101', '01101100', '01101100', '01101111'])


In [None]:
vocab = " abcdefghijklmnopqrstuvwxyz,."

special_tokens = ["<pad>", "<unk>"]

vocab = special_tokens + list(vocab)
vocab_size = len(vocab)

string_to_index = {string: i for i, string in enumerate(vocab)}
index_to_string = {i: string for i, string in enumerate(vocab)}

In [None]:
def tokenize(text: str) -> list[int]:
    # return [string_to_index.get(string, 1) for string in text]
    return str_to_code_units(text)

def detokenize(token_ids: list[int]) -> str:
    # return "".join([index_to_string[idx] for idx in token_ids])
    return code_units_to_str(token_ids)

print(tokenize("é a"))
detokenize(tokenize("é a"))

[195, 169, 32, 97]


'é a'

In [None]:
import torch
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence

context_length = 128
batch_size= 1024

def collate_fn(batch):
    token_ids = [torch.tensor(tokenize(b["text"])[:context_length]) for b in batch]
    token_ids = pad_sequence(token_ids, batch_first=True)
    return token_ids

train_dataloader = DataLoader(
    dataset["train"], batch_size=batch_size, shuffle=True, collate_fn=collate_fn,
    num_workers=4, pin_memory=True, persistent_workers=True
)
val_dataloader = DataLoader(
    dataset["validation"], batch_size=batch_size, shuffle=False, collate_fn=collate_fn,
    num_workers=4,
)
next(iter(train_dataloader))[0]

tensor([ 79, 110,  99, 101,  32, 117, 112, 111, 110,  32,  97,  32, 116, 105,
        109, 101,  44,  32, 116, 104, 101, 114, 101,  32, 119,  97, 115,  32,
         97,  32,  99, 108, 101, 118, 101, 114,  32,  98, 117, 110, 110, 121,
         32, 110,  97, 109, 101, 100,  32,  66, 101, 110, 110, 121,  46,  32,
         72, 101,  32, 108, 111, 118, 101, 100,  32, 116, 111,  32, 104, 111,
        112,  32,  97, 114, 111, 117, 110, 100,  32, 105, 110,  32, 116, 104,
        101,  32, 115, 112, 114, 105, 110, 103, 116, 105, 109, 101,  44,  32,
        119, 104, 101, 110,  32, 116, 104, 101,  32, 102, 108, 111, 119, 101,
        114, 115,  32, 119, 101, 114, 101,  32,  98, 108, 111, 111, 109, 105,
        110, 103])

## Exercise: train our TinyLlama model on the TinyStories dataset

- (make sure you're using a GPU)
- Create a training and a validation set. Train the small model. Do you see indications of learning? Overfitting?
- Increase the parameter count: number of layers, embedding dimensions, ... What is the impact on performance.
- Run inference with your model on different inputs

In [15]:
class GuanacoModule(L.LightningModule):
    def __init__(self, model, learning_rate=1e-4):
        super().__init__()
        self.model = model
        self.learning_rate = learning_rate

    def forward(self, inputs: torch.tensor) -> torch.tensor:
        return self.model(inputs)

    def training_step(self, tokens: torch.tensor) -> float:
        inputs = tokens[:, :-1]
        targets = tokens[:, 1:]

        logits = self.forward(inputs)

        logits = rearrange(logits, 'B T C -> (B T) C')
        targets = rearrange(targets, 'B T -> (B T)')

        loss = F.cross_entropy(logits, targets)
        self.log(
            "train_loss", loss, prog_bar=True, logger=True, on_step=True, on_epoch=True
        )
        return loss

    def validation_step(self, tokens: torch.tensor) -> float:
        loss = self.training_step(tokens)
        self.log(
            "val_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            logger=True,
            sync_dist=True,
        )
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        return optimizer

model = GuanacoModule(
    Guanaco(vocab_size=256, emb_dim=512, n_heads=1, n_layers=1, max_len=128),
    learning_rate=1e-3
)

trainer = L.Trainer(max_epochs=2, devices=1)
# Development trick: use overfit_batches=0.01 to make sure you can overfit small samples
trainer.fit(model, train_dataloader, val_dataloader)

INFO: GPU available: True (cuda), used: True
INFO:lightning.pytorch.utilities.rank_zero:GPU available: True (cuda), used: True
INFO: TPU available: False, using: 0 TPU cores
INFO:lightning.pytorch.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO: IPU available: False, using: 0 IPUs
INFO:lightning.pytorch.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO: HPU available: False, using: 0 HPUs
INFO:lightning.pytorch.utilities.rank_zero:HPU available: False, using: 0 HPUs
INFO: You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:lightning.pytorch.utilities.rank_zero:You are using a CUDA device ('NVIDIA A100-SXM4-40GB') that has Tensor Cores. To properly 

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

  self.pid = os.fork()


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO: `Trainer.fit` stopped: `max_epochs=2` reached.
INFO:lightning.pytorch.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=2` reached.


## Generate stories with trained model

In [18]:
def generate(model, x: str, n_tokens: int = 5, device="cuda"):
    """Predict next token with greedy decoding."""
    x = torch.tensor(tokenize(x)).unsqueeze(0)
    x = x.to(device)
    model = model.to(device)

    for _ in range(n_tokens):
        pred = model(x)[:, -1, :]  # Logits of the next token prediction (B, V)
        next_tokens = pred.argmax(dim=-1) # Next token_id with highest proba (B)
        next_tokens = rearrange(next_tokens, "B -> B 1")
        x = torch.cat((x, next_tokens), dim=1)
    return "".join(detokenize(x[0].tolist()))

generate(model, "Once upon a time", n_tokens=200)

'Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big red ball with her mommy and daddy. One day, Lily wanted to play outside and said her to help her'