From Switch Transformer paper:

>In deep learning, models typically reuse the same parameters for all inputs. Mixture of Experts (MoE) defies this and instead selects different parameters for each incoming example. The result is a sparsely-activated model -- with outrageous numbers of parameters -- but a constant computational cost.

A vanilla Transformer block looks like this:

```python
class ModernTransformerBlock(nn.Module):
    def __init__(self, embed_dim, n_heads, up):
        super().__init__()
        self.attn = nn.MultiheadAttention(embed_dim, n_heads)
        self.mlp = nn.Sequential(
            SwishGLU(embed_dim, embed_dim * up),
            nn.Linear(embed_dim * up, embed_dim),
        )
        self.pre_attn_norm = RMSNorm(embed_dim)
        self.pre_mlp_norm = RMSNorm(embed_dim)
    
    def forward(self, x):
        x = x + self.attn(self.pre_attn_norm(x))
        x = x + self.mlp(self.pre_mlp_norm(x))
        return x
```

The Mixture-of-Experts layer replaces the MLP layer. Instead of having one MLP layer, we have `num_experts` different MLP layers called *experts*.

The idea is to process a contextualized token, by sending it to a subset of experts. In this way we could efficiently increase the number of parameters of the model without affecting computational cost too much.

First, the token is fed into *router*, which determines to which experts a token should go to be processed. For computational reasons, there is a fixed limit on:
* how many tokens an expert can process, and
* by how many experts a token is processed.

# Grading
Your task is to implement a Mixture of Experts layer. You can get points for the following subtasks:
1.  (5 points) Naive implementation of MoE layer that works with `num_experts_per_token>=1`
2.  (5 points) Well-vectorized implementation of MoE layer that works with `num_experts_per_token=1`
3.  (5 points) Implementation of a script testing for 1. 2. implementations output equivalence and performance superiority of 2.
4.  (5 points) Well-vectorized implementation of MoE layer that works with `num_experts_per_token>=1`
5.  (Bonus 5 points) Use Huggingface's Trainer class and compare performance of randomly initialized MoE Transformer and standard Transformer on `https://huggingface.co/datasets/imdb` dataset.

20 points scored in this task is equivalent to at least 16% points achievable in this course.

Please submit your assignments until 15th of April, 18:00 CET.

# Rules
- You shouldn't change basic `forward` and `initialization` signatures of the main classes: `Router` and `MoE`. You can add additional arguments with default values.
- As an assignment, provide a Jupyter notebook with a short introduction at the top of what has been done and where.
- You can add or remove any other classes, though you should keep the behaviour of `MLP` class somehow.
- Sensible vectorization is good enough for the maximum amount of points. There is no need to optimize performance to the max, just show that you can identify opportunities for vectorization and you are able to implement complex vectorizations.
- If in doubt, direct questions to either Jan Ludziejewski or Juliusz Straszyński.
- A notebook that is hard to grade (crashing, obfuscated) might be scored for 0 points.

# Hints
- First, write a naive implementation, vectorized operations might be hard to analyze for correctness.
- You can make randomness deterministic by appropriate torch functions.
- If you have a hard time fulfilling fair randomness for token discarding, you can try keeping the earlier tokens.

In [1]:
%pip install torch_tb_profiler einops

Collecting torch_tb_profiler
  Downloading torch_tb_profiler-0.4.3-py3-none-any.whl (1.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting einops
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, torch_tb_profiler
Successfully installed einops-0.7.0 torch_tb_profiler-0.4.3


In [2]:
from torch import nn
import torch
from transformers import PretrainedConfig
import torch.nn.functional as F
from einops import einsum

class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.ReLU(),
            nn.Linear(config.intermediate_size, config.hidden_size),
        )

    def forward(self, x):
        return self.mlp(x)

# Router
The router is a module which assigns tokens to experts. It answers two questions:
1. Which tokens should be assigned to which expert.
2. How much weight should be assigned to each expert. The weight is determined by similarity between the token embedding and the expert embedding

The following conditions must be satisfied:
1. The routing weights must sum to 1 for each token and be non-negative
2. A token should have exactly `num_experts_per_token` non-zero weights

In [3]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size))
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        pass

The MoE module is a module which wraps around a set of expert modules and a router module.

It takes input embeddings and routes them to the experts.

Each token is processed individually by a subset of experts.

The output token embedding is a weighted sum of the expert outputs.

The weights are determined by the router module.

The subset of experts is determined by non-zero weights in the routing output.

Additionally each expert might process at most `expert_capacity = ceil((batch_size * seq_len) / num_experts * capacity_factor)` tokens

Superfluous tokens to be discarded by a particular expert should be selected uniformly at random.

Discarding should be equivalent to setting the appropriate routing weight to 0, while other weights remain the same.

This means that a token is processed by at most num_experts_per_token experts with a sum of weights of at most 1.

Specifically, this could mean that a token is processed by 0 experts - in this case the resulting embedding should be a zero tensor.

In [4]:
import math

# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class MoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        # You can change experts representation if you want
        self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = math.ceil(batch_size * seq_len / self.num_experts * self.capacity_factor)
        pass

# Configurations

In [5]:
base_config = dict(
    vocab_size=5000,
    max_position_embeddings=256,
    num_attention_heads=8,
    num_hidden_layers=4,
    hidden_dropout_prob=0.1,
    hidden_size=128,
    intermediate_size=512,
    num_labels=2
)

standard_config = PretrainedConfig(
    **base_config,
    ff_cls=MLP
)

moe_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoE
)

# Basic Transformer-related classes

In [6]:
from einops import rearrange

class Embedding(nn.Module):
  def __init__(self, config):
    super(Embedding, self).__init__()
    self.word_embed = nn.Embedding(config.vocab_size, config.hidden_size)
    self.pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
    self.dropout = nn.Dropout(config.hidden_dropout_prob)

  def forward(self, x):
    batch_size, seq_length = x.shape
    device = x.device
    positions = torch.arange(0, seq_length).expand(
        batch_size, seq_length).to(device)
    embedding = self.word_embed(x) + self.pos_embed(positions)
    return self.dropout(embedding)


class MHSelfAttention(nn.Module):
    def __init__(self, config: PretrainedConfig):
        super(MHSelfAttention, self).__init__()
        self.num_attention_heads = config.num_attention_heads
        self.hidden_size = config.hidden_size
        self.head_size = self.hidden_size // self.num_attention_heads
        self.num_attention_heads = config.num_attention_heads
        self.qkv = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)

    def forward(self, embeddings):
        batch_size, seq_length, hidden_size = embeddings.size()

        result = self.qkv(embeddings)
        q, k, v = rearrange(result, 'b s (qkv nah hdsz) -> qkv b nah s hdsz', nah=self.num_attention_heads, qkv=3).unbind(0)

        attention_scores = torch.matmul(q, k.transpose(-1, -2))
        attention_scores = attention_scores / math.sqrt(hidden_size)
        attention_probs = nn.Softmax(dim=-1)(attention_scores)

        contextualized_layer = torch.matmul(attention_probs, v)

        outputs = rearrange(contextualized_layer, 'b nah s hdsz -> b s (nah hdsz)')
        return outputs

class TransformerBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.attention = MHSelfAttention(config)
        self.norm1 = nn.LayerNorm(config.hidden_size)
        self.norm2 = nn.LayerNorm(config.hidden_size)
        self.intermediate = config.ff_cls(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, x):
        x =  x + self.norm1(self.dropout(self.attention(x)))
        x =  x + self.norm2(self.dropout(self.intermediate(x)))
        return x

class TransformerClassifier(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.embeddings = Embedding(config)
        self.layer = nn.Sequential(*[TransformerBlock(config) for _ in range(config.num_hidden_layers)])
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, input_ids, labels=None):
        embedding_output = self.embeddings(input_ids)
        encoding = self.layer(embedding_output)
        pooled_encoding = encoding.mean(dim=1)
        logits = self.classifier(pooled_encoding)
        loss = F.cross_entropy(logits, labels) if labels is not None else None
        return {
            'loss': loss,
            'logits': logits,
        }

# Tokenizer training

In [7]:
%pip install datasets

Collecting datasets
  Downloading datasets-2.18.0-py3-none-any.whl (510 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.5/510.5 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl (116 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m116.3/116.3 kB[0m [31m13.9 MB/s[0m eta [36m0:00:00[0m
Collecting xxhash (from datasets)
  Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.1/194.1 kB[0m [31m23.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting multiprocess (from datasets)
  Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m17.6 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: xxhash, dill, multiprocess, datasets
Successfully installed dataset

In [8]:
from tokenizers import ByteLevelBPETokenizer
from datasets import load_dataset
from tokenizers.processors import BertProcessing

dataset = load_dataset('imdb')

tokenizer = ByteLevelBPETokenizer()
tokenizer.train_from_iterator(
    dataset['train']['text'],
    vocab_size=base_config['vocab_size'],
    special_tokens=["<s>", "</s>", "<pad>"],
    min_frequency=2
)
tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)

tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("<pad>"), pad_token="<pad>", length=base_config['max_position_embeddings'])
tokenizer.model_max_length = base_config['max_position_embeddings']
tokenizer.pad_token = "<pad>"

from transformers import Trainer, TrainingArguments

def tokenize(row):
    return {
        'input_ids': tokenizer.encode(row['text']).ids,
    }

tokenized_dataset = dataset.map(tokenize)

Downloading readme:   0%|          | 0.00/7.81k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.5M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/42.0M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/25000 [00:00<?, ? examples/s]

Generating unsupervised split:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

# **Setting seed for reproducibility.**

In [39]:
# https://pytorch.org/docs/stable/notes/randomness.html

torch.manual_seed(0)

<torch._C.Generator at 0x7c3eb1f9a930>

# **1. Naive implementation of MoE layer that works with num_experts_per_token>=1 and leftmost token choosing strategy.**

In [40]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, num_experts] - expert routing weights
class Router(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts_per_token = config.num_experts_per_token
        self.hidden_size = config.hidden_size
        self.num_experts = config.num_experts

        self.expert_embeddings = nn.Parameter(torch.randn(self.num_experts, self.hidden_size))
        torch.nn.init.kaiming_uniform_(self.expert_embeddings, nonlinearity='linear')

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        similarity = einsum(x, self.expert_embeddings, 'b s h, e h -> b s e')
        top_experts = torch.topk(similarity, self.num_experts_per_token)
        softmaxed_topk_values = F.softmax(top_experts.values, dim=-1)
        mask = torch.zeros_like(similarity, dtype=torch.bool)
        mask = mask.scatter_(-1, top_experts.indices, 1)
        routing_weights = torch.zeros_like(similarity)
        routing_weights[mask] = softmaxed_topk_values.flatten()

        return routing_weights

In [44]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class NaiveMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        # You can change experts representation if you want
        self.experts = nn.ModuleList([MLP(config) for _ in range(self.num_experts)])
        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = torch.ceil(torch.tensor(batch_size * seq_len / self.num_experts * self.capacity_factor, device=x.device, dtype=torch.int))
        routing_weights = self.router(x)
        for i in range(self.num_experts):
            token_indices = torch.nonzero(routing_weights[:, :, i], as_tuple=False)
            if token_indices.shape[0] > expert_capacity:
                routing_weights[token_indices[expert_capacity:, 0], token_indices[expert_capacity:, 1], i] = 0

        expert_outputs = torch.zeros(batch_size, seq_len, self.hidden_size, device=x.device)
        for i in range(self.num_experts):
            token_indices = torch.nonzero(routing_weights[:, :, i], as_tuple=False)
            expert_outputs[token_indices[:, 0], token_indices[:, 1]] += self.experts[i](x[token_indices[:, 0], token_indices[:, 1]]) * routing_weights[token_indices[:, 0], token_indices[:, 1], i].reshape(-1, 1)

        return expert_outputs

In [45]:
from torch.utils.data import DataLoader

naive_moe_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=4,
    ff_cls=NaiveMoE
)

train_loader = DataLoader(tokenized_dataset['train'], batch_size=16, shuffle=True)
test_loader = DataLoader(tokenized_dataset['test'], batch_size=16, shuffle=False)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TransformerClassifier(naive_moe_config).to(DEVICE)
# model = TransformerClassifier(standard_config).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [46]:
from tqdm import tqdm

NUM_OF_EPOCHS = 20

for epoch in range(NUM_OF_EPOCHS):
    model.train()
    train_progress_bar = tqdm(train_loader, desc=f'Train, Epoch {epoch + 1} / {NUM_OF_EPOCHS}')
    running_loss = 0.
    for i, batch in enumerate(train_progress_bar):
        x, y = batch['input_ids'], batch['label']
        x = torch.stack(x, dim=1).to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        loss = model(x, y)['loss']
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if i % 10 == 9:
            last_loss = running_loss / 10 # avg loss per batch
            print('batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_samples = 0
        correct_samples = 0
        test_progress_bar = tqdm(test_loader, desc=f'Test, Epoch {epoch + 1} / {NUM_OF_EPOCHS}')
        for batch in test_progress_bar:
            x, y = batch['input_ids'], batch['label']
            x = torch.stack(x, dim=1).to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)['logits']
            total_loss += F.cross_entropy(logits, y, reduction='sum').item()
            total_samples += y.shape[0]
            correct_samples += (logits.argmax(dim=-1) == y).sum().item()

        print(f'Epoch {epoch + 1}, loss: {total_loss / total_samples}, accuracy: {correct_samples / total_samples}')

Train, Epoch 1 / 20:   1%|          | 13/1563 [00:00<01:25, 18.18it/s]

batch 10 loss: 1.2033341526985168


Train, Epoch 1 / 20:   1%|▏         | 23/1563 [00:01<01:27, 17.63it/s]

batch 20 loss: 0.9138149499893189


Train, Epoch 1 / 20:   2%|▏         | 33/1563 [00:01<01:22, 18.52it/s]

batch 30 loss: 0.8427469909191132


Train, Epoch 1 / 20:   3%|▎         | 43/1563 [00:02<01:19, 19.17it/s]

batch 40 loss: 0.9648739159107208


Train, Epoch 1 / 20:   3%|▎         | 52/1563 [00:02<01:16, 19.80it/s]

batch 50 loss: 0.7992942154407501


Train, Epoch 1 / 20:   4%|▍         | 62/1563 [00:03<01:34, 15.94it/s]

batch 60 loss: 0.877907806634903


Train, Epoch 1 / 20:   5%|▍         | 72/1563 [00:03<01:22, 18.12it/s]

batch 70 loss: 0.9652309834957122


Train, Epoch 1 / 20:   5%|▌         | 83/1563 [00:04<01:18, 18.93it/s]

batch 80 loss: 0.7334680557250977


Train, Epoch 1 / 20:   6%|▌         | 93/1563 [00:05<01:17, 18.96it/s]

batch 90 loss: 0.6813925802707672


Train, Epoch 1 / 20:   7%|▋         | 103/1563 [00:05<01:17, 18.80it/s]

batch 100 loss: 0.7527759790420532


Train, Epoch 1 / 20:   7%|▋         | 113/1563 [00:06<01:17, 18.72it/s]

batch 110 loss: 0.7054231286048889


Train, Epoch 1 / 20:   8%|▊         | 123/1563 [00:06<01:15, 18.99it/s]

batch 120 loss: 0.678711074590683


Train, Epoch 1 / 20:   9%|▊         | 133/1563 [00:07<01:26, 16.53it/s]

batch 130 loss: 0.6902370393276215


Train, Epoch 1 / 20:   9%|▉         | 143/1563 [00:07<01:25, 16.52it/s]

batch 140 loss: 0.6843487322330475


Train, Epoch 1 / 20:  10%|▉         | 153/1563 [00:08<01:34, 14.92it/s]

batch 150 loss: 0.7010039031505585


Train, Epoch 1 / 20:  10%|█         | 162/1563 [00:09<01:18, 17.80it/s]

batch 160 loss: 0.7291490793228149


Train, Epoch 1 / 20:  11%|█         | 172/1563 [00:09<01:12, 19.09it/s]

batch 170 loss: 0.7387413561344147


Train, Epoch 1 / 20:  12%|█▏        | 182/1563 [00:10<01:12, 18.97it/s]

batch 180 loss: 0.6992709934711456


Train, Epoch 1 / 20:  12%|█▏        | 192/1563 [00:10<01:10, 19.54it/s]

batch 190 loss: 0.6891151666641235


Train, Epoch 1 / 20:  13%|█▎        | 203/1563 [00:11<01:09, 19.43it/s]

batch 200 loss: 0.6880708396434784


Train, Epoch 1 / 20:  14%|█▎        | 213/1563 [00:11<01:09, 19.37it/s]

batch 210 loss: 0.6906685948371887


Train, Epoch 1 / 20:  14%|█▍        | 223/1563 [00:12<01:08, 19.46it/s]

batch 220 loss: 0.6733020663261413


Train, Epoch 1 / 20:  15%|█▍        | 232/1563 [00:12<01:07, 19.72it/s]

batch 230 loss: 0.7311889231204987


Train, Epoch 1 / 20:  16%|█▌        | 243/1563 [00:13<01:08, 19.33it/s]

batch 240 loss: 0.6868929386138916


Train, Epoch 1 / 20:  16%|█▌        | 253/1563 [00:13<01:07, 19.48it/s]

batch 250 loss: 0.7263181626796722


Train, Epoch 1 / 20:  17%|█▋        | 262/1563 [00:14<01:05, 19.80it/s]

batch 260 loss: 0.7165864109992981


Train, Epoch 1 / 20:  17%|█▋        | 273/1563 [00:14<01:07, 19.19it/s]

batch 270 loss: 0.6461679220199585


Train, Epoch 1 / 20:  18%|█▊        | 283/1563 [00:15<01:06, 19.37it/s]

batch 280 loss: 0.6539095759391784


Train, Epoch 1 / 20:  19%|█▊        | 292/1563 [00:15<01:06, 19.10it/s]

batch 290 loss: 0.6880915343761445


Train, Epoch 1 / 20:  19%|█▉        | 303/1563 [00:16<01:05, 19.32it/s]

batch 300 loss: 0.6141978979110718


Train, Epoch 1 / 20:  20%|██        | 313/1563 [00:16<01:03, 19.63it/s]

batch 310 loss: 0.6594462335109711


Train, Epoch 1 / 20:  21%|██        | 324/1563 [00:17<01:02, 19.78it/s]

batch 320 loss: 0.6481190741062164


Train, Epoch 1 / 20:  21%|██▏       | 333/1563 [00:17<01:04, 19.10it/s]

batch 330 loss: 0.5855267524719239


Train, Epoch 1 / 20:  22%|██▏       | 343/1563 [00:18<01:03, 19.22it/s]

batch 340 loss: 0.6383742272853852


Train, Epoch 1 / 20:  23%|██▎       | 353/1563 [00:19<01:12, 16.58it/s]

batch 350 loss: 0.6169226199388504


Train, Epoch 1 / 20:  23%|██▎       | 363/1563 [00:19<01:16, 15.78it/s]

batch 360 loss: 0.6006479769945144


Train, Epoch 1 / 20:  24%|██▍       | 373/1563 [00:20<01:17, 15.45it/s]

batch 370 loss: 0.5948901355266571


Train, Epoch 1 / 20:  25%|██▍       | 383/1563 [00:21<01:10, 16.68it/s]

batch 380 loss: 0.6531133234500885


Train, Epoch 1 / 20:  25%|██▌       | 392/1563 [00:21<01:02, 18.87it/s]

batch 390 loss: 0.6320235669612885


Train, Epoch 1 / 20:  26%|██▌       | 402/1563 [00:22<01:01, 18.99it/s]

batch 400 loss: 0.591797998547554


Train, Epoch 1 / 20:  26%|██▋       | 413/1563 [00:22<00:58, 19.50it/s]

batch 410 loss: 0.594695508480072


Train, Epoch 1 / 20:  27%|██▋       | 423/1563 [00:23<01:00, 18.81it/s]

batch 420 loss: 0.5109083950519562


Train, Epoch 1 / 20:  28%|██▊       | 433/1563 [00:23<00:58, 19.29it/s]

batch 430 loss: 0.6194938391447067


Train, Epoch 1 / 20:  28%|██▊       | 443/1563 [00:24<00:59, 18.75it/s]

batch 440 loss: 0.6324471950531005


Train, Epoch 1 / 20:  29%|██▉       | 453/1563 [00:24<00:57, 19.15it/s]

batch 450 loss: 0.6117632180452347


Train, Epoch 1 / 20:  30%|██▉       | 462/1563 [00:25<00:56, 19.37it/s]

batch 460 loss: 0.5558372169733048


Train, Epoch 1 / 20:  30%|███       | 471/1563 [00:25<00:55, 19.72it/s]

batch 470 loss: 0.5228474557399749


Train, Epoch 1 / 20:  31%|███       | 483/1563 [00:26<00:54, 19.69it/s]

batch 480 loss: 0.5589699894189835


Train, Epoch 1 / 20:  32%|███▏      | 493/1563 [00:26<00:54, 19.73it/s]

batch 490 loss: 0.5864254713058472


Train, Epoch 1 / 20:  32%|███▏      | 503/1563 [00:27<00:55, 19.02it/s]

batch 500 loss: 0.5574246495962143


Train, Epoch 1 / 20:  33%|███▎      | 512/1563 [00:27<00:53, 19.54it/s]

batch 510 loss: 0.5965816140174866


Train, Epoch 1 / 20:  33%|███▎      | 522/1563 [00:28<00:56, 18.45it/s]

batch 520 loss: 0.5464736521244049


Train, Epoch 1 / 20:  34%|███▍      | 532/1563 [00:28<00:55, 18.46it/s]

batch 530 loss: 0.568380719423294


Train, Epoch 1 / 20:  35%|███▍      | 541/1563 [00:29<00:53, 19.08it/s]

batch 540 loss: 0.6813136011362075


Train, Epoch 1 / 20:  35%|███▌      | 552/1563 [00:29<00:54, 18.68it/s]

batch 550 loss: 0.6177334904670715


Train, Epoch 1 / 20:  36%|███▌      | 563/1563 [00:30<00:51, 19.54it/s]

batch 560 loss: 0.6047531843185425


Train, Epoch 1 / 20:  37%|███▋      | 571/1563 [00:30<00:55, 17.72it/s]

batch 570 loss: 0.6027579754590988


Train, Epoch 1 / 20:  37%|███▋      | 583/1563 [00:31<01:02, 15.78it/s]

batch 580 loss: 0.5948746383190155


Train, Epoch 1 / 20:  38%|███▊      | 591/1563 [00:32<01:04, 15.06it/s]

batch 590 loss: 0.5811483830213546


Train, Epoch 1 / 20:  39%|███▊      | 603/1563 [00:33<01:04, 14.83it/s]

batch 600 loss: 0.5499907732009888


Train, Epoch 1 / 20:  39%|███▉      | 613/1563 [00:33<00:52, 18.24it/s]

batch 610 loss: 0.5348590016365051


Train, Epoch 1 / 20:  40%|███▉      | 623/1563 [00:34<00:50, 18.79it/s]

batch 620 loss: 0.44903946220874785


Train, Epoch 1 / 20:  40%|████      | 633/1563 [00:34<00:49, 18.96it/s]

batch 630 loss: 0.5378324419260025


Train, Epoch 1 / 20:  41%|████      | 643/1563 [00:35<00:48, 18.97it/s]

batch 640 loss: 0.5037701457738877


Train, Epoch 1 / 20:  42%|████▏     | 653/1563 [00:35<00:47, 19.31it/s]

batch 650 loss: 0.5987352877855301


Train, Epoch 1 / 20:  42%|████▏     | 662/1563 [00:36<00:47, 19.15it/s]

batch 660 loss: 0.6033804237842559


Train, Epoch 1 / 20:  43%|████▎     | 672/1563 [00:36<00:45, 19.38it/s]

batch 670 loss: 0.5475032776594162


Train, Epoch 1 / 20:  44%|████▎     | 682/1563 [00:37<00:45, 19.16it/s]

batch 680 loss: 0.5796170949935913


Train, Epoch 1 / 20:  44%|████▍     | 693/1563 [00:37<00:44, 19.56it/s]

batch 690 loss: 0.5290399789810181


Train, Epoch 1 / 20:  45%|████▍     | 702/1563 [00:38<00:44, 19.27it/s]

batch 700 loss: 0.5218061238527298


Train, Epoch 1 / 20:  46%|████▌     | 713/1563 [00:38<00:43, 19.61it/s]

batch 710 loss: 0.5473307013511658


Train, Epoch 1 / 20:  46%|████▋     | 723/1563 [00:39<00:43, 19.22it/s]

batch 720 loss: 0.5848041355609894


Train, Epoch 1 / 20:  47%|████▋     | 733/1563 [00:39<00:43, 19.12it/s]

batch 730 loss: 0.5667793214321136


Train, Epoch 1 / 20:  47%|████▋     | 742/1563 [00:40<00:42, 19.17it/s]

batch 740 loss: 0.5575024068355561


Train, Epoch 1 / 20:  48%|████▊     | 753/1563 [00:40<00:41, 19.53it/s]

batch 750 loss: 0.5108521848917007


Train, Epoch 1 / 20:  49%|████▉     | 763/1563 [00:41<00:41, 19.06it/s]

batch 760 loss: 0.584150618314743


Train, Epoch 1 / 20:  49%|████▉     | 773/1563 [00:41<00:40, 19.46it/s]

batch 770 loss: 0.480792498588562


Train, Epoch 1 / 20:  50%|█████     | 784/1563 [00:42<00:39, 19.53it/s]

batch 780 loss: 0.5570266991853714


Train, Epoch 1 / 20:  51%|█████     | 793/1563 [00:42<00:39, 19.57it/s]

batch 790 loss: 0.48405068218708036


Train, Epoch 1 / 20:  51%|█████▏    | 803/1563 [00:43<00:46, 16.37it/s]

batch 800 loss: 0.4478864729404449


Train, Epoch 1 / 20:  52%|█████▏    | 813/1563 [00:44<00:46, 16.10it/s]

batch 810 loss: 0.5679599046707153


Train, Epoch 1 / 20:  53%|█████▎    | 821/1563 [00:44<00:51, 14.31it/s]

batch 820 loss: 0.4621555358171463


Train, Epoch 1 / 20:  53%|█████▎    | 834/1563 [00:45<00:40, 17.83it/s]

batch 830 loss: 0.42560127675533294


Train, Epoch 1 / 20:  54%|█████▍    | 843/1563 [00:45<00:37, 19.20it/s]

batch 840 loss: 0.4886484816670418


Train, Epoch 1 / 20:  55%|█████▍    | 853/1563 [00:46<00:35, 19.82it/s]

batch 850 loss: 0.5115957766771316


Train, Epoch 1 / 20:  55%|█████▌    | 862/1563 [00:46<00:34, 20.23it/s]

batch 860 loss: 0.5167519673705101


Train, Epoch 1 / 20:  56%|█████▌    | 873/1563 [00:47<00:34, 19.95it/s]

batch 870 loss: 0.48486549556255343


Train, Epoch 1 / 20:  56%|█████▋    | 882/1563 [00:47<00:33, 20.08it/s]

batch 880 loss: 0.4485791951417923


Train, Epoch 1 / 20:  57%|█████▋    | 891/1563 [00:48<00:35, 19.15it/s]

batch 890 loss: 0.4570061445236206


Train, Epoch 1 / 20:  58%|█████▊    | 904/1563 [00:49<00:33, 19.97it/s]

batch 900 loss: 0.5517386436462403


Train, Epoch 1 / 20:  58%|█████▊    | 912/1563 [00:49<00:34, 18.91it/s]

batch 910 loss: 0.5338814318180084


Train, Epoch 1 / 20:  59%|█████▉    | 924/1563 [00:50<00:32, 19.76it/s]

batch 920 loss: 0.4121210664510727


Train, Epoch 1 / 20:  60%|█████▉    | 932/1563 [00:50<00:32, 19.25it/s]

batch 930 loss: 0.4784955054521561


Train, Epoch 1 / 20:  60%|██████    | 942/1563 [00:50<00:31, 19.97it/s]

batch 940 loss: 0.46554372608661654


Train, Epoch 1 / 20:  61%|██████    | 953/1563 [00:51<00:31, 19.37it/s]

batch 950 loss: 0.4900284051895142


Train, Epoch 1 / 20:  62%|██████▏   | 962/1563 [00:52<00:30, 19.71it/s]

batch 960 loss: 0.45619317293167116


Train, Epoch 1 / 20:  62%|██████▏   | 972/1563 [00:52<00:31, 19.03it/s]

batch 970 loss: 0.45685319006443026


Train, Epoch 1 / 20:  63%|██████▎   | 982/1563 [00:53<00:29, 19.73it/s]

batch 980 loss: 0.5136098623275757


Train, Epoch 1 / 20:  63%|██████▎   | 991/1563 [00:53<00:29, 19.41it/s]

batch 990 loss: 0.4826166838407516


Train, Epoch 1 / 20:  64%|██████▍   | 1003/1563 [00:54<00:28, 19.80it/s]

batch 1000 loss: 0.4537782371044159


Train, Epoch 1 / 20:  65%|██████▍   | 1012/1563 [00:54<00:27, 19.82it/s]

batch 1010 loss: 0.5132469058036804


Train, Epoch 1 / 20:  65%|██████▌   | 1022/1563 [00:55<00:27, 19.96it/s]

batch 1020 loss: 0.45485312044620513


Train, Epoch 1 / 20:  66%|██████▌   | 1032/1563 [00:55<00:33, 15.91it/s]

batch 1030 loss: 0.4299334123730659


Train, Epoch 1 / 20:  67%|██████▋   | 1042/1563 [00:56<00:31, 16.80it/s]

batch 1040 loss: 0.38682681918144224


Train, Epoch 1 / 20:  67%|██████▋   | 1052/1563 [00:56<00:33, 15.38it/s]

batch 1050 loss: 0.4831730544567108


Train, Epoch 1 / 20:  68%|██████▊   | 1063/1563 [00:57<00:27, 18.46it/s]

batch 1060 loss: 0.4305664002895355


Train, Epoch 1 / 20:  69%|██████▊   | 1074/1563 [00:58<00:25, 19.45it/s]

batch 1070 loss: 0.5499860376119614


Train, Epoch 1 / 20:  69%|██████▉   | 1081/1563 [00:58<00:24, 19.83it/s]

batch 1080 loss: 0.5648777514696122


Train, Epoch 1 / 20:  70%|██████▉   | 1094/1563 [00:59<00:23, 20.15it/s]

batch 1090 loss: 0.4853793680667877


Train, Epoch 1 / 20:  71%|███████   | 1103/1563 [00:59<00:22, 20.01it/s]

batch 1100 loss: 0.43616433441638947


Train, Epoch 1 / 20:  71%|███████   | 1113/1563 [01:00<00:22, 19.61it/s]

batch 1110 loss: 0.4013599455356598


Train, Epoch 1 / 20:  72%|███████▏  | 1122/1563 [01:00<00:22, 19.87it/s]

batch 1120 loss: 0.4009762853384018


Train, Epoch 1 / 20:  72%|███████▏  | 1132/1563 [01:01<00:21, 19.88it/s]

batch 1130 loss: 0.4393373727798462


Train, Epoch 1 / 20:  73%|███████▎  | 1143/1563 [01:01<00:20, 20.02it/s]

batch 1140 loss: 0.5046923518180847


Train, Epoch 1 / 20:  74%|███████▎  | 1152/1563 [01:02<00:20, 19.99it/s]

batch 1150 loss: 0.4285038888454437


Train, Epoch 1 / 20:  74%|███████▍  | 1163/1563 [01:02<00:20, 19.76it/s]

batch 1160 loss: 0.43479354232549666


Train, Epoch 1 / 20:  75%|███████▌  | 1173/1563 [01:03<00:19, 19.92it/s]

batch 1170 loss: 0.47049427926540377


Train, Epoch 1 / 20:  76%|███████▌  | 1183/1563 [01:03<00:19, 19.69it/s]

batch 1180 loss: 0.4688050389289856


Train, Epoch 1 / 20:  76%|███████▋  | 1193/1563 [01:04<00:18, 19.58it/s]

batch 1190 loss: 0.5012555465102195


Train, Epoch 1 / 20:  77%|███████▋  | 1203/1563 [01:04<00:18, 19.74it/s]

batch 1200 loss: 0.5143014758825302


Train, Epoch 1 / 20:  78%|███████▊  | 1212/1563 [01:05<00:17, 19.72it/s]

batch 1210 loss: 0.45694091320037844


Train, Epoch 1 / 20:  78%|███████▊  | 1224/1563 [01:05<00:17, 19.84it/s]

batch 1220 loss: 0.48451325595378875


Train, Epoch 1 / 20:  79%|███████▉  | 1233/1563 [01:06<00:16, 19.53it/s]

batch 1230 loss: 0.4853356540203094


Train, Epoch 1 / 20:  79%|███████▉  | 1242/1563 [01:06<00:16, 19.77it/s]

batch 1240 loss: 0.49351574331521986


Train, Epoch 1 / 20:  80%|████████  | 1252/1563 [01:07<00:16, 18.89it/s]

batch 1250 loss: 0.4537286788225174


Train, Epoch 1 / 20:  81%|████████  | 1262/1563 [01:07<00:18, 16.30it/s]

batch 1260 loss: 0.4657557547092438


Train, Epoch 1 / 20:  81%|████████▏ | 1272/1563 [01:08<00:17, 16.28it/s]

batch 1270 loss: 0.465218648314476


Train, Epoch 1 / 20:  82%|████████▏ | 1282/1563 [01:09<00:18, 14.99it/s]

batch 1280 loss: 0.4382157474756241


Train, Epoch 1 / 20:  83%|████████▎ | 1294/1563 [01:09<00:14, 19.13it/s]

batch 1290 loss: 0.4752366542816162


Train, Epoch 1 / 20:  83%|████████▎ | 1303/1563 [01:10<00:13, 19.13it/s]

batch 1300 loss: 0.5344160974025727


Train, Epoch 1 / 20:  84%|████████▍ | 1313/1563 [01:10<00:12, 19.82it/s]

batch 1310 loss: 0.42640105783939364


Train, Epoch 1 / 20:  85%|████████▍ | 1323/1563 [01:11<00:12, 19.41it/s]

batch 1320 loss: 0.3347319424152374


Train, Epoch 1 / 20:  85%|████████▌ | 1332/1563 [01:11<00:11, 19.71it/s]

batch 1330 loss: 0.42906643897295


Train, Epoch 1 / 20:  86%|████████▌ | 1343/1563 [01:12<00:11, 19.47it/s]

batch 1340 loss: 0.4255662143230438


Train, Epoch 1 / 20:  87%|████████▋ | 1352/1563 [01:12<00:10, 19.83it/s]

batch 1350 loss: 0.40485686957836153


Train, Epoch 1 / 20:  87%|████████▋ | 1362/1563 [01:13<00:10, 19.46it/s]

batch 1360 loss: 0.3931199565529823


Train, Epoch 1 / 20:  88%|████████▊ | 1373/1563 [01:13<00:09, 20.19it/s]

batch 1370 loss: 0.4482094258069992


Train, Epoch 1 / 20:  89%|████████▊ | 1384/1563 [01:14<00:08, 19.93it/s]

batch 1380 loss: 0.461888262629509


Train, Epoch 1 / 20:  89%|████████▉ | 1392/1563 [01:14<00:08, 20.23it/s]

batch 1390 loss: 0.47674389779567716


Train, Epoch 1 / 20:  90%|████████▉ | 1403/1563 [01:15<00:08, 19.07it/s]

batch 1400 loss: 0.4375310823321342


Train, Epoch 1 / 20:  90%|█████████ | 1413/1563 [01:15<00:07, 19.75it/s]

batch 1410 loss: 0.42160905003547666


Train, Epoch 1 / 20:  91%|█████████ | 1423/1563 [01:16<00:07, 19.12it/s]

batch 1420 loss: 0.3959992229938507


Train, Epoch 1 / 20:  92%|█████████▏| 1432/1563 [01:16<00:06, 19.70it/s]

batch 1430 loss: 0.4214401230216026


Train, Epoch 1 / 20:  92%|█████████▏| 1442/1563 [01:17<00:06, 19.74it/s]

batch 1440 loss: 0.38138877749443056


Train, Epoch 1 / 20:  93%|█████████▎| 1453/1563 [01:17<00:05, 20.19it/s]

batch 1450 loss: 0.4337517753243446


Train, Epoch 1 / 20:  94%|█████████▎| 1462/1563 [01:18<00:05, 19.81it/s]

batch 1460 loss: 0.36272286921739577


Train, Epoch 1 / 20:  94%|█████████▍| 1472/1563 [01:18<00:04, 20.04it/s]

batch 1470 loss: 0.39216705560684206


Train, Epoch 1 / 20:  95%|█████████▍| 1483/1563 [01:19<00:04, 18.37it/s]

batch 1480 loss: 0.4507389485836029


Train, Epoch 1 / 20:  96%|█████████▌| 1493/1563 [01:19<00:04, 17.25it/s]

batch 1490 loss: 0.453522589802742


Train, Epoch 1 / 20:  96%|█████████▌| 1503/1563 [01:20<00:03, 15.91it/s]

batch 1500 loss: 0.38704138398170473


Train, Epoch 1 / 20:  97%|█████████▋| 1511/1563 [01:21<00:03, 15.14it/s]

batch 1510 loss: 0.4552582919597626


Train, Epoch 1 / 20:  97%|█████████▋| 1521/1563 [01:21<00:02, 17.14it/s]

batch 1520 loss: 0.42420624792575834


Train, Epoch 1 / 20:  98%|█████████▊| 1532/1563 [01:22<00:01, 16.76it/s]

batch 1530 loss: 0.388854107260704


Train, Epoch 1 / 20:  99%|█████████▊| 1541/1563 [01:22<00:01, 18.52it/s]

batch 1540 loss: 0.37162585705518725


Train, Epoch 1 / 20:  99%|█████████▉| 1552/1563 [01:23<00:00, 19.08it/s]

batch 1550 loss: 0.5589667811989785


Train, Epoch 1 / 20: 100%|██████████| 1563/1563 [01:24<00:00, 18.60it/s]


batch 1560 loss: 0.47827038168907166


Test, Epoch 1 / 20: 100%|██████████| 1563/1563 [00:37<00:00, 41.64it/s]


Epoch 1, loss: 0.4485974488687515, accuracy: 0.79136


Train, Epoch 2 / 20:   1%|          | 12/1563 [00:00<01:19, 19.48it/s]

batch 10 loss: 0.47310285568237304


Train, Epoch 2 / 20:   2%|▏         | 24/1563 [00:01<01:18, 19.72it/s]

batch 20 loss: 0.4839359313249588


Train, Epoch 2 / 20:   2%|▏         | 34/1563 [00:01<01:16, 19.92it/s]

batch 30 loss: 0.43701818883419036


Train, Epoch 2 / 20:   3%|▎         | 43/1563 [00:02<01:16, 19.74it/s]

batch 40 loss: 0.36176503002643584


Train, Epoch 2 / 20:   3%|▎         | 53/1563 [00:02<01:17, 19.45it/s]

batch 50 loss: 0.504285940527916


Train, Epoch 2 / 20:   4%|▎         | 56/1563 [00:02<01:19, 19.01it/s]


KeyboardInterrupt: 

# **2. and 4. Vectorized implementation of MoE layer that works with num_experts_per_token>=1. Satisfies 2nd and 4th part of the task.**

In [59]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class VectorizedMoE(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        # You can change experts representation if you want
        self.expert = torch.nn.Linear(self.hidden_size, self.hidden_size)
        self.expert_weights = torch.nn.Parameter(torch.stack([self.expert.weight for _ in range(self.num_experts)], dim=0))
        self.expert_biases = torch.nn.Parameter(torch.stack([self.expert.bias for _ in range(self.num_experts)], dim=0))
        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = torch.ceil(torch.tensor(batch_size * seq_len / self.num_experts * self.capacity_factor, device=x.device, dtype=torch.int))
        routing_weights = self.router(x)
        flat_routing_weights = routing_weights.view(-1, self.num_experts)  # Shape: [batch_size * seq_len, num_experts]
        topk_values, topk_indices = flat_routing_weights.topk(k=expert_capacity, dim=0)
        mask = torch.zeros_like(flat_routing_weights).bool()
        mask.scatter_(0, topk_indices, 1)
        flat_routing_weights = flat_routing_weights * mask.float()

        x_flat = x.reshape(-1, x.size(-1))
        inputs_expanded = x_flat.unsqueeze(1).expand(-1, self.num_experts, -1)
        weighted_inputs = inputs_expanded * flat_routing_weights.unsqueeze(-1)
        combined_inputs = weighted_inputs.reshape(-1, self.hidden_size)
        combined_outputs = torch.matmul(combined_inputs, self.expert_weights.view(-1, self.hidden_size).t()) + self.expert_biases.flatten()
        combined_outputs = combined_outputs.view(self.num_experts, batch_size * seq_len, self.num_experts, self.hidden_size)
        expert_outputs = torch.sum(combined_outputs, dim=(0, 2))
        expert_outputs = expert_outputs.view(batch_size, seq_len, self.hidden_size)

        return expert_outputs

In [60]:
from torch.utils.data import DataLoader

vectorized_moe_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=2,
    ff_cls=VectorizedMoE
)

train_loader = DataLoader(tokenized_dataset['train'], batch_size=16, shuffle=True)
test_loader = DataLoader(tokenized_dataset['test'], batch_size=16, shuffle=False)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TransformerClassifier(vectorized_moe_config).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm

NUM_OF_EPOCHS = 20

for epoch in range(NUM_OF_EPOCHS):
    model.train()
    train_progress_bar = tqdm(train_loader, desc=f'Train, Epoch {epoch + 1} / {NUM_OF_EPOCHS}')
    running_loss = 0.
    for i, batch in enumerate(train_progress_bar):
        x, y = batch['input_ids'], batch['label']
        x = torch.stack(x, dim=1).to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        loss = model(x, y)['loss']
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if i % 10 == 9:
            last_loss = running_loss / 10 # avg loss per batch
            print('batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_samples = 0
        correct_samples = 0
        test_progress_bar = tqdm(test_loader, desc=f'Test, Epoch {epoch + 1} / {NUM_OF_EPOCHS}')
        for batch in test_progress_bar:
            x, y = batch['input_ids'], batch['label']
            x = torch.stack(x, dim=1).to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)['logits']
            total_loss += F.cross_entropy(logits, y, reduction='sum').item()
            total_samples += y.shape[0]
            correct_samples += (logits.argmax(dim=-1) == y).sum().item()

        print(f'Epoch {epoch + 1}, loss: {total_loss / total_samples}, accuracy: {correct_samples / total_samples}')

Train, Epoch 1 / 20:   1%|          | 14/1563 [00:00<01:05, 23.67it/s]

batch 10 loss: 1.2465168297290803


Train, Epoch 1 / 20:   1%|▏         | 23/1563 [00:00<01:01, 25.20it/s]

batch 20 loss: 0.8675061404705048


Train, Epoch 1 / 20:   2%|▏         | 32/1563 [00:01<01:00, 25.42it/s]

batch 30 loss: 0.763215708732605


Train, Epoch 1 / 20:   3%|▎         | 44/1563 [00:01<01:03, 23.82it/s]

batch 40 loss: 0.7114193439483643


Train, Epoch 1 / 20:   3%|▎         | 53/1563 [00:02<01:03, 23.87it/s]

batch 50 loss: 0.7123790502548217


Train, Epoch 1 / 20:   4%|▍         | 62/1563 [00:02<01:07, 22.17it/s]

batch 60 loss: 0.7528236389160157


Train, Epoch 1 / 20:   5%|▍         | 74/1563 [00:03<01:06, 22.55it/s]

batch 70 loss: 0.7408448278903961


Train, Epoch 1 / 20:   5%|▌         | 83/1563 [00:03<01:04, 22.82it/s]

batch 80 loss: 0.7350192070007324


Train, Epoch 1 / 20:   6%|▌         | 95/1563 [00:04<00:57, 25.44it/s]

batch 90 loss: 0.6873619556427002


Train, Epoch 1 / 20:   7%|▋         | 104/1563 [00:04<00:56, 25.94it/s]

batch 100 loss: 0.7421616673469543


Train, Epoch 1 / 20:   7%|▋         | 113/1563 [00:04<00:56, 25.64it/s]

batch 110 loss: 0.7086854696273803


Train, Epoch 1 / 20:   8%|▊         | 125/1563 [00:05<00:55, 26.05it/s]

batch 120 loss: 0.7356543421745301


Train, Epoch 1 / 20:   9%|▊         | 134/1563 [00:05<00:54, 26.07it/s]

batch 130 loss: 0.658912593126297


Train, Epoch 1 / 20:   9%|▉         | 143/1563 [00:05<00:53, 26.30it/s]

batch 140 loss: 0.6885289907455444


Train, Epoch 1 / 20:  10%|▉         | 155/1563 [00:06<00:52, 26.92it/s]

batch 150 loss: 0.6801732361316681


Train, Epoch 1 / 20:  10%|█         | 164/1563 [00:06<00:52, 26.50it/s]

batch 160 loss: 0.7042733669281006


Train, Epoch 1 / 20:  11%|█         | 173/1563 [00:06<00:51, 26.82it/s]

batch 170 loss: 0.7078870058059692


Train, Epoch 1 / 20:  12%|█▏        | 185/1563 [00:07<00:51, 26.92it/s]

batch 180 loss: 0.7379485070705414


Train, Epoch 1 / 20:  12%|█▏        | 194/1563 [00:07<00:51, 26.57it/s]

batch 190 loss: 0.6965594291687012


Train, Epoch 1 / 20:  13%|█▎        | 203/1563 [00:08<00:51, 26.45it/s]

batch 200 loss: 0.6714589655399322


Train, Epoch 1 / 20:  14%|█▍        | 215/1563 [00:08<00:52, 25.67it/s]

batch 210 loss: 0.660686856508255


Train, Epoch 1 / 20:  14%|█▍        | 224/1563 [00:08<00:52, 25.73it/s]

batch 220 loss: 0.6838428139686584


Train, Epoch 1 / 20:  15%|█▍        | 233/1563 [00:09<00:50, 26.38it/s]

batch 230 loss: 0.6863567173480988


Train, Epoch 1 / 20:  16%|█▌        | 245/1563 [00:09<00:49, 26.39it/s]

batch 240 loss: 0.6545214593410492


Train, Epoch 1 / 20:  16%|█▋        | 254/1563 [00:10<00:49, 26.49it/s]

batch 250 loss: 0.6294950544834137


Train, Epoch 1 / 20:  17%|█▋        | 263/1563 [00:10<00:48, 26.77it/s]

batch 260 loss: 0.6758016884326935


Train, Epoch 1 / 20:  18%|█▊        | 275/1563 [00:10<00:49, 26.25it/s]

batch 270 loss: 0.6731827020645141


Train, Epoch 1 / 20:  18%|█▊        | 278/1563 [00:10<00:48, 26.43it/s]

# **3. Naive MoE with the same MLP architecture and token choosing strategy as the vectorized version for comparison.**

In [None]:
# Input: [batch_size, seq_len, hidden_size] - input embeddings
# Output: [batch_size, seq_len, hidden_size] - output embeddings
class NaiveMoEForComparison(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.num_experts = config.num_experts
        self.hidden_size = config.hidden_size
        self.num_experts_per_token = config.num_experts_per_token
        self.capacity_factor = config.capacity_factor

        # You can change experts representation if you want
        self.experts = nn.ModuleList([torch.nn.Linear(self.hidden_size, self.hidden_size) for _ in range(self.num_experts)])
        self.router = Router(config)

    def forward(self, x):
        batch_size, seq_len, hidden_size = x.shape
        expert_capacity = torch.ceil(torch.tensor(batch_size * seq_len / self.num_experts * self.capacity_factor, device=x.device, dtype=torch.int))
        routing_weights = self.router(x)

        routing_weights = routing_weights.view(-1, self.num_experts) # Shape: [batch_size * seq_len, num_experts]
        for i in range(self.num_experts):
            exp_i_routing_weights = routing_weights[:, i]
            _, topk_indices = exp_i_routing_weights.topk(k=expert_capacity, dim=0)
            mask = torch.zeros_like(exp_i_routing_weights).bool()
            mask.scatter_(0, topk_indices, 1)
            exp_i_routing_weights = exp_i_routing_weights * mask.float()
            routing_weights[:, i] = exp_i_routing_weights

        routing_weights = routing_weights.view(batch_size, seq_len, self.num_experts)
        expert_outputs = torch.zeros(batch_size, seq_len, self.hidden_size, device=x.device)
        for i in range(self.num_experts):
            token_indices = torch.nonzero(routing_weights[:, :, i], as_tuple=False)
            expert_outputs[token_indices[:, 0], token_indices[:, 1]] += self.experts[i](x[token_indices[:, 0], token_indices[:, 1]]) * routing_weights[token_indices[:, 0], token_indices[:, 1], i].reshape(-1, 1)

        return expert_outputs

In [None]:
from torch.utils.data import DataLoader

naive_moe_for_comparison_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=2,
    ff_cls=NaiveMoEForComparison
)

train_loader = DataLoader(tokenized_dataset['train'], batch_size=16, shuffle=True)
test_loader = DataLoader(tokenized_dataset['test'], batch_size=16, shuffle=False)

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
model = TransformerClassifier(naive_moe_for_comparison_config).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
from tqdm import tqdm

NUM_OF_EPOCHS = 20

for epoch in range(NUM_OF_EPOCHS):
    model.train()
    train_progress_bar = tqdm(train_loader, desc=f'Train, Epoch {epoch + 1} / {NUM_OF_EPOCHS}')
    running_loss = 0.
    for i, batch in enumerate(train_progress_bar):
        x, y = batch['input_ids'], batch['label']
        x = torch.stack(x, dim=1).to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        loss = model(x, y)['loss']
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

        if i % 10 == 9:
            last_loss = running_loss / 10 # avg loss per batch
            print('batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    model.eval()
    with torch.no_grad():
        total_loss = 0
        total_samples = 0
        correct_samples = 0
        test_progress_bar = tqdm(test_loader, desc=f'Test, Epoch {epoch + 1} / {NUM_OF_EPOCHS}')
        for batch in test_progress_bar:
            x, y = batch['input_ids'], batch['label']
            x = torch.stack(x, dim=1).to(DEVICE)
            y = y.to(DEVICE)
            logits = model(x)['logits']
            total_loss += F.cross_entropy(logits, y, reduction='sum').item()
            total_samples += y.shape[0]
            correct_samples += (logits.argmax(dim=-1) == y).sum().item()

        print(f'Epoch {epoch + 1}, loss: {total_loss / total_samples}, accuracy: {correct_samples / total_samples}')

# **Comparing Naive MoE and the Vectorized version.**