From Switch Transformer paper:

>In deep learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) defies this and instead selects different parameters for each incoming example. The result is a sparsely-activated model -- with outrageous numbers of parameters -- but a constant computational cost.

A vanilla Transformer block looks like this:

```python
class ModernTransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, up):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, n_heads)
        self.mlp = nn.Sequential(
            SwishGLU(embed_dim, embed_dim * up),
            nn.Linear(embed_dim * up, embed_dim),
        )
        self.pre_attn_norm = RMSNorm(embed_dim)
        self.pre_mlp_norm = RMSNorm(embed_dim)
    
    def forward(self, x):
        x = x + self.attn(self.pre_attn_norm(x))
        x = x + self.mlp(self.pre_mlp_norm(x))
        return x
```

The Mixture-of-Experts layer replaces the MLP layer. Instead of having one MLP layer, we have `num_experts` different MLP layers called *experts*.

The idea is to process a contextualized token, by sending it to a subset of experts. In this way we could efficiently increase the number of parameters of the model without affecting computational cost too much.

First, the token is fed into *router*, which determines to which experts a token should go to be processed. For computational reasons, there is a fixed limit on:
* how many tokens an expert can process, and
* by how many experts a token is processed.

# Grading
Your task is to implement a Mixture of Experts layer. You can get points for the following subtasks:
1.  (5 points) Naive implementation of MoE layer that works with `num_experts_per_token>=1`
2.  (5 points) Well-vectorized implementation of MoE layer that works with `num_experts_per_token=1`
3.  (5 points) Implementation of a script testing for 1. 2. implementations output equivalence and performance superiority of 2.
4.  (5 points) Well-vectorized implementation of MoE layer that works with `num_experts_per_token>=1`
5.  (Bonus 5 points) Use Huggingface's Trainer class and compare performance of randomly initialized MoE Transformer and standard Transformer on `https://huggingface.co/datasets/imdb` dataset.

20 points scored in this task is equivalent to at least 16% points achievable in this course.

Please submit your assignments until 15th of April, 18:00 CET.

# Rules
- You shouldn't change basic `forward` and `initialization` signatures of the main classes: `Router` and `MoE`. You can add additional arguments with default values.
- As an assignment, provide a Jupyter notebook with a short introduction at the top of what has been done and where.
- You can add or remove any other classes, though you should keep the behaviour of `MLP` class somehow.
- Sensible vectorization is good enough for the maximum amount of points. There is no need to optimize performance to the max, just show that you can identify opportunities for vectorization and you are able to implement complex vectorizations.
- If in doubt, direct questions to either Jan Ludziejewski or Juliusz Straszyński.
- A notebook that is hard to grade (crashing, obfuscated) might be scored for 0 points.

# Hints
- First, write a naive implementation, vectorized operations might be hard to analyze for correctness.
- You can make randomness deterministic by appropriate torch functions.
- If you have a hard time fulfilling fair randomness for token discarding, you can try keeping the earlier tokens.

In [None]:
%pip install torch_tb_profiler einops

In [None]:
from torch import nn
import torch
from transformers import PretrainedConfig
import torch.nn.functional as F
from einops import einsum

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.ReLU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )

    def forward(self, x):
        return self.mlp(x)

# Router
The router is a module which assigns tokens to experts. It answers two questions:
1. Which tokens should be assigned to which expert.
2. How much weight should be assigned to each expert. The weight is determined by similarity between the token embedding and the expert embedding

The following conditions must be satisfied:
1. The routing weights must sum to 1 for each token and be non-negative
2. A token should have exactly `num_experts_per_token` non-zero weights

In [None]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size))
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        pass

The MoE module is a module which wraps around a set of expert modules and a router module.

It takes input embeddings and routes them to the experts.

Each token is processed individually by a subset of experts.

The output token embedding is a weighted sum of the expert outputs.

The weights are determined by the router module.

The subset of experts is determined by non-zero weights in the routing output.

Additionally each expert might process at most `expert_capacity = ceil((batch_size * seq_len) / num_experts * capacity_factor)` tokens

Superfluous tokens to be discarded by a particular expert should be selected uniformly at random.

Discarding should be equivalent to setting the appropriate routing weight to 0, while other weights remain the same.

This means that a token is processed by at most num_experts_per_token experts with a sum of weights of at most 1.

Specifically, this could mean that a token is processed by 0 experts - in this case the resulting embedding should be a zero tensor.

In [None]:
import math

# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class MoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        # You can change experts representation if you want
        self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)
        pass

# Configurations

In [None]:
base_config = dict(
    vocab_size=5000,
    max_position_embeddings=256,
    num_attention_heads=8,
    num_hidden_layers=4,
    hidden_dropout_prob=0.1,
    hidden_size=128,
    intermediate_size=512,
    num_labels=2
)

standard_config = PretrainedConfig(
    **base_config,
    ff_cls=MLP
)

moe_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoE
)

# Basic Transformer-related classes

In [None]:
from einops import rearrange

class Embedding(nn.Module):
  def __init__(self, config):
    super(Embedding, self).__init__()
    self.word_embed = nn.Embedding(config.vocab_size, config.hidden_size)
    self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, x):
    batch_size, seq_length = x.shape
    device = x.device
    positions = torch.arange(0, seq_length).expand(
        batch_size, seq_length).to(device)
    embedding = self.word_embed(x) + self.pos_embed(positions)
    return self.dropout(embedding)


class MHSelfAttention(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super(MHSelfAttention, self).__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_attention_heads
        self.num_attention_heads = config.num_attention_heads
        self.qkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)

    def forward(self, embeddings):
        batch_size, seq_length, hidden_size = embeddings.size()

        result = self.qkv(embeddings)
        q, k, v = rearrange(result, 'b s (qkv nah hdsz) -> qkv b nah s hdsz', nah=self.num_attention_heads, qkv=3).unbind(0)

        attention_scores = torch.matmul(q, k.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(hidden_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        contextualized_layer = torch.matmul(attention_probs, v)

        outputs = rearrange(contextualized_layer, 'b nah s hdsz -> b s (nah hdsz)')
        return outputs

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MHSelfAttention(config)
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.norm2 = nn.LayerNorm(config.hidden_size)
        self.intermediate = config.ff_cls(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x =  x + self.norm1(self.dropout(self.attention(x)))
        x =  x + self.norm2(self.dropout(self.intermediate(x)))
        return x

class TransformerClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embedding(config)
        self.layer = nn.Sequential(*[TransformerBlock(config) for _ in range(config.num_hidden_layers)])
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, input_ids, labels=None):
        embedding_output = self.embeddings(input_ids)
        encoding = self.layer(embedding_output)
        pooled_encoding = encoding.mean(dim=1)
        logits = self.classifier(pooled_encoding)
        loss = F.cross_entropy(logits, labels) if labels is not None else None
        return {
            'loss': loss,
            'logits': logits,
        }

# Tokenizer training

In [None]:
from tokenizers import ByteLevelBPETokenizer
from datasets import load_dataset
from tokenizers.processors import BertProcessing

dataset = load_dataset('imdb')

tokenizer = ByteLevelBPETokenizer()
tokenizer.train_from_iterator(
    dataset['train']['text'],
    vocab_size=base_config['vocab_size'],
    special_tokens=["<s>", "</s>", "<pad>"],
    min_frequency=2
)
tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)

tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("<pad>"), pad_token="<pad>", length=base_config['max_position_embeddings'])
tokenizer.model_max_length = base_config['max_position_embeddings']
tokenizer.pad_token = "<pad>"

from transformers import Trainer, TrainingArguments

def tokenize(row):
    return {
        'input_ids': tokenizer.encode(row['text']).ids,
    }

tokenized_dataset = dataset.map(tokenize)