In [1]:
!pip3 install torch torchtext torchvision torchaudio portalocker --index-url https://download.pytorch.org/whl/cu121

Looking in indexes: https://download.pytorch.org/whl/cu121
Collecting torchtext
  Downloading https://download.pytorch.org/whl/torchtext-0.16.2%2Bcpu-cp312-cp312-linux_x86_64.whl (2.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m80.1 MB/s[0m eta [36m0:00:00[0m
Collecting portalocker
  Downloading https://download.pytorch.org/whl/portalocker-2.10.1-py3-none-any.whl (18 kB)
Collecting torch
  Downloading https://download.pytorch.org/whl/cu121/torch-2.2.0%2Bcu121-cp312-cp312-linux_x86_64.whl (757.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m757.2/757.2 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
Collecting torchdata==0.7.1 (from torchtext)
  Downloading https://download.pytorch.org/whl/torchdata-0.7.1-py3-none-any.whl (184 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.4/184.4 kB[0m [31m16.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Download

In [2]:
%%writefile dataloader.py
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
import warnings

# Suppress the specific UserWarning from torch.tensor on a tensor
# We are handling this correctly with torch.stack now, but this is good practice.
warnings.filterwarnings("ignore", category=UserWarning, message="To copy construct from a tensor.*")

class AGNewsDataset(Dataset):
    """
    A PyTorch Dataset class to handle the AG News data.
    It processes the raw text iterator into a list of tensors in the constructor.
    """
    def __init__(self, data_iterator, vocab, tokenizer):
        self.data = []
        self.vocab = vocab
        self.tokenizer = tokenizer

        # This loop consumes the iterator and stores the processed data in self.data
        for label, text in data_iterator:
            tokens = self.tokenizer(text)
            indices = self.vocab(tokens)
            indices_tensor = torch.tensor(indices, dtype=torch.long)
            label_tensor = torch.tensor(label - 1, dtype=torch.long)
            self.data.append((label_tensor, indices_tensor))

    def __len__(self):
        """Returns the total number of samples."""
        return len(self.data)

    def __getitem__(self, idx):
        """Returns the processed sample at a given index."""
        return self.data[idx]

def get_dataloaders_and_vocab(batch_size=64):
    """
    Orchestrates the entire data loading process.
    Handles iterator exhaustion correctly by creating fresh iterators for each step.

    Args:
        batch_size (int): The batch size for the DataLoaders.

    Returns:
        tuple: A tuple containing (train_dataloader, test_dataloader, vocab)
    """
    print("--- Starting Data Loading Process ---")

    # --- Step A: Setup Tokenizer ---
    tokenizer = get_tokenizer('basic_english')

    # --- Step B: Build Vocabulary ---
    # We create a fresh iterator here specifically for building the vocabulary.
    # This iterator will be exhausted after this step.
    print("Building vocabulary...")
    train_iter_for_vocab, _ = AG_NEWS(split=('train', 'test'))

    def yield_tokens(data_iter):
        for _, text in data_iter:
            yield tokenizer(text)

    vocab = build_vocab_from_iterator(yield_tokens(train_iter_for_vocab), specials=["<unk>", "<pad>"])
    vocab.set_default_index(vocab["<unk>"])
    print(f"Vocabulary built. Size: {len(vocab)}")

    # --- Step C: Instantiate Datasets ---
    # We create a SECOND set of fresh iterators to pass to our Dataset class.
    print("Processing data and creating Dataset objects...")
    train_iter_for_dataset, test_iter_for_dataset = AG_NEWS(split=('train', 'test'))

    train_dataset = AGNewsDataset(train_iter_for_dataset, vocab, tokenizer)
    test_dataset = AGNewsDataset(test_iter_for_dataset, vocab, tokenizer)
    print("Dataset objects created.")

    # --- Step D: Define the Collate Function ---
    # This function defines how to combine a list of samples into a single batch.
    def collate_batch(batch):
        label_list, text_list = [], []
        for (_label, _text) in batch:
            label_list.append(_label)
            text_list.append(_text)

        # Use torch.stack for labels, as it's the correct way to combine a list of tensors.
        labels_tensor = torch.stack(label_list)

        # Use pad_sequence for text to handle variable lengths.
        texts_tensor = pad_sequence(text_list, batch_first=True, padding_value=vocab['<pad>'])

        return texts_tensor, labels_tensor # Note: Returning (text, label) is more conventional

    # --- Step E: Create DataLoaders ---
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=collate_batch)
    print("--- Data Loading Process Complete ---")

    return train_dataloader, test_dataloader, vocab


if __name__ == '__main__':
    """
    This block runs only when the script is executed directly.
    It's a good way to test that the file is working correctly on its own.
    """
    print("Testing dataloader.py directly...")

    train_dl, test_dl, vocab_obj = get_dataloaders_and_vocab(batch_size=8)

    print(f"\nVocabulary size: {len(vocab_obj)}")

    text_batch, labels_batch = next(iter(train_dl))

    print("\n--- Testing a single batch ---")
    print(f"Text batch shape: {text_batch.shape}")
    print(f"Labels batch shape: {labels_batch.shape}")

    print("\nFirst text tensor in batch:\n", text_batch[0])
    print("\nFirst label in batch:", labels_batch[0])
    print("\ndataloader.py test successful!")

Writing dataloader.py


In [3]:
%%writefile expert.py
import torch
import torch.nn as nn

class Expert(nn.Module):
    """
    A simple feed-forward network, which will serve as an 'expert' in our MoE layer.
    """
    def __init__(self, d_model, d_hidden):
        """
        Args:
            d_model (int): The input and output dimension of the model.
            d_hidden (int): The dimension of the hidden layer.
        """
        super(Expert, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_hidden),
            nn.ReLU(),
            nn.Linear(d_hidden, d_model)
        )

    def forward(self, x):
        """
        Forward pass for the expert.

        Args:
            x (torch.Tensor): The input tensor. Shape: [..., d_model]

        Returns:
            torch.Tensor: The output tensor. Shape: [..., d_model]
        """
        return self.net(x)

Writing expert.py


In [4]:
%%writefile gating.py
import torch
import torch.nn as nn

class Gating(nn.Module):
    """
    A simple linear layer that acts as the gating mechanism in the MoE.
    It decides which experts to route the tokens to.
    """
    def __init__(self, d_model, num_experts):
        """
        Args:
            d_model (int): The dimension of the input tokens.
            num_experts (int): The total number of experts in the MoE layer.
        """
        super(Gating, self).__init__()
        self.gate = nn.Linear(d_model, num_experts)

    def forward(self, x):
        """
        Forward pass for the gating network.

        Args:
            x (torch.Tensor): The input tensor from the Transformer.
                              Shape: [batch_size, seq_len, d_model]

        Returns:
            torch.Tensor: The logits for each expert.
                          Shape: [batch_size, seq_len, num_experts]
        """
        # The output of this layer will be the raw scores (logits) for each expert
        return self.gate(x)

Writing gating.py


In [5]:
%%writefile moe_layer.py
import torch
import torch.nn as nn
import torch.nn.functional as F

from expert import Expert
from gating import Gating

class MoELayer(nn.Module):
    """
    A Mixture of Experts layer.
    """
    def __init__(self, d_model, d_hidden, num_experts, top_k):
        """
        Args:
            d_model (int): The dimension of the input and output.
            d_hidden (int): The hidden dimension of each expert FFN.
            num_experts (int): The total number of experts.
            top_k (int): The number of experts to route each token to.
        """
        super(MoELayer, self).__init__()

        # Basic validation
        if top_k > num_experts:
            raise ValueError("top_k must be less than or equal to num_experts")

        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k

        # Instantiate the experts and the gating network
        self.experts = nn.ModuleList([Expert(d_model, d_hidden) for _ in range(num_experts)])
        self.gating = Gating(d_model, num_experts)

    def compute_load_balancing_loss(self, gating_logits):
        """
        Computes the load balancing loss for the MoE layer.
        This loss encourages the gating network to distribute tokens evenly across experts.

        Args:
            gating_logits (torch.Tensor): The raw logits from the gating network.
                                      Shape: [batch_size * seq_len, num_experts]

        Returns:
            torch.Tensor: A single scalar value representing the load balancing loss.
        """
        #
        # The formula is: alpha * sum(f_i * P_i) for i in experts
        # f_i = fraction of tokens sent to expert i
        # P_i = average probability (gating value) for expert i over tokens sent to it
        #

        # Calculate P_i: softmax over all logits
        gating_probs = F.softmax(gating_logits, dim=-1)

        # Calculate f_i: mean of the one-hot encoding of the chosen expert
        # For top-k > 1, this is more complex. A simplification is to look at the prob distribution.
        # We can calculate the fraction of the "load" each expert gets.
        f_i = gating_probs.mean(dim=0)

        # Calculate P_i: The mean of the probabilities assigned to each expert across all tokens.
        P_i = gating_probs.mean(dim=0)


        # The loss is the dot product of these two vectors, scaled by the number of experts.
        # This encourages the product (and thus both f_i and P_i) to be uniform.
        loss = self.num_experts * torch.sum(f_i * P_i)
        return loss

    def forward(self, x):
        """
        Forward pass for the MoE layer.

        Args:
            x (torch.Tensor): Input tensor. Shape: [batch_size, seq_len, d_model]

        Returns:
            (This will be implemented tomorrow)
        """
        # Reshape input for gating: [batch_size * seq_len, d_model]
        # This treats each token independently.
        batch_size, seq_len, d_model = x.shape
        x_reshaped = x.view(-1, d_model)

        # Get gating logits: [batch_size * seq_len, num_experts]
        gating_logits = self.gating(x_reshaped)

        # Get the top-k experts and their scores (gating values)
        # The scores are softmax-normalized logits for the top-k experts.
        # top_k_gating_values shape: [batch_size * seq_len, top_k]
        # top_k_indices shape: [batch_size * seq_len, top_k]
        top_k_gating_values, top_k_indices = torch.topk(gating_logits, self.top_k, dim=-1)

        # Apply softmax to the top-k logits to get weights
        top_k_gating_values = F.softmax(top_k_gating_values, dim=-1)

        # Create a flat tensor of token indices
        # This will be [0, 0, 1, 1, 2, 2, ...] for top_k=2
        # It helps us track which output belongs to which original token.
        token_indices = torch.arange(x_reshaped.size(0)).repeat_interleave(self.top_k)

        # Create a flat tensor of the chosen expert indices for all tokens
        flat_expert_indices = top_k_indices.flatten()

        # Create our dispatch mask. It's a binary matrix of shape
        # [batch_size * seq_len, num_experts].
        # Entry (i, j) is 1 if token i is routed to expert j, and 0 otherwise.
        dispatch_mask = torch.zeros(x_reshaped.size(0), self.num_experts, device=x.device).bool()
        dispatch_mask.scatter_(1, top_k_indices, True)

        # The final output tensor, initialized to zeros
        final_output = torch.zeros_like(x_reshaped)

        # Now, iterate through each expert.
        for i in range(self.num_experts):
          # Find the tokens that are routed to this expert
          expert_mask = dispatch_mask[:, i]

          # If no tokens are routed to this expert, skip it.
          if not expert_mask.any():
            continue

          # Get the indices of the tokens for this expert
          token_ids_for_expert = expert_mask.nonzero(as_tuple=True)[0]

          # Get the actual input tokens for this expert
          inputs_for_expert = x_reshaped[token_ids_for_expert]

          # Pass the tokens through the expert
          expert_output = self.experts[i](inputs_for_expert)

          # Find the gating values associated with these tokens for this expert
          gating_values_for_expert = top_k_gating_values[dispatch_mask[:, i]]

          # The gating values tensor is currently [num_tokens_for_expert, top_k].
          # We need to find which of the top_k is our current expert 'i'.
          # We create a mask for this.
          k_mask = (top_k_indices[expert_mask] == i)

          # Apply the mask to get the single correct gating value for each token.
          correct_gating_values = gating_values_for_expert[k_mask]

          # Multiply the expert output by the gating values (element-wise)
          weighted_output = expert_output * correct_gating_values.unsqueeze(-1)

          # Add the weighted output to the final output tensor at the correct positions.
          # This is the "combine" step. We use index_add_ for an efficient scatter-add.
          final_output.index_add_(0, token_ids_for_expert, weighted_output)

        # Reshape the final output back to the original input shape
        return final_output.view(batch_size, seq_len, d_model), gating_logits

Writing moe_layer.py


In [6]:
%%writefile model.py
import torch
import torch.nn as nn
import math

from moe_layer import MoELayer

# --- Positional Encoding: A standard, non-learnable component ---
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

class TransformerEncoderLayerWithMoE(nn.Module):
    def __init__(self, d_model, nhead, d_hidden, num_experts, top_k, dropout=0.1):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
        self.moe_layer = MoELayer(d_model, d_hidden, num_experts, top_k)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, src):
        # Attention block
        attn_output, _ = self.self_attn(src, src, src)
        src = src + self.dropout(attn_output)
        src = self.norm1(src)

        # MoE block
        moe_output, gating_logits = self.moe_layer(src)
        src = src + self.dropout(moe_output)
        src = self.norm2(src)
        return src, gating_logits

# --- The Full Classification Model ---
class MoETransformerClassifier(nn.Module):
    def __init__(self, vocab_size, d_model, nhead, d_hidden, num_experts, top_k, num_classes, num_layers):
        super(MoETransformerClassifier, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model)

        # Stack of our custom MoE-enabled encoder layers
        self.transformer_encoder = nn.ModuleList(
            [TransformerEncoderLayerWithMoE(d_model, nhead, d_hidden, num_experts, top_k) for _ in range(num_layers)]
        )

        # The final classification head
        self.classifier = nn.Linear(d_model, num_classes)

    def forward(self, src):
        # src shape: [batch_size, seq_len]
        src = self.embedding(src) * math.sqrt(self.d_model)
        src = self.pos_encoder(src)

        all_gating_logits = []

        # Pass through the stack of encoder layers
        for layer in self.transformer_encoder:
            src, gating_logits = layer(src)
            all_gating_logits.append(gating_logits)

        # Pooling: Average the outputs of all tokens in the sequence
        pooled_output = src.mean(dim=1)

        # Final classification
        output_logits = self.classifier(pooled_output)
        return output_logits, all_gating_logits

Writing model.py


In [7]:
%%writefile train.py
import torch
import torch.nn as nn
from tqdm import tqdm

from dataloader import get_dataloaders_and_vocab
from model import MoETransformerClassifier
import os

# --- Configuration ---
BATCH_SIZE = 32
NUM_EPOCHS = 6
LEARNING_RATE = 1e-4
D_MODEL = 128
NHEAD = 4
D_HIDDEN = 512
NUM_EXPERTS = 8
TOP_K = 2
NUM_LAYERS = 2
NUM_CLASSES = 4
LOAD_BALANCING_ALPHA = 0.005

def train_one_epoch(model, dataloader, optimizer, criterion, device, LOAD_BALANCING_ALPHA):
    model.train()
    total_loss = 0

    progress_bar = tqdm(dataloader, desc=f"Training Epoch")

    for text_batch, labels_batch in progress_bar:
        text_batch, labels_batch = text_batch.to(device), labels_batch.to(device)

        # 1. Forward pass
        # This returns the final classification logits and a list of gating logits from each MoE layer
        output_logits, all_gating_logits = model(text_batch)

        # 2. Calculate main classification loss (Cross-Entropy)
        main_loss = criterion(output_logits, labels_batch)

        # 3. Calculate and sum the load balancing loss across ALL MoE layers
        load_balancing_loss = 0

        # We iterate through the list of logits and the list of encoder layers together.
        # all_gating_logits[i] corresponds to the logits from model.transformer_encoder[i].
        for i, logits in enumerate(all_gating_logits):
            moe_layer = model.transformer_encoder[i].moe_layer
            load_balancing_loss += moe_layer.compute_load_balancing_loss(logits)

        # 4. Combine the losses
        # The total loss is the main classification loss plus the scaled sum of all load balancing losses.
        combined_loss = main_loss + (LOAD_BALANCING_ALPHA * load_balancing_loss)

        # 5. Backward pass and optimization
        optimizer.zero_grad()
        combined_loss.backward()
        optimizer.step()

        total_loss += combined_loss.item()

        progress_bar.set_postfix(loss=combined_loss.item(), main_loss=main_loss.item())

    return total_loss / len(dataloader)


def validate_one_epoch(model, dataloader, criterion, device):
    model.eval()
    total_loss = 0
    correct_predictions = 0
    total_samples = 0

    # We don't need to compute gradients during validation
    with torch.no_grad():
        progress_bar = tqdm(dataloader, desc="Validation Epoch")
        for text_batch, labels_batch in progress_bar:
            text_batch, labels_batch = text_batch.to(device), labels_batch.to(device)

            # 1. Forward pass (we don't need the gating logits for validation)
            output_logits, _ = model(text_batch)

            # 2. Calculate loss
            loss = criterion(output_logits, labels_batch)
            total_loss += loss.item()

            # 3. Calculate accuracy
            # Get the predicted class by finding the index of the max logit
            _, predicted_labels = torch.max(output_logits, 1)

            # Compare predicted labels with the true labels
            correct_predictions += (predicted_labels == labels_batch).sum().item()
            total_samples += labels_batch.size(0)

    avg_loss = total_loss / len(dataloader)
    accuracy = correct_predictions / total_samples
    return avg_loss, accuracy


def save_model_components(model, save_dir="model_weights"):
    """
    Saves the full model and its individual components (gating and experts).
    """
    print(f"\n--- Saving model components to '{save_dir}' ---")

    # Create the directory if it doesn't exist
    os.makedirs(save_dir, exist_ok=True)

    # 1. Save the full model's state_dict
    full_model_path = os.path.join(save_dir, "full_model.pth")
    torch.save(model.state_dict(), full_model_path)
    print(f"Full model saved to {full_model_path}")

    # 2. Save the components of each TransformerEncoderLayerWithMoE
    # We assume for this project that all MoE layers are identical,
    # so we'll just save the components from the first one.
    # In a real-world scenario with different MoE layers, you'd save them all.

    # Let's get the first MoE encoder layer
    first_moe_encoder_layer = model.transformer_encoder[0]

    # Save the gating network
    gating_path = os.path.join(save_dir, "gating_network.pth")
    torch.save(first_moe_encoder_layer.moe_layer.gating.state_dict(), gating_path)
    print(f"Gating network saved to {gating_path}")

    # Save each expert individually
    for i, expert in enumerate(first_moe_encoder_layer.moe_layer.experts):
        expert_path = os.path.join(save_dir, f"expert_{i}.pth")
        torch.save(expert.state_dict(), expert_path)
        print(f"Expert {i} saved to {expert_path}")

    embedding_path = os.path.join(save_dir, "embedding_layer.pth")
    torch.save(model.embedding.state_dict(), embedding_path)
    print(f"Embedding layer saved to {embedding_path}")

    classifier_path = os.path.join(save_dir, "classifier_head.pth")
    torch.save(model.classifier.state_dict(), classifier_path)
    print(f"Classifier head saved to {classifier_path}")

    print("--- All components saved successfully ---")


def main():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    train_dl, test_dl, vocab = get_dataloaders_and_vocab(batch_size=BATCH_SIZE)
    vocab_size = len(vocab)

    model = MoETransformerClassifier(
        vocab_size, D_MODEL, NHEAD, D_HIDDEN, NUM_EXPERTS, TOP_K, NUM_CLASSES, NUM_LAYERS
    ).to(device)


    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

    print("--- Starting Training & Validation ---")
    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch {epoch}/{NUM_EPOCHS} ---")

        # Train for one epoch
        avg_train_loss = train_one_epoch(model, train_dl, optimizer, criterion, device, LOAD_BALANCING_ALPHA)

        # Validate for one epoch
        avg_val_loss, val_accuracy = validate_one_epoch(model, test_dl, criterion, device)

        scheduler.step()

        # Print the results for the epoch
        print(f"End of Epoch {epoch}:")
        print(f"\tAverage Training Loss: {avg_train_loss:.4f}")
        print(f"\tAverage Validation Loss: {avg_val_loss:.4f}")
        print(f"\tValidation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

    save_model_components(model)

if __name__ == "__main__":
    main()

Writing train.py


In [8]:
!python train.py


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/content/train.py", line 1, in <module>
    import torch
  File "/usr/local/lib/python3.12/dist-packages/torch/__init__.py", line 1471, in <module>
    from .functional import *  # noqa: F403
  File "/usr/local/lib/python3.12/dist-packages/torch/functional.py", line 9, in <module>
    import torch.nn.functional as F
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/__init__.py", line 1, in <module>
    from .modules import *  # noqa: F403
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/__ini