# Exercise 6

## Group 
- **ID**: <your group ID>

- **Members**: 
    - <your name1>
    - <your name2>
    - <your name3>

## Hand-in
- Please hand in this notebook with your code implementation via Ilias 
- Please make sure that there is exactly **one** submission per group

## Task Description

In this exercise, you will implement a custom Extended Long Short-Term Memory (xLSTM) model to predict the next tokens given an input sequence. The Model is described in the paper [xLSTM: Extended Long Short-Term Memory](https://arxiv.org/abs/2405.04517).

You will work with the “Tiny Shakespeare” dataset, a character-level corpus of Shakespeare’s plays and sonnets, commonly used for next-character prediction. The dataset is available at [Github](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt).

You will implement a custom character‐level tokenizer and DataLoader, write your costum Model(with different classes) and train it, plot the Perplexity score and the loss curve and finally showcase input–output text samples from your trained xLSTM.

** NEW **:
We provide some of the mLSTM and sLSTM code, as illustrated in Figures 10 and 11 of the xLSTM paper. For this part, you only need to implement the mLSTMCell and sLSTMCell classes, the gray boxes shown in those figures, and integrate them with the rest of the code. You’re free to modify any part of the provided code.

## Grading scheme
Total: 5 points
1. **Preparing the Tokenizer and Dataloader** (1 point)
2. **Preparing the Model** (2.5 point)
3. **Train the Model** (1 point)
4. **Showcasing plots and few input & output examples** (0.5 point)

### Imports

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

### **Preparing the Tokenizer and Dataloader** (1 point)

### **Preparing the Model** (2.5 point)

#### components

In [2]:
class BlockDiagonalProj(nn.Module):
    def __init__(self, input_dim, num_heads):
        super(BlockDiagonalProj, self).__init__()
        self.input_dim = input_dim
        self.num_heads = num_heads
        self.out_head_size = input_dim // num_heads
        self.weight = nn.Parameter(torch.empty(num_heads, self.out_head_size, input_dim // num_heads))

    def forward(self, x):
        shape = x.shape
        x = x.view(*shape[:-1], self.num_heads, -1)
        x = torch.einsum("...hd,hod->...ho", x, self.weight)
        x = x.reshape(*shape[:-1], -1)
        return x
    
class CasualConv1d(nn.Module):
    def __init__(self, feature_dim, kernel_size, bias=True):
        super(CasualConv1d, self).__init__()
        self.pad = (kernel_size -1)
        self.conv = nn.Conv1d(in_channels=feature_dim, out_channels=feature_dim, kernel_size=kernel_size, padding=self.pad, groups=feature_dim, bias=bias)
    def forward(self, x):
        y = x.transpose(2, 1)
        y = self.conv(y)
        return y[:, :, : -self.pad].transpose(2, 1)
       

#### mLSTM block

In [54]:
### COMPLETE THIS CLASS ####
class mLSTMCell(nn.Module):
    def __init__(self,):
        super(mLSTMCell, self).__init__()
    def forward(self,):
        pass
#############################  

class mLSTMLayer(nn.Module): 
    def __init__(self, embedding_dim, proj_blocksize, bias=False):
        super(mLSTMLayer, self).__init__()
        self.outer_embedding_dim = embedding_dim
        self.inner_embedding_dim = 2 * embedding_dim
        self.proj_blocksize = proj_blocksize
        self.bias = bias
        
        self.proj_up = nn.Linear(in_features=self.outer_embedding_dim, 
                                 out_features= 2 * self.inner_embedding_dim, 
                                 bias=bias)
        self.num_proj_heads = self.inner_embedding_dim // proj_blocksize
        self.q_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        self.k_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        self.v_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=self.num_proj_heads)
        
        self.conv1d = CasualConv1d(feature_dim=self.inner_embedding_dim, kernel_size=4)
        self.conv_swish = nn.SiLU()
        
        ############################     EDIT      ##################################
        self.mlstm_cell = mLSTMCell()
        ##############################################################
        
        self.ogate_swish = nn.SiLU()
        self.learnable_skip_con = nn.Parameter(torch.ones(self.inner_embedding_dim, requires_grad=True))
        self.proj_down = nn.Linear(in_features=self.inner_embedding_dim,
                                 out_features=self.outer_embedding_dim, 
                                 bias=bias)
        
        
        
    def forward(self, x):
        B, S, _ = x.shape
        x_ = F.layer_norm(x, normalized_shape=(self.outer_embedding_dim,))
        x_inner = self.proj_up(x_)  
        x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.inner_embedding_dim, dim=-1)
        x_mlstm_conv = self.conv1d(x_mlstm)
        x_mlstm_conv_act = self.conv_swish(x_mlstm_conv)
        
        q = self.q_proj(x_mlstm_conv_act)
        k = self.k_proj(x_mlstm_conv_act)
        v = self.v_proj(x_mlstm)
        
        ##########################     EDIT      ####################################
        mlstm_cell_state = self.mlstm_cell()
        ##############################################################
        
        mlstm_cell_skip = mlstm_cell_state + (self.learnable_skip_con * x_mlstm_conv_act)
        
        h_state = mlstm_cell_skip * self.ogate_swish(z)
        
        y = self.proj_down(h_state) + x
        
        return y
        

#### sLSTM block

In [53]:
### COMPLETE THIS CLASS ####
class sLSTMCell(nn.Module):
    def __init__(self,):
        super(sLSTMCell, self).__init__()
    def forward(self,):
        pass
#############################  
    
class sLSTMLayer(nn.Module): 
    def __init__(self, embedding_dim, proj_blocksize, conv_block=True, bias=False):
        super(sLSTMLayer, self).__init__()
        self.inner_embedding_dim = embedding_dim
        self.proj_blocksize = proj_blocksize
        self.conv_block = conv_block
        if conv_block:
            self.conv1d = CasualConv1d(feature_dim=self.inner_embedding_dim, kernel_size=4)
            self.conv_swish = nn.SiLU()
        
        self.i_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        self.f_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        self.z_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        self.o_proj = BlockDiagonalProj(input_dim=self.inner_embedding_dim, num_heads=4)
        
        ##############################     EDIT      ################################
        self.slstm_cell = sLSTMCell()
        ##############################################################
        
        self.up_proj1 = nn.Linear(in_features=self.inner_embedding_dim, out_features= int((4/3)*self.inner_embedding_dim), bias=bias)
        self.up_proj2 = nn.Linear(in_features=self.inner_embedding_dim, out_features= int((4/3)*self.inner_embedding_dim), bias=bias)
        self.up_proj2_gelu = nn.GELU()
        
        self.down_proj = nn.Linear(in_features=int((4/3)*self.inner_embedding_dim), out_features=self.inner_embedding_dim, bias=bias)
        
    def forward(self, x):
        B, S, _ = x.shape
        
        x_ = F.layer_norm(x, normalized_shape=(self.inner_embedding_dim,))
        
        if self.conv_block:
            x_conv = self.conv1d(x_)
            x_conv_act = self.conv_swish(x_conv)
        else:
            x_conv_act = x_
        i = self.i_proj(x_conv_act)
        f = self.f_proj(x_conv_act)
        z = self.z_proj(x_)
        o = self.o_proj(x_)
        
        ###########################     EDIT      ###################################
        y_ = self.slstm_cell()
        ##############################################################
        
        B_, NH_, S_, DH_ = y_.shape
        gn_in_1 = y_.transpose(1, 2)
        gn_in_2 = gn_in_1.reshape(B_ * S_, NH_ * DH_)
        gn_out = F.group_norm(gn_in_2, num_groups=NH_)
        out = gn_out.view(B_, S_, NH_, DH_).transpose(1, 2)
        out = out.transpose(1, 2).view(B, S, -1)
        
        skip_con = x + out
        skip_con_layer_norm = F.layer_norm(skip_con, normalized_shape=(self.inner_embedding_dim,))
        
        up_proj1 = self.up_proj1(skip_con_layer_norm)
        up_proj2 = self.up_proj2(skip_con_layer_norm)
        up_proj2_act = self.up_proj2_gelu(up_proj2)
        down_proj = self.down_proj(up_proj2_act * up_proj1)
        y = down_proj + skip_con
        return y

### **Train the Model** (1 point)

### **Showcasing plots and few input & output examples** (0.5 point)