# Exercise 6

## Group
- **ID**: 5

- **Members**:
    - Hasan Algafri
    - Emre Dursunluer
    - Taha El Amine Kassabi

## Hand-in
- Please hand in this notebook with your code implementation via Ilias 
- Please make sure that there is exactly **one** submission per group

## Task Description

In this exercise, you will implement a custom Extended Long Short-Term Memory (xLSTM) model to predict the next tokens given an input sequence. The Model is described in the paper [xLSTM: Extended Long Short-Term Memory](https://arxiv.org/abs/2405.04517).

You will work with the “Tiny Shakespeare” dataset, a character-level corpus of Shakespeare’s plays and sonnets, commonly used for next-character prediction. The dataset is available at [Github](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).

You will implement a custom character‐level tokenizer and DataLoader, write your costum Model(with different classes) and train it, plot the Perplexity score and the loss curve and finally showcase input–output text samples from your trained xLSTM.

** NEW **:
We provide some of the mLSTM and sLSTM code, as illustrated in Figures 10 and 11 of the xLSTM paper. For this part, you only need to implement the mLSTMCell and sLSTMCell classes, the gray boxes shown in those figures, and integrate them with the rest of the code. You’re free to modify any part of the provided code.

## Grading scheme
Total: 5 points
1. **Preparing the Tokenizer and Dataloader** (1 point)
2. **Preparing the Model** (2.5 point)
3. **Train the Model** (1 point)
4. **Showcasing plots and few input & output examples** (0.5 point)

### Imports

In [1]:
from collections import defaultdict

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm, trange

device = "cuda" if torch.cuda.is_available() else "mps" if torch.mps.is_available() else "cpu"

In [None]:
os.makedirs("../models/", exist_ok=True)
os.makedirs("../data/", exist_ok=True)

### **Preparing the Tokenizer and Dataloader** (1 point)

In [2]:
!wget -nc -O ../data/tinyshakespeare.txt https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

File ‘../data/tinyshakespeare.txt’ already there; not retrieving.


In [3]:
class CharTokenizer:
    '''
    Simple character level tokenizer; maps each unique character to an integer index.
    '''

    def __init__(self, text):
        chars = sorted(set(text))
        self.vocab_size = len(chars)
        self.char_to_tok = {ch: i for i, ch in enumerate(chars)}
        self.tok_to_char = {i: ch for ch, i in self.char_to_tok.items()}

    def __len__(self):
        return self.vocab_size

    def __call__(self, text):
        return self.encode(text)

    def encode(self, text):
        return [self.char_to_tok[ch] for ch in text]

    def decode(self, tokens):
        return [self.tok_to_char[tok] for tok in tokens]


class CharTokenizedText(Dataset):
    def __init__(self, text, tokenizer, seq_len=128):
        self.text = text
        self.seq_len = seq_len
        self.data = torch.tensor(tokenizer(text), dtype=torch.long)

    def __len__(self):
        return len(self.text) - self.seq_len

    def __getitem__(self, idx):
        x = self.data[idx: idx + self.seq_len]
        y = self.data[idx + 1: idx + self.seq_len + 1]
        return x, y


def load_data(train_proportion=0.8, batch_size=128, seq_len=128):
    with open("../data/tinyshakespeare.txt", "r") as f:
        text = f.read()

    tokenizer = CharTokenizer(text)

    ds = lambda txt: CharTokenizedText(txt, tokenizer, seq_len)
    dl = lambda ds, shuffle=False: DataLoader(ds, batch_size=batch_size, shuffle=shuffle)

    train_chars = int(len(text) * train_proportion)
    train_text, val_text = text[:train_chars], text[train_chars:]
    train_ds, val_ds = ds(train_text), ds(val_text)
    return dl(train_ds, True), dl(val_ds), tokenizer

### **Preparing the Model** (2.5 point)

#### components

In [4]:
class BlockDiagonalProj(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(BlockDiagonalProj, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.out_head_size = input_dim // num_heads
        self.weight = nn.Parameter(torch.empty(num_heads, self.out_head_size, input_dim // num_heads))
        nn.init.normal_(
            self.weight.data,
            mean=0.0,
            std=(2.0 / 5.0 / self.weight.shape[-1]) ** .5
        )

    def forward(self, x):
        shape = x.shape
        x = x.view(*shape[:-1], self.num_heads, -1)
        x = torch.einsum("...hd,hod->...ho", x, self.weight)
        x = x.reshape(*shape[:-1], -1)
        return x


class CasualConv1d(nn.Module):
    def __init__(self, feature_dim, kernel_size, bias=True):
        super(CasualConv1d, self).__init__()
        self.pad = (kernel_size - 1)
        self.conv = nn.Conv1d(in_channels=feature_dim, out_channels=feature_dim, kernel_size=kernel_size, padding=self.pad, groups=feature_dim, bias=bias)

    def forward(self, x):
        y = x.transpose(2, 1)
        y = self.conv(y)
        return y[:, :, : -self.pad].transpose(2, 1)


#### mLSTM block

In [5]:
### COMPLETE THIS CLASS ####
class mLSTMCell(nn.Module):
    def __init__(self, input_dim):
        super(mLSTMCell, self).__init__()
        self.i = nn.Linear(input_dim, input_dim)
        self.f = nn.Linear(input_dim, input_dim)
        self.o = nn.Linear(input_dim, input_dim)

    def forward(self, q, k, v):
        B, S, D = q.shape
        h = torch.empty_like(q)
        C = torch.zeros(B, D, D, device=q.device, dtype=q.dtype)
        n = torch.zeros(B, D, device=q.device, dtype=q.dtype)
        m_prev = torch.zeros_like(n)

        i = self.i(k)
        f = self.f(k)
        o = self.o(k).sigmoid()

        for t in range(S):
            q_t = q[:, t, :]  # (B, D)
            k_t = k[:, t, :]
            v_t = v[:, t, :]

            i_t = i[:, t, :]
            f_t = f[:, t, :]
            o_t = o[:, t, :]

            m_t = torch.maximum(f_t + m_prev, i_t)

            i_t = (i_t - m_t).exp()
            f_t = (f_t + m_prev - m_t).exp()
            n = f_t * n + i_t * k_t
            C = f_t.unsqueeze(-1) * C + i_t.unsqueeze(-1) * (v_t.unsqueeze(-1) @ k_t.unsqueeze(-2))

            h_t_raw = C @ q_t.unsqueeze(-1) / (n.unsqueeze(-2) @ q_t.unsqueeze(-1)).clamp(min=1)
            h[:, t, :] = o_t * h_t_raw.squeeze()

            m_prev = m_t

        return h


#############################

class mLSTMLayer(nn.Module):
    def __init__(self, embedding_dim, proj_blocksize, bias=False):
        super(mLSTMLayer, self).__init__()
        self.outer_embedding_dim = embedding_dim
        self.inner_embedding_dim = 2 * embedding_dim
        self.proj_blocksize = proj_blocksize
        self.bias = bias

        self.proj_up = nn.Linear(in_features=self.outer_embedding_dim,
                                 out_features=2 * self.inner_embedding_dim,
                                 bias=bias)
        self.num_proj_heads = self.inner_embedding_dim // proj_blocksize
        self.q_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        self.k_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        self.v_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)

        self.conv1d = CasualConv1d(feature_dim=self.inner_embedding_dim, kernel_size=4)
        self.conv_swish = nn.SiLU()

        ############################     EDIT      ##################################
        self.mlstm_cell = mLSTMCell(self.inner_embedding_dim)
        ##############################################################

        self.ogate_swish = nn.SiLU()
        self.learnable_skip_con = nn.Parameter(torch.ones(self.inner_embedding_dim, requires_grad=True))
        self.proj_down = nn.Linear(in_features=self.inner_embedding_dim,
                                   out_features=self.outer_embedding_dim,
                                   bias=bias)

    def forward(self, x):
        B, S, _ = x.shape
        x_ = F.layer_norm(x, normalized_shape=(self.outer_embedding_dim,))
        x_inner = self.proj_up(x_)
        x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.inner_embedding_dim, dim=-1)
        x_mlstm_conv = self.conv1d(x_mlstm)
        x_mlstm_conv_act = self.conv_swish(x_mlstm_conv)

        q = self.q_proj(x_mlstm_conv_act)
        k = self.k_proj(x_mlstm_conv_act)
        v = self.v_proj(x_mlstm)

        ##########################     EDIT      ####################################
        mlstm_cell_state = self.mlstm_cell(q, k, v)
        ##############################################################

        mlstm_cell_skip = mlstm_cell_state + (self.learnable_skip_con * x_mlstm_conv_act)

        h_state = mlstm_cell_skip * self.ogate_swish(z)

        y = self.proj_down(h_state) + x

        return y


#### sLSTM block

In [6]:
### COMPLETE THIS CLASS ####
class sLSTMCell(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(sLSTMCell, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.head_dim = input_dim // num_heads

        self.r_z = nn.Linear(input_dim, input_dim, bias=False)
        self.r_i = nn.Linear(input_dim, input_dim, bias=False)
        self.r_f = nn.Linear(input_dim, input_dim, bias=False)
        self.r_o = nn.Linear(input_dim, input_dim, bias=False)

    def forward(self, i_logits, f_logits, z_logits, o_logits):
        B, S, D = i_logits.shape
        NH, DH = self.num_heads, self.head_dim

        i_logits = i_logits.view(B, NH, S, DH)
        f_logits = f_logits.view(B, NH, S, DH)
        z_logits = z_logits.view(B, NH, S, DH)
        o_logits = o_logits.view(B, NH, S, DH)

        c = torch.zeros(B, NH, DH, device=i_logits.device, dtype=i_logits.dtype)
        n = torch.zeros_like(c)
        m_prev = torch.zeros_like(c)
        h = torch.zeros_like(i_logits)
        h_t_prev = h[:, :, 0, :].reshape(B, -1)

        for t in range(S):
            r_z_t = self.r_z(h_t_prev).view(B, NH, DH)
            r_i_t = self.r_i(h_t_prev).view(B, NH, DH)
            r_f_t = self.r_f(h_t_prev).view(B, NH, DH)
            r_o_t = self.r_o(h_t_prev).view(B, NH, DH)

            z_t = (z_logits[:, :, t, :] + r_z_t).sigmoid()
            o_t = (o_logits[:, :, t, :] + r_o_t).tanh()

            i_t = i_logits[:, :, t, :] + r_i_t
            f_t = f_logits[:, :, t, :] + r_f_t
            m_t = torch.maximum(f_t + m_prev, i_t)

            i_t = (i_t - m_t).exp()
            f_t = (f_t + m_prev - m_t).exp()

            c = f_t * c + i_t * z_t
            n = f_t * n + i_t

            h_t_raw = c / n
            h[:, :, t, :] = o_t * h_t_raw
            h_t_prev = h[:, :, t, :].reshape(B, -1)

            m_prev = m_t

        return h


#############################

class sLSTMLayer(nn.Module):
    def __init__(self, embedding_dim, proj_blocksize, conv_block=True, bias=False):
        super(sLSTMLayer, self).__init__()
        self.inner_embedding_dim = embedding_dim
        self.proj_blocksize = proj_blocksize
        self.conv_block = conv_block
        if conv_block:
            self.conv1d = CasualConv1d(feature_dim=self.inner_embedding_dim, kernel_size=4)
            self.conv_swish = nn.SiLU()

        self.i_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        self.f_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        self.z_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        self.o_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)

        ##############################     EDIT      ################################
        self.slstm_cell = sLSTMCell(self.inner_embedding_dim, num_heads=4)
        ##############################################################

        self.up_proj1 = nn.Linear(in_features=self.inner_embedding_dim, out_features=int((4 / 3) * self.inner_embedding_dim), bias=bias)
        self.up_proj2 = nn.Linear(in_features=self.inner_embedding_dim, out_features=int((4 / 3) * self.inner_embedding_dim), bias=bias)
        self.up_proj2_gelu = nn.GELU()

        self.down_proj = nn.Linear(in_features=int((4 / 3) * self.inner_embedding_dim), out_features=self.inner_embedding_dim, bias=bias)

    def forward(self, x):
        B, S, _ = x.shape

        x_ = F.layer_norm(x, normalized_shape=(self.inner_embedding_dim,))

        if self.conv_block:
            x_conv = self.conv1d(x_)
            x_conv_act = self.conv_swish(x_conv)
        else:
            x_conv_act = x_
        i = self.i_proj(x_conv_act)
        f = self.f_proj(x_conv_act)
        z = self.z_proj(x_)
        o = self.o_proj(x_)

        ###########################     EDIT      ###################################
        y_ = self.slstm_cell(i, f, z, o)
        ##############################################################

        B_, NH_, S_, DH_ = y_.shape
        gn_in_1 = y_.transpose(1, 2)
        gn_in_2 = gn_in_1.reshape(B_ * S_, NH_ * DH_)
        gn_out = F.group_norm(gn_in_2, num_groups=NH_)
        out = gn_out.view(B_, S_, NH_, DH_).transpose(1, 2)
        out = out.transpose(1, 2).view(B, S, -1)

        skip_con = x + out
        skip_con_layer_norm = F.layer_norm(skip_con, normalized_shape=(self.inner_embedding_dim,))

        up_proj1 = self.up_proj1(skip_con_layer_norm)
        up_proj2 = self.up_proj2(skip_con_layer_norm)
        up_proj2_act = self.up_proj2_gelu(up_proj2)
        down_proj = self.down_proj(up_proj2_act * up_proj1)
        y = down_proj + skip_con
        return y

In [7]:
class xLSTM(nn.Module):
    def __init__(self, vocab_size, layers=(3, 1), embedding_dim=128, proj_blocksize=64):
        assert sum(layers) > 0, "Minimum 1 layer."
        super(xLSTM, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim)

        m_layers, s_layers = layers
        self.layers = nn.Sequential(
            *[mLSTMLayer(embedding_dim, proj_blocksize) for _ in range(m_layers)],
            *[sLSTMLayer(embedding_dim, proj_blocksize) for _ in range(s_layers)]
        )

        self.layer_norm = nn.LayerNorm(embedding_dim)
        self.fc = nn.Linear(embedding_dim, vocab_size)

        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_normal_(module.weight)
            if module.bias is not None:
                nn.init.constant_(module.bias, 0)
        elif isinstance(module, nn.Parameter):
            nn.init.normal_(module, mean=0, std=0.02)
        elif isinstance(module, BlockDiagonalProj):
            self._init_weights(module.weight)

    def forward(self, x):
        x = self.embedding(x)
        x = self.layers(x)
        x = self.layer_norm(x)
        return self.fc(x)

    def generate(self, tokenizer, prompt, max_length=200):
        x = tokenizer.encode(prompt).unsqueeze(0)
        for _ in range(max_length):
            logits = self(x)
            x = torch.cat([x, logits.argmax(dim=-1, keepdim=True)], dim=1)
        return tokenizer.decode(x.squeeze(0))

### **Train the Model** (1 point)

In [8]:
class Configuration:
    def __init__(self):
        self.device = device

        self.num_epochs = 5
        self.lr = 1e-4

        train_proportion = 0.8
        batch_size = 64
        seq_len = 128

        self.train_dl, self.val_dl, self.tokenizer = load_data(train_proportion, batch_size, seq_len)
        self.model = xLSTM(len(self.tokenizer))

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.lr)


config = Configuration()

In [9]:
def train_one_epoch(config):
    model = config.model
    device = config.device
    train_dl = config.train_dl
    criterion = config.criterion
    optimizer = config.optimizer

    total_loss = 0

    model.train()
    for X, y in tqdm(train_dl, desc="Train", leave=False):
        X, y = X.to(device), y.to(device)

        logits = model(X)
        loss = criterion(logits.view(-1, logits.shape[-1]), y.view(-1))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss

    total_loss /= len(train_dl)
    return total_loss.item(), total_loss.exp().item()


@torch.inference_mode()
def validate(config):
    model = config.model
    device = config.device
    val_dl = config.val_dl
    criterion = config.criterion

    total_loss = 0
    model.eval()
    for X, y in tqdm(val_dl, desc="Validate", leave=False):
        X, y = X.to(device), y.to(device)
        logits = model(X)
        total_loss += criterion(logits.view(-1, logits.shape[-1]), y.view(-1))

    total_loss /= len(val_dl)
    return total_loss.item(), total_loss.exp().item()


def train(config):
    config.model.to(config.device)
    log = defaultdict(list)
    for epoch in trange(config.num_epochs, desc="Epoch"):
        train_loss, train_ppl = train_one_epoch(config)
        val_loss, val_ppl = validate(config)

        log["train_loss"].append(train_loss)
        log["train_ppl"].append(train_ppl)
        log["val_loss"].append(val_loss)
        log["val_ppl"].append(val_ppl)

        print(f"Epoch {epoch + 1}: Train Loss: {train_loss:.4f}, Train PPL: {train_ppl:.4f}, Val Loss: {val_loss:.4f}, Val PPL: {val_ppl:.4f}")

        if val_loss < min(log["val_loss"]):
            torch.save({
                "epoch": epoch,
                "model_state_dict": config.model.state_dict(),
                "optimizer_state_dict": config.optimizer.state_dict(),
                "loss": val_loss,
            }, "../models/ex06_xlstm_min_val_ppl.pth")

    return log

In [10]:
print(f"Vocabulary size: {len(config.tokenizer)}")
print(f"Train dataset size: {len(config.train_dl.dataset)}")
print(f"Validation dataset size: {len(config.val_dl.dataset)}")
print(f"Model parameters: {sum(p.numel() for p in config.model.parameters() if p.requires_grad):,}")

Vocabulary size: 65
Train dataset size: 892187
Validation dataset size: 222951
Model parameters: 1,203,905


In [None]:
log = train(config)

Epoch:   0%|          | 0/5 [00:00<?, ?it/s]

Train:   0%|          | 0/13941 [00:00<?, ?it/s]

### **Showcasing plots and few input & output examples** (0.5 point)

In [None]:
import matplotlib.pyplot as plt

In [None]:
plt.figure(figsize=(15, 5))

plt.subplot(1, 2, 1)
plt.plot(log["train_loss"], label="Train Loss", marker="o")
plt.plot(log["val_loss"], label="Validation Loss", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training and Validation Loss")
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(log["train_ppl"], label="Train Perplexity", marker="o")
plt.plot(log["val_ppl"], label="Validation Perplexity", marker="s")
plt.xlabel("Epoch")
plt.ylabel("Perplexity")
plt.title("Training and Validation Perplexity")
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

In [None]:
prompts = [
    "ROMEO:",
    "To be ",
    "JULIET:",
    "O Romeo",
]

fmt = f"""Prompt: '{{}}'
{"-" * 40}
{{}}

"""

for prompt in prompts:
    generated = config.model.generate(config.tokenizer, prompt=prompt)
    print(fmt.format(prompt, generated))