In [1]:
from google.colab import drive
drive.mount('/content/drive')

save_base_path = "/content/drive/MyDrive/MoE"

Mounted at /content/drive


# MoE based GPT 2

## Import Dependencies

In [2]:

import torch
import torch.nn as nn
from torch.optim import Optimizer
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, Dataset
from typing import Optional
import copy
import os
import json
from typing import Callable, Iterable, Tuple
from utils import *
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import csv
from tqdm import tqdm

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

cuda


## MoE Router Module

In [3]:

class TopKRouter(nn.Module):
    def __init__(self, hidden_size: int, num_experts: int, top_k: int = 1):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.gate = nn.Linear(hidden_size, num_experts, bias=False)
        torch.nn.init.normal_(self.gate.weight, mean=0.0, std=0.08)

    def forward(self, hidden_states):
        # hidden_states: [batch, seq, hidden]?
        batch_size, seq_len, _ = hidden_states.shape
        router_logits = self.gate(hidden_states)  # [batch, seq, num_experts]

        # Calculate probabilities
        routing_weights = torch.softmax(router_logits, dim=-1)

        # --- Load Balancing Loss (Switch Transformer style) ---
        # Two components:
        # 1. f_i = fraction of tokens dispatched to expert i (hard assignment)
        # 2. P_i = average probability assigned to expert i (soft assignment)
        # Loss = num_experts * sum(f_i * P_i)

        # Get top-k for hard assignment
        top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)

        # Compute f_i: fraction of tokens assigned to each expert
        # Create one-hot encoding of expert assignments
        expert_mask = torch.zeros(
            batch_size, seq_len, self.num_experts,
            device=hidden_states.device
        )
        # For top-k, we scatter 1.0 for selected experts
        expert_mask.scatter_(
            2,
            top_k_indices,
            torch.ones_like(top_k_indices, dtype=expert_mask.dtype)
        )

        # Fraction of tokens routed to each expert (averaged over batch and sequence)
        tokens_per_expert = expert_mask.sum(dim=(0, 1))  # [num_experts]
        fraction_per_expert = tokens_per_expert / (batch_size * seq_len)  # f_i

        # Compute P_i: average routing probability for each expert
        prob_per_expert = routing_weights.mean(dim=(0, 1))  # [num_experts]

        # Load balancing loss
        # Multiply by num_experts to make it scale-invariant
        aux_loss = self.num_experts * torch.sum(fraction_per_expert * prob_per_expert)

        # Normalize top-k weights
        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)

        return top_k_weights, top_k_indices, router_logits, aux_loss

## MoE Layer

In [4]:
class MoELayer(nn.Module):
    """Mixture of Experts layer replacing the MLP"""
    def __init__(self, dense_mlp, num_experts: int = 8, top_k: int = 4, drop_ratio: float = 0.0, load_balance_weight: float = 0.01):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.load_balance_weight = load_balance_weight # <--- New parameter

        # Hugging Face GPT-2 uses Conv1D, where weight shape is [input, output].
        hidden_size = dense_mlp.c_fc.weight.shape[0]
        # Create router
        self.router = TopKRouter(hidden_size, num_experts, top_k)

        # Create experts by copying the dense MLP weights with optional drop-upcycling
        self.experts = nn.ModuleList([
            self._copy_mlp_with_drop(dense_mlp, drop_ratio) for _ in range(num_experts)
        ])

        # Placeholder for the loss to be accessed by the Trainer
        self.aux_loss = 0.0

    def _copy_mlp_with_drop(self, dense_mlp, drop_ratio: float):
        """
        Create a copy of the dense MLP with drop-upcycling.
        Re-initializes drop_ratio% of parameters to promote diversity.
        """
        expert = copy.deepcopy(dense_mlp)

        if drop_ratio > 0:
            with torch.no_grad():
                for name, param in expert.named_parameters():
                    # Create a mask for parameters to re-initialize
                    mask = torch.rand_like(param) < drop_ratio

                    # Re-initialize masked parameters with small random values
                    if mask.any():
                        param.data[mask] = torch.randn_like(param[mask]) * 0.02

        return expert

    def forward(self, hidden_states):
        batch_size, seq_len, hidden_size = hidden_states.shape
        total_tokens = batch_size * seq_len

        # Route tokens - NOW UNPACKING 4 VALUES
        top_k_weights, top_k_indices, router_logits, aux_loss = self.router(hidden_states)

        # Store weighted aux_loss for the trainer to retrieve later
        self.aux_loss = aux_loss * self.load_balance_weight

        # Flatten everything
        flat_hidden = hidden_states.view(total_tokens, hidden_size)
        flat_weights = top_k_weights.view(total_tokens, self.top_k)
        flat_indices = top_k_indices.view(total_tokens, self.top_k)

        # Initialize output
        output = torch.zeros_like(flat_hidden)

        # Create a dispatch mask: [num_experts, total_tokens * top_k]
        flat_indices_1d = flat_indices.view(-1)  # [total_tokens * top_k]
        flat_weights_1d = flat_weights.view(-1)  # [total_tokens * top_k]

        # Token indices repeated for each k
        token_indices = torch.arange(total_tokens, device=hidden_states.device)
        token_indices = token_indices.unsqueeze(1).expand(-1, self.top_k).reshape(-1)

        # Process each expert
        for expert_idx in range(self.num_experts):
            expert_mask = (flat_indices_1d == expert_idx)

            if not expert_mask.any():
                continue

            expert_token_indices = token_indices[expert_mask]
            expert_weights = flat_weights_1d[expert_mask]
            expert_input = flat_hidden[expert_token_indices]

            expert_output = self.experts[expert_idx](expert_input)
            weighted_output = expert_output * expert_weights.unsqueeze(-1)
            output.index_add_(0, expert_token_indices, weighted_output)

        output = output.view(batch_size, seq_len, hidden_size)

        return output

In [5]:
def calculate_active_params(model):
    """Calculate total parameters and active parameters in MoE layers"""
    total_params = sum(p.numel() for p in model.parameters())

    # Calculate MoE-specific info
    moe_info = {
        'total_experts': 0,
        'active_per_token': 0,
        'total_expert_params': 0,
        'active_expert_params': 0,
    }

    for name, module in model.named_modules():
        if isinstance(module, MoELayer):
            moe_info['total_experts'] += module.num_experts
            moe_info['active_per_token'] += module.top_k

            # Count expert parameters
            expert_params = sum(p.numel() for p in module.experts.parameters())
            moe_info['total_expert_params'] += expert_params
            moe_info['active_expert_params'] += (expert_params / module.num_experts) * module.top_k

    return total_params, moe_info

## Upcycle GPT2 Vanilla Weights to MoE Architecture

In [6]:
def upcycle_gpt2_to_moe(
    model_name: str = 'gpt2',
    num_experts: int = 8,
    top_k: int = 2,
    moe_layers: Optional[list] = None,
    drop_ratio: float = 0.05,
    match_active_params: bool = False,
    load_balance_weight: float = 0.01
):
    """
    Convert a standard GPT-2 model to MoE architecture
    """
    # Load the pre-trained model with LM head
    model = GPT2LMHeadModel.from_pretrained(model_name)
    original_params = sum(p.numel() for p in model.parameters())

    # If no specific layers specified, convert all layers
    if moe_layers is None:
        moe_layers = list(range(len(model.transformer.h)))

    # Auto-adjust for fair comparison
    if match_active_params:
        top_k = 1
        print(f"Fair comparison mode: Setting top_k={top_k} to match vanilla GPT-2 active params")

    upcycle_type = "Drop-Upcycling" if drop_ratio > 0 else "Standard Upcycling"
    print(f"Converting layers {moe_layers} to MoE with {num_experts} experts (top-{top_k})")
    print(f"Using {upcycle_type}" + (f" with {drop_ratio*100}% parameter re-initialization" if drop_ratio > 0 else ""))

    # Replace MLPs with MoE layers
    for layer_idx in moe_layers:
        if layer_idx >= len(model.transformer.h):
            print(f"Warning: Layer {layer_idx} doesn't exist, skipping")
            continue

        original_mlp = model.transformer.h[layer_idx].mlp

        model.transformer.h[layer_idx].mlp = MoELayer(
            original_mlp,
            num_experts=num_experts,
            top_k=top_k,
            drop_ratio=drop_ratio,
            load_balance_weight=load_balance_weight
        )

        print(f"Converted layer {layer_idx}")

    # Print parameter comparison
    total_params, moe_info = calculate_active_params(model)
    print("\n" + "=" * 60)
    print("PARAMETER COMPARISON")
    print("=" * 60)
    print(f"Original model params:        {original_params:,}")
    print(f"MoE model total params:       {total_params:,}")
    print(f"MoE model active params:      {int(original_params + moe_info['active_expert_params']):,}")
    print(f"\nPer-layer breakdown:")
    print(f"  Experts per layer:          {num_experts}")
    print(f"  Active experts per token:   {top_k}")
    print(f"  Active ratio:               {top_k}/{num_experts} = {top_k/num_experts:.1%}")

    if match_active_params:
        print(f"\nFair comparison mode: Active params â‰ˆ vanilla GPT-2")
    else:
        active_ratio = (original_params + moe_info['active_expert_params']) / original_params
        print(f"\nMoE has {active_ratio:.1f}x active parameters vs vanilla")

    return model

In [7]:
def save_moe_model(model, save_path):
    """Save MoE model with custom layers"""
    os.makedirs(save_path, exist_ok=True)

    # Save model state dict
    torch.save(model.state_dict(), os.path.join(save_path, 'pytorch_model.bin'))

    # Save config
    model.config.save_pretrained(save_path)

    # Save MoE configuration
    moe_config = {
        'moe_layers': [],
        'num_experts': None,
        'top_k': None,
    }

    for layer_idx, layer in enumerate(model.transformer.h):
        if isinstance(layer.mlp, MoELayer):
            moe_config['moe_layers'].append(layer_idx)
            if moe_config['num_experts'] is None:
                moe_config['num_experts'] = layer.mlp.num_experts
                moe_config['top_k'] = layer.mlp.top_k

    with open(os.path.join(save_path, 'moe_config.json'), 'w') as f:
        json.dump(moe_config, f)

    print(f"Model saved to {save_path}")

In [8]:
def load_moe_model(load_path, device='cuda'):
    """Load MoE model with custom layers"""
    # Load MoE config
    with open(os.path.join(load_path, 'moe_config.json'), 'r') as f:
        moe_config = json.load(f)

    # Load base model
    base_model = GPT2LMHeadModel.from_pretrained('gpt2')

    # Convert to MoE architecture
    for layer_idx in moe_config['moe_layers']:
        original_mlp = base_model.transformer.h[layer_idx].mlp
        base_model.transformer.h[layer_idx].mlp = MoELayer(
            original_mlp,
            num_experts=moe_config['num_experts'],
            top_k=moe_config['top_k'],
            drop_ratio=0.0
        )

    # Load trained weights
    state_dict = torch.load(os.path.join(load_path, 'pytorch_model.bin'), map_location=device)
    base_model.load_state_dict(state_dict)

    print(f"Model loaded from {load_path}")
    return base_model

In [None]:
moe_model = upcycle_gpt2_to_moe(
        model_name='gpt2',
        num_experts=4,
        top_k=2,
        drop_ratio=0.3,
        load_balance_weight=0.01,
        moe_layers=[1, 3, 5, 7, 9, 11],
    )

save_moe_model(moe_model, f'{save_base_path}/gpt2-moe-upcycled')

del moe_model

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Converting layers [1, 3, 5, 7, 9, 11] to MoE with 4 experts (top-2)
Using Drop-Upcycling with 30.0% parameter re-initialization
Converted layer 1
Converted layer 3
Converted layer 5
Converted layer 7
Converted layer 9
Converted layer 11

PARAMETER COMPARISON
Original model params:        124,439,808
MoE model total params:       209,462,016
MoE model active params:      181,108,992

Per-layer breakdown:
  Experts per layer:          4
  Active experts per token:   2
  Active ratio:               2/4 = 50.0%

MoE has 1.5x active parameters vs vanilla
Model saved to /content/drive/MyDrive/MoE/gpt2-moe-upcycled


## Finetune MoE Model

In [None]:
class MoETrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        outputs = model(**inputs)

        if isinstance(outputs, dict):
            lm_loss = outputs['loss']
        else:
            lm_loss = outputs[0]

        total_aux_loss = 0.0

        for name, module in model.named_modules():
            if isinstance(module, MoELayer) and hasattr(module, 'aux_loss'):
                total_aux_loss += module.aux_loss

        total_loss = lm_loss + total_aux_loss

        return (total_loss, outputs) if return_outputs else total_loss

def finetune_moe_model(
    moe_model_path: str = f'{save_base_path}/gpt2-moe-upcycled',
    output_dir: str = f'{save_base_path}/gpt2-moe-finetuned',
    num_train_steps: int = 5000,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    warmup_steps: int = 200,
    gradient_accumulation_steps: int = 8,
    save_steps: int = 1000,
    max_length: int = 512,
):
    """
    Fine-tune the upcycled MoE model
    """
    print("=" * 60)
    print("FINE-TUNING MoE MODEL")
    print("=" * 60)

    # Load the model using custom loader
    print(f"Loading model from {moe_model_path}...")
    model = load_moe_model(moe_model_path)

    print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")

    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    print(f"\nLoading dataset...")
    dataset = load_dataset('allenai/c4', 'en', split='train', streaming=True)
    dataset = dataset.shuffle(seed=42, buffer_size=1000).take(100_000)

    # Tokenize dataset
    def tokenize_function(examples):
        return tokenizer(
            examples['text'],
            truncation=True,
            max_length=max_length,
            padding=False,
            return_tensors=None,
        )

    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=['text', 'timestamp', 'url']
    )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        max_steps=num_train_steps,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        warmup_steps=312,
        lr_scheduler_type='cosine',
        max_grad_norm=0.5,
        fp16=False,

        save_strategy='epoch',
        save_total_limit=3,
        logging_steps=10,
        logging_first_step=True,
        dataloader_num_workers=2,
        remove_unused_columns=False,
        report_to='none',
    )

    # Create trainer
    trainer = MoETrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )

    # Train
    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)
    print(f"Total steps: {num_train_steps}")
    print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
    print(f"Learning rate: {learning_rate}")
    print(f"Warmup steps: {warmup_steps}\n")

    trainer.train()

    print("\n" + "=" * 60)
    print("SAVING FINAL MODEL")
    print("=" * 60)
    save_moe_model(model, output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model and tokenizer saved to {output_dir}")

    return model

In [None]:
finetuned_model = finetune_moe_model(
        moe_model_path=f'{save_base_path}/gpt2-moe-upcycled',
        output_dir=f'{save_base_path}/gpt2-moe-finetuned',
        num_train_steps=12500,
        batch_size=16,
        gradient_accumulation_steps=2,
        learning_rate=5e-5,
        warmup_steps=500,
    )

print("Successfully fintetuned model at path ./gpt2-moe-fintuned")

del finetuned_model

FINE-TUNING MoE MODEL
Loading model from /content/drive/MyDrive/MoE/gpt2-moe-upcycled...
Model loaded from /content/drive/MyDrive/MoE/gpt2-moe-upcycled
Model loaded with 209,462,016 parameters


tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]


Loading dataset...


README.md: 0.00B [00:00, ?B/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Tokenizing dataset...

STARTING TRAINING
Total steps: 12500
Effective batch size: 32
Learning rate: 5e-05
Warmup steps: 500



HfHubHTTPError: Caught HfHubHTTPError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_http.py", line 402, in hf_raise_for_status
    response.raise_for_status()
  File "/usr/local/lib/python3.12/dist-packages/requests/models.py", line 1026, in raise_for_status
    raise HTTPError(http_error_msg, response=self)
requests.exceptions.HTTPError: 429 Client Error: Too Many Requests for url: https://us.gcp.cdn.hf.co/xet-bridge-us/621ffdd236468d709f182a80/01d3f3cd44a4bddd2e1523c934aee67449e4ccb5e6609a4c37813f2170a96da3?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27c4-train.00051-of-01024.json.gz%3B+filename%3D%22c4-train.00051-of-01024.json.gz%22%3B&response-content-type=application%2Fgzip&Expires=1764720114&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiRXBvY2hUaW1lIjoxNzY0NzIwMTE0fX0sIlJlc291cmNlIjoiaHR0cHM6Ly91cy5nY3AuY2RuLmhmLmNvL3hldC1icmlkZ2UtdXMvNjIxZmZkZDIzNjQ2OGQ3MDlmMTgyYTgwLzAxZDNmM2NkNDRhNGJkZGQyZTE1MjNjOTM0YWVlNjc0NDllNGNjYjVlNjYwOWE0YzM3ODEzZjIxNzBhOTZkYTNcXD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=qZ7HKsD7-LfroiBOgut5OyT5i56T6X5QOIvP~ylzdPuPmNcgRLQfWD1Ag5zWzmpwMTvpaW3NFlQtQtxeTmSQmgE~V9tORQ7DkUMWtSsI1fzx0ve-zNvbXRUoU~cQttIjX2URTMtGPVSlUV-A8v96JXWPyux9cA3xUAORTbYkTjqlXGbCbiO5sMg0UWQ53IpAbzcBSu7ZWgnwdNU7J6EZkUmzskMB5-bCwWGwrkSB8~7-Rs0153640Q1N~OnY60QNyt4ZZX-q3cp~y~aC6wwvMhkSb2ZObbM1PhTZI5Mlj0Hrafaedj8szZWGJh4Wl6SmAYWrYRXxN2Ebj7DSAnf8sw__&Key-Pair-Id=KJLH8B0YWU4Y8M

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/worker.py", line 349, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/_utils/fetch.py", line 33, in fetch
    data.append(next(self.dataset_iter))
                ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 2347, in __iter__
    yield from self._iter_pytorch()
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 2262, in _iter_pytorch
    for key, example in ex_iterable:
                        ^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1107, in __iter__
    yield from self._iter()
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1286, in _iter
    for key, transformed_example in outputs:
                                    ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1283, in <genexpr>
    for key, transformed_batch in outputs
                                  ^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1267, in iter_outputs
    for i, key_example in inputs_iterator:
                          ^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1129, in iter_batched_inputs
    for key, example in iterator:
                        ^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1892, in __iter__
    for key, example in self.ex_iterable:
                        ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1745, in __iter__
    for key_example in islice(self.ex_iterable, self.n - ex_iterable_num_taken):
                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 1558, in __iter__
    for x in self.ex_iterable:
             ^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/iterable_dataset.py", line 325, in __iter__
    for key, pa_table in self.generate_tables_fn(**gen_kwags):
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/packaged_modules/json/json.py", line 123, in _generate_tables
    batch = f.read(self.config.chunksize)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/datasets/utils/file_utils.py", line 807, in read_with_retries
    out = read(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/gzip.py", line 338, in read
    return self._buffer.read(size)
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/_compression.py", line 68, in readinto
    data = self.read(len(byte_view))
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/gzip.py", line 544, in read
    if not self._read_gzip_header():
           ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/gzip.py", line 513, in _read_gzip_header
    last_mtime = _read_gzip_header(self._fp)
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.12/gzip.py", line 468, in _read_gzip_header
    magic = fp.read(2)
            ^^^^^^^^^^
  File "/usr/lib/python3.12/gzip.py", line 104, in read
    self.file.read(size-self._length+read)
  File "/usr/local/lib/python3.12/dist-packages/huggingface_hub/hf_file_system.py", line 1016, in read
    return super().read(length)
           ^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fsspec/spec.py", line 2083, in read
    out = self.cache._fetch(self.loc, self.loc + length)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/fsspec/caching.py", line 249, in _fetch
    self.cache = self.fetcher(start, end)  # new block replaces old
                 ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/huggingface_hub/hf_file_system.py", line 977, in _fetch_range
    hf_raise_for_status(r)
  File "/usr/local/lib/python3.12/dist-packages/huggingface_hub/utils/_http.py", line 475, in hf_raise_for_status
    raise _format(HfHubHTTPError, str(e), response) from e
huggingface_hub.errors.HfHubHTTPError: 429 Client Error: Too Many Requests for url: https://us.gcp.cdn.hf.co/xet-bridge-us/621ffdd236468d709f182a80/01d3f3cd44a4bddd2e1523c934aee67449e4ccb5e6609a4c37813f2170a96da3?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27c4-train.00051-of-01024.json.gz%3B+filename%3D%22c4-train.00051-of-01024.json.gz%22%3B&response-content-type=application%2Fgzip&Expires=1764720114&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiRXBvY2hUaW1lIjoxNzY0NzIwMTE0fX0sIlJlc291cmNlIjoiaHR0cHM6Ly91cy5nY3AuY2RuLmhmLmNvL3hldC1icmlkZ2UtdXMvNjIxZmZkZDIzNjQ2OGQ3MDlmMTgyYTgwLzAxZDNmM2NkNDRhNGJkZGQyZTE1MjNjOTM0YWVlNjc0NDllNGNjYjVlNjYwOWE0YzM3ODEzZjIxNzBhOTZkYTNcXD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSomcmVzcG9uc2UtY29udGVudC10eXBlPSoifV19&Signature=qZ7HKsD7-LfroiBOgut5OyT5i56T6X5QOIvP~ylzdPuPmNcgRLQfWD1Ag5zWzmpwMTvpaW3NFlQtQtxeTmSQmgE~V9tORQ7DkUMWtSsI1fzx0ve-zNvbXRUoU~cQttIjX2URTMtGPVSlUV-A8v96JXWPyux9cA3xUAORTbYkTjqlXGbCbiO5sMg0UWQ53IpAbzcBSu7ZWgnwdNU7J6EZkUmzskMB5-bCwWGwrkSB8~7-Rs0153640Q1N~OnY60QNyt4ZZX-q3cp~y~aC6wwvMhkSb2ZObbM1PhTZI5Mlj0Hrafaedj8szZWGJh4Wl6SmAYWrYRXxN2Ebj7DSAnf8sw__&Key-Pair-Id=KJLH8B0YWU4Y8M


## MoE Base Model Sanity Check

In [28]:
from torch.utils.data import DataLoader
import math

load_path = f'{save_base_path}/gpt2-moe-finetuned/checkpoint-8000'

# --- 2. Load MoE Configuration (moe_config.json assumed to be present) ---
with open(os.path.join(load_path, 'moe_config.json'), 'r') as f:
    moe_config = json.load(f)

# --- 3. Rebuild MoE Architecture (Same as before) ---
model = GPT2LMHeadModel.from_pretrained('gpt2')
for layer_idx in moe_config['moe_layers']:
    original_mlp = model.transformer.h[layer_idx].mlp
    model.transformer.h[layer_idx].mlp = MoELayer(
        original_mlp,
        num_experts=moe_config['num_experts'],
        top_k=moe_config['top_k'],
        drop_ratio=0.0
    )
tokenizer = GPT2Tokenizer.from_pretrained(f'{save_base_path}/gpt2-moe-finetuned')
tokenizer.pad_token = tokenizer.eos_token
model.to('cuda')

dataset = load_dataset('allenai/c4', 'en', split='validation', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10).take(1000)

def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=512,
        padding=False,
        return_tensors=None,
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['text', 'timestamp', 'url']
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

val_loader = DataLoader(
    tokenized_dataset,
    batch_size=1,
    collate_fn=data_collator
)

total_loss = 0.0
total_steps = 0
total_correct = 0
total_tokens = 0

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Validating"):
        batch = {k: v.to('cuda') for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = batch["labels"][..., 1:].contiguous()

        predictions = torch.argmax(shift_logits, dim=-1)

        mask = shift_labels != -100

        correct = (predictions == shift_labels) & mask

        total_correct += correct.sum().item()
        total_tokens += mask.sum().item()
        total_loss += loss.item()
        total_steps += 1

avg_loss = total_loss / total_steps
perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')
accuracy = total_correct / total_tokens if total_tokens > 0 else 0.0

print("\n" + "-" * 30)
print(f"Validation Loss: {avg_loss:.4f}")
print(f"Perplexity:      {perplexity:.2f}")
print(f"Token Accuracy:  {accuracy:.2%}")
print("-" * 30)

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Tokenizing dataset...


Validating: 0it [00:00, ?it/s]`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.
Validating: 1000it [00:30, 32.42it/s]


------------------------------
Validation Loss: 3.6683
Perplexity:      39.18
Token Accuracy:  36.74%
------------------------------





In [None]:
import torch.nn.functional as F

model = load_moe_model(f'{save_base_path}/gpt2-moe-finetuned', device='cuda')

def check_expert_similarity(model):
    print(f"{'Layer':<10} | {'Pair':<15} | {'Cosine Similarity':<20} | {'Status'}")
    print("-" * 65)

    for layer_idx, layer in enumerate(model.transformer.h):
        if not isinstance(layer.mlp, MoELayer):
            continue

        # Get the experts
        experts = layer.mlp.experts
        num_experts = len(experts)

        # We'll check similarity between Expert 0 and the others
        # Flatten weights for comparison: [hidden_size * intermediate_size]
        w_0 = experts[0].c_fc.weight.flatten()

        for i in range(1, num_experts):
            w_i = experts[i].c_fc.weight.flatten()

            # Compute Cosine Similarity
            similarity = F.cosine_similarity(w_0.unsqueeze(0), w_i.unsqueeze(0)).item()

            status = "DIVERGED (Good)" if similarity < 0.95 else "REDUNDANT (Bad)"
            print(f"{layer_idx:<10} | {'0 vs ' + str(i):<15} | {similarity:.4f}               | {status}")

# Run the check
check_expert_similarity(model)

del model
torch.cuda.empty_cache()

## Vanilla GPT2 Baseline

In [None]:
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.to('cuda')

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

output_dir = f'{save_base_path}/gpt2-finetuned'

num_train_steps =
batch_size = 32
learning_rate = 1e-4
warmup_steps = 312
gradient_accumulation_steps = 1
max_length = 512

print(f"\nLoading dataset...")
dataset = load_dataset('allenai/c4', 'en', split='train', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=1000).take(100_000)

# Tokenize dataset
def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=512,
        padding=False,
        return_tensors=None,
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['text', 'timestamp', 'url']
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

# Training arguments
training_args = TrainingArguments(
    output_dir=output_dir,
    max_steps=num_train_steps,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=learning_rate,
    warmup_steps=312,
    lr_scheduler_type='cosine',
    max_grad_norm=0.5,
    fp16=False,

    logging_steps=10,
    logging_first_step=True,
    dataloader_num_workers=2,
    remove_unused_columns=False,
    report_to='none',
)

# Create trainer
trainer = MoETrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset,
    data_collator=data_collator,
)

# Train
print("\n" + "=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Total steps: {num_train_steps}")
print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
print(f"Learning rate: {learning_rate}")
print(f"Warmup steps: {warmup_steps}\n")

trainer.train()

print("\n" + "=" * 60)
print("SAVING FINAL MODEL")
print("=" * 60)
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)
print(f"Model and tokenizer saved to {output_dir}")


Loading dataset...


Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Tokenizing dataset...

STARTING TRAINING
Total steps: 3125
Effective batch size: 32
Learning rate: 0.0001
Warmup steps: 312



Step,Training Loss
1,3.8321
10,3.7492
20,3.7366
30,3.7582
40,3.6537
50,3.63
60,3.5194
70,3.688
80,3.514
90,3.6249


'(ProtocolError('Connection aborted.', RemoteDisconnected('Remote end closed connection without response')), '(Request ID: 1582bf83-7c3d-4e5b-890d-a866bbd43be5)')' thrown while requesting GET https://huggingface.co/datasets/allenai/c4/resolve/1588ec454efa1a09f29cd18ddd04fe05fc8653a2/en/c4-train.00051-of-01024.json.gz
Retrying in 1s [Retry 1/5].



SAVING FINAL MODEL
Model and tokenizer saved to /content/drive/MyDrive/MoE/gpt2-finetuned


In [None]:
from torch.utils.data import DataLoader
import math

model = GPT2LMHeadModel.from_pretrained(f'{save_base_path}/gpt2-finetuned')
model.to('cuda')
model.eval()
tokenizer = tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

dataset = load_dataset('allenai/c4', 'en', split='validation', streaming=True)
dataset = dataset.shuffle(seed=42, buffer_size=10).take(1000)

def tokenize_function(examples):
    return tokenizer(
        examples['text'],
        truncation=True,
        max_length=512,
        padding=False,
        return_tensors=None,
    )

print("Tokenizing dataset...")
tokenized_dataset = dataset.map(
    tokenize_function,
    batched=True,
    remove_columns=['text', 'timestamp', 'url']
)

data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer,
    mlm=False
)

val_loader = DataLoader(
    tokenized_dataset,
    batch_size=1,
    collate_fn=data_collator
)

total_loss = 0.0
total_steps = 0
total_correct = 0
total_tokens = 0

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Validating"):
        batch = {k: v.to('cuda') for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss
        logits = outputs.logits
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = batch["labels"][..., 1:].contiguous()

        predictions = torch.argmax(shift_logits, dim=-1)

        mask = shift_labels != -100

        correct = (predictions == shift_labels) & mask

        total_correct += correct.sum().item()
        total_tokens += mask.sum().item()
        total_loss += loss.item()
        total_steps += 1

avg_loss = total_loss / total_steps
perplexity = math.exp(avg_loss) if avg_loss < 100 else float('inf')
accuracy = total_correct / total_tokens if total_tokens > 0 else 0.0

print("\n" + "-" * 30)
print(f"Validation Loss: {avg_loss:.4f}")
print(f"Perplexity:      {perplexity:.2f}")
print(f"Token Accuracy:  {accuracy:.2%}")
print("-" * 30)

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/1024 [00:00<?, ?it/s]

Tokenizing dataset...


Validating: 1000it [00:18, 52.86it/s]


------------------------------
Validation Loss: 3.4644
Perplexity:      31.96
Token Accuracy:  38.09%
------------------------------





## Load NLI Datset

In [9]:
def compute_accuracy(preds, labels):
    correct = sum(p.lower().strip() == l.lower().strip() for p, l in zip(preds, labels))
    return correct / len(labels)

def generate_gpt2(model, tokenizer, input_ids, max_gen_length=50, device="cuda"):
    model.eval()
    input_ids = input_ids.to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=max_gen_length,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False  # Greedy decoding
        )

    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

def evaluate_gpt2_xnli(model, tokenizer, dataloader, max_gen_length=10, device="cuda"):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for item in tqdm(dataloader, desc="Generating"):
            input_ids = item['input_ids']
            gen_text = generate_gpt2(model, tokenizer, input_ids, max_gen_length=max_gen_length, device=device)
            pred_label = gen_text.split("Label:")[-1].strip()
            all_preds.append(pred_label)
            all_labels.extend(item['label_strs'])
    acc = compute_accuracy(all_preds, all_labels)
    print(f"Evaluation accuracy: {acc*100:.2f}%")
    return acc, all_preds, all_labels

class XNLIDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset for XNLI (Cross-lingual Natural Language Inference) task.

    Supports train, dev, and test splits in a specific language,
    tokenizes text inputs for GPT-style models, and optionally subsamples the dataset.

    Attributes:
        split (str): Dataset split, one of 'train', 'dev', 'test'.
        lang (str): Language code (e.g., 'en', 'zh').
        tokenizer: A HuggingFace tokenizer to convert text to input IDs.
        max_length (int): Maximum sequence length for tokenization.
        LABEL2ID (dict): Mapping from textual labels to integer IDs.
        ID2LABEL (dict): Reverse mapping from integer IDs to textual labels.
        data (pd.DataFrame): The loaded and preprocessed dataset.
    """
    def __init__(
        self,
        split="train",
        lang="en",
        train_path_template="XNLI-MT-1.0/multinli/multinli.train.{lang}.tsv",
        test_path="XNLI-1.0/xnli.test.tsv",
        dev_path="XNLI-1.0/xnli.dev.tsv",
        tokenizer=None,
        max_length=1024,
        subset = 1.0  # 0~1
    ):
        self.split = split
        self.lang = lang
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.LABEL2ID = {"entailment": 0, "contradictory": 1, "neutral": 2}
        self.ID2LABEL = {v: k for k, v in self.LABEL2ID.items()}

        if split == "train":
            path = train_path_template.format(lang=lang)
            df = self.read_xnli_tsv(path, split)
            df = df.dropna(subset=['premise','hypo','label'])
        elif split in ["dev", "test"]:
            path = test_path if split=="test" else dev_path
            df = self.read_xnli_tsv(path, split)
            df = df[df['language']==lang].copy()
            keep_cols = ['sentence1', 'sentence2', 'gold_label']
            df = df[keep_cols].dropna()
            df.rename(columns={'sentence1':'premise','sentence2':'hypo','gold_label':'label'}, inplace=True)
            df['label'] = df['label'].replace({'contradiction': 'contradictory'})
        else:
            raise ValueError("split must be one of ['train','dev','test']")

        original_num = len(df)
        if subset < 1.0:
            n = max(1, int(len(df) * subset))
            df = df.iloc[:n].reset_index(drop=True)
        subset_num = len(df)

        self.data = df.reset_index(drop=True)
        print(f"Dataset initialized: split='{split}', lang='{lang}', total={original_num}, subset={subset}, subset_count={subset_num}")

    def read_xnli_tsv(self, path, split):
        """
        Read an XNLI TSV file and return it as a pandas DataFrame.

        Args:
            path (str): Path to the TSV file.
            split (str): One of "train", "dev", "test" indicating the dataset split.

        Returns:
            pd.DataFrame: The dataset as a DataFrame with appropriate columns.
        """
        if split == "train":
            with open(path, "r", encoding="utf-8") as f:
                lines = f.read().splitlines()
            header = lines[0].split("\t")
            data = []
            for i, line in enumerate(lines[1:], start=2):
                parts = line.split("\t")
                if len(parts) == len(header):
                    data.append(parts)
                else:
                    print(f"skip row {i}: {len(parts)} cols â†’ {parts[:2]}")
        else:
            with open(path, "r", encoding="utf-8") as f:
                reader = csv.reader(f, delimiter="\t")
                rows = list(reader)
            header = rows[0]
            expected_cols = len(header)
            data = []
            for i, row in enumerate(rows[1:], start=2):
                if len(row) == expected_cols:
                    data.append(row)
                else:
                    print(f"skip row {i}: {len(row)} cols â†’ {row[:2]}")
        return pd.DataFrame(data, columns=header)

    def __len__(self):
        """Return the number of examples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve a single example by index and tokenize it.

        For training split:
            - Constructs the input as "Premise: ... Hypothesis: ... Label: ..."
            - Tokenizes the full input.
            - Masks the prefix tokens in the labels with -100 for GPT loss computation.

        For dev/test split:
            - Constructs the input without label as "Premise: ... Hypothesis: ... Label:"

        Returns:
            dict: Contains 'input_ids', 'attention_mask', 'labels' (train only), 'label_str'
        """
        row = self.data.iloc[idx]
        premise = row['premise']
        hypo = row['hypo']
        label = row['label']
        if self.lang == 'zh': # de-tokenize for Chinese
            premise = premise.replace(" ", "")
            hypo = hypo.replace(" ", "")

        if self.split == "train":
            prefix = f"Premise: {premise}\nHypothesis: {hypo}\nLabel:"
            full_text = prefix + str(self.LABEL2ID[label])
            tokenized = self.tokenizer(
                full_text,
                truncation=True,
                max_length=self.max_length,
                padding=False,
                return_tensors="pt"
            )
            tokenized = {k: v.squeeze(0) for k, v in tokenized.items()}

            prefix_ids = self.tokenizer(prefix).input_ids
            labels_ids = tokenized['input_ids'].clone()
            labels_ids[:len(prefix_ids)] = -100 # Masks the prefix tokens in the labels with -100 for GPT loss computation.
            tokenized['labels'] = labels_ids
            tokenized['label_str'] = str(self.LABEL2ID[label])
            return tokenized
        else:
            text = f"Premise: {premise}\nHypothesis: {hypo}\nLabel:"
            tokenized = self.tokenizer(
                text,
                truncation=True,
                max_length=self.max_length,
                padding=False,
                return_tensors="pt"
            )
            tokenized = {k: v.squeeze(0) for k, v in tokenized.items()}
            tokenized['label_str'] = str(self.LABEL2ID[label])
            return tokenized

    @staticmethod
    def collate_fn(batch):
        """
        Collate a batch of examples into padded tensors.

        Pads 'input_ids' and 'attention_mask' to the max length in the batch.
        Pads 'labels' with -100 if present.
        Collects 'label_str' for reference.

        Returns:
            dict: Padded tensors and label strings for the batch.
        """

        input_ids = torch.nn.utils.rnn.pad_sequence(
            [b['input_ids'] for b in batch],
            batch_first=True,
            padding_value=50256
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [b['attention_mask'] for b in batch],
            batch_first=True,
            padding_value=0
        )

        if 'labels' in batch[0]:
            labels = torch.nn.utils.rnn.pad_sequence(
                [b['labels'] for b in batch],
                batch_first=True,
                padding_value=-100
            )
        else:
            labels = None

        label_strs = [b['label_str'] for b in batch]

        out = {"input_ids": input_ids, "attention_mask": attention_mask, "label_strs": label_strs}
        if labels is not None:
            out["labels"] = labels
        return out

## Load Fresh Tokenizer and Model

In [11]:
EPOCHS = 1
BATCH_SIZE = 4
LR = 5e-5
WEIGHT_DECAY = 0.01
CORRECT_BIAS = True

model = load_moe_model(f'{save_base_path}/gpt2-moe-finetuned', device='cuda')
model.to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained(f'{save_base_path}/gpt2-moe-finetuned')

## Adam Optimizer

In [12]:
class AdamW(Optimizer):
    def __init__(
            self,
            params: Iterable[torch.nn.parameter.Parameter],
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-6,
            weight_decay: float = 0.0,
            correct_bias: bool = True,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)

    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                # State should be stored in this dictionary.
                state = self.state[p]

                # Access hyperparameters from the `group` dictionary.
                lr = group["lr"]
                eps = group["eps"]
                weight_decay = group["weight_decay"]
                correct_bias = group["correct_bias"]
                beta1, beta2 = group["betas"]
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                state["step"] += 1
                t = state["step"]

                """
                TODO-6: Implement the AdamW parameter update for this step.

                Implementation hints:
                1. Update biased first moment estimate:
                    m_t = beta1 * m_{t-1} + (1 - beta1) * grad
                2. Update biased second raw moment estimate:
                    v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
                3. Apply bias correction if correct_bias=True:
                    m_hat = m_t / (1 - beta1^t)
                    v_hat = v_t / (1 - beta2^t)
                4. Compute step size:
                    step_size = lr (or lr / (1 - beta1^t) if bias correction)
                5. Update parameters:
                    p = p - step_size * m_hat / (sqrt(v_hat) + eps)
                6. Apply decoupled weight decay after the parameter update (if weight_decay > 0):
                    p = p - lr * weight_decay * p
                Reference:
                Algorithm 1 in "Adam: A Method for Stochastic Optimization"
                https://arxiv.org/abs/1412.6980
                """

                m_t = exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                v_t = exp_avg_sq.mul_(beta2).add_(grad.square(), alpha=1 - beta2)

                if correct_bias:
                    m_hat = m_t.div(1 - beta1**t)
                    v_hat = v_t.div(1 - beta2**t)
                    step_size = lr
                    # FIXME: following the step size in the comments raise assertion error in sanity check
                    # step_size = lr / (1 - beta1**t)
                else:
                    m_hat = exp_avg
                    v_hat = exp_avg_sq
                    step_size = lr

                denom = torch.sqrt(v_hat).add(eps)
                update_direction = m_hat.div(denom)
                p.data.add_(update_direction, alpha=-step_size)

                if weight_decay > 0:
                    p.data.add_(p.data, alpha=-lr * weight_decay)

        return loss

## English Baseline

In [None]:
TRAIN_SUBSET = 1
DEV_SUBSET = 1
TEST_SUBSET = 1

train_dataset = XNLIDataset(
    split="train",
    lang="en",
    tokenizer=tokenizer,
    subset=TRAIN_SUBSET
)

dev_dataset = XNLIDataset(
    split="dev",
    lang="en",
    tokenizer=tokenizer,
    subset=DEV_SUBSET
)

test_dataset = XNLIDataset(
    split="test",
    lang="en",
    tokenizer=tokenizer,
    subset=TEST_SUBSET
)

In [None]:
# Create DataLoaders for training and validation datasets
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,collate_fn=XNLIDataset.collate_fn)
dev_loader = DataLoader(dev_dataset,shuffle=False,collate_fn=XNLIDataset.collate_fn)

VOCAB_SIZE = tokenizer.vocab_size

criterion = torch.nn.CrossEntropyLoss()
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, correct_bias=CORRECT_BIAS)
# Track training progress
global_train_losses = []
total_train_loss = 0.0
total_train_steps = 0
print_interval = 10

# Track best dev accuracy for model saving
# This only works for epoch > 1
best_dev_acc = 0.0
SAVE_DIR = "best_model"
os.makedirs(SAVE_DIR, exist_ok=True)

# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    model.train()

    # Iterate over batches
    loop = tqdm(train_loader, desc="Training")
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch.get("labels").to(DEVICE)

        optimizer.zero_grad()

        outputs = model(input_ids=input_ids, attention_mask=attention_mask)

        # 2. Get Logits directly
        vocabulary_logits = outputs.logits

        # 3. Shift logits and labels for next-token prediction
        shifted_logits = vocabulary_logits[:, :-1, :].contiguous()
        shifted_labels = labels[:, 1:].contiguous()

        # Flatten for CrossEntropyLoss
        logits_for_loss = shifted_logits.view(-1, VOCAB_SIZE)
        labels_for_loss = shifted_labels.view(-1)

        # 4. Compute Language Modeling Loss
        lm_loss = criterion(logits_for_loss, labels_for_loss)

        # 5. --- FIX: Add Load Balancing Auxiliary Loss ---
        total_aux_loss = 0.0
        for module in model.modules():
            if isinstance(module, MoELayer):
                total_aux_loss += module.aux_loss

        # Combine losses
        loss = lm_loss + total_aux_loss

        # 6. Backpropagation
        loss.backward()
        optimizer.step()

        # Update stats
        total_train_loss += loss.item()
        total_train_steps += 1
        global_train_avg_loss = total_train_loss / total_train_steps
        global_train_losses.append(global_train_avg_loss)

        loop.set_postfix({
            'avg_loss': f"{global_train_avg_loss:.4f}",
            'aux_loss': f"{total_aux_loss:.4f}" # Optional: Monitor aux loss separately
        })

    print(f"Epoch {epoch+1} finished | Global Avg Loss: {global_train_avg_loss:.4f}")

    # Evaluation
    acc, all_preds, all_labels = evaluate_gpt2_xnli(model, tokenizer, dev_loader, max_gen_length=1, device=DEVICE)

    # Save model
    save_moe_model(model, f'{save_base_path}/gpt2-moe-finetuned-en')
    tokenizer.save_pretrained(f'{save_base_path}/gpt2-moe-finetuned-en')

    print("Model finetuned on XNLI EN split has been saved to ./gpt2-moe-finetuned-en")
    print(f"The accuracy of this model is: {acc}")

In [None]:
del model
del tokenizer
torch.cuda.empty_cache()

## Cross Lingual Transfer

In [14]:
langs = ['en', 'ar', 'bg', 'de','el','es','fr','hi','ru','sw','th','tr','ur','vi','zh']

## French Finetune

In [None]:
import os
import json
import torch
from safetensors import safe_open

load_path = f'{save_base_path}/gpt2-moe-finetuned (1)/checkpoint-8000' # Your checkpoint path
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 2. Load MoE Configuration (moe_config.json assumed to be present) ---
with open(os.path.join(load_path, 'moe_config.json'), 'r') as f:
    moe_config = json.load(f)

# --- 3. Rebuild MoE Architecture (Same as before) ---
model = GPT2LMHeadModel.from_pretrained('gpt2')
for layer_idx in moe_config['moe_layers']:
    original_mlp = model.transformer.h[layer_idx].mlp
    model.transformer.h[layer_idx].mlp = MoELayer(
        original_mlp,
        num_experts=moe_config['num_experts'],
        top_k=moe_config['top_k'],
        drop_ratio=0.0
    )
tokenizer = GPT2Tokenizer.from_pretrained(f'{save_base_path}/gpt2-moe-finetuned')

# --- 4. Load Safetensors Weights (The NEW Step) ---
state_dict = {}
safetensors_path = os.path.join(load_path, 'model.safetensors')

# Use safe_open to read the weights and store them in the state_dict
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        state_dict[key] = f.get_tensor(key)

# --- 5. Apply State Dictionary and Move to GPU ---
# This line maps the safetensors weights (including experts and router)
# onto the rebuilt MoE structure.
model.load_state_dict(state_dict, strict=False)
model.to('cuda')

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768

In [None]:
TRAIN_SUBSET = 1
DEV_SUBSET = 1
TEST_SUBSET = 1

train_dataset = XNLIDataset(
    split="train",
    lang="fr",
    tokenizer=tokenizer,
    subset=TRAIN_SUBSET
)

dev_dataset = XNLIDataset(
    split="dev",
    lang="fr",
    tokenizer=tokenizer,
    subset=DEV_SUBSET
)

test_dataset = XNLIDataset(
    split="test",
    lang="fr",
    tokenizer=tokenizer,
    subset=TEST_SUBSET
)

Dataset initialized: split='train', lang='fr', total=392702, subset=1, subset_count=392702
Dataset initialized: split='dev', lang='fr', total=2490, subset=1, subset_count=2490
Dataset initialized: split='test', lang='fr', total=5010, subset=1, subset_count=5010


In [None]:
# Create DataLoaders for training and validation datasets
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,collate_fn=XNLIDataset.collate_fn)
dev_loader = DataLoader(dev_dataset,shuffle=False,collate_fn=XNLIDataset.collate_fn)

VOCAB_SIZE = tokenizer.vocab_size

criterion = torch.nn.CrossEntropyLoss()
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, correct_bias=CORRECT_BIAS)
# Track training progress
global_train_losses = []
total_train_loss = 0.0
total_train_steps = 0
print_interval = 10

# Track best dev accuracy for model saving
# This only works for epoch > 1
best_dev_acc = 0.0
SAVE_DIR = "best_model"
os.makedirs(SAVE_DIR, exist_ok=True)

# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    model.train()

    # Iterate over batches
    loop = tqdm(train_loader, desc="Training")
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch.get("labels").to(DEVICE)

        optimizer.zero_grad()

        # 1. Forward pass to get hidden states
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        vocabulary_logits = outputs.logits

        # 3. Shift logits and labels for next-token prediction
        shifted_logits = vocabulary_logits[:, :-1, :].contiguous()
        shifted_labels = labels[:, 1:].contiguous()

        # Flatten for CrossEntropyLoss
        logits_for_loss = shifted_logits.view(-1, VOCAB_SIZE)
        labels_for_loss = shifted_labels.view(-1)

        # 4. Compute Language Modeling Loss
        lm_loss = criterion(logits_for_loss, labels_for_loss)

        # 5. --- FIX: Add Load Balancing Auxiliary Loss ---
        total_aux_loss = 0.0
        for module in model.modules():
            if isinstance(module, MoELayer):
                total_aux_loss += module.aux_loss

        # Combine losses
        loss = lm_loss + total_aux_loss

        # 6. Backpropagation
        loss.backward()
        optimizer.step()

        # Update stats
        total_train_loss += loss.item()
        total_train_steps += 1
        global_train_avg_loss = total_train_loss / total_train_steps
        global_train_losses.append(global_train_avg_loss)

        loop.set_postfix({
            'avg_loss': f"{global_train_avg_loss:.4f}",
            'aux_loss': f"{total_aux_loss:.4f}" # Optional: Monitor aux loss separately
        })

    print(f"Epoch {epoch+1} finished | Global Avg Loss: {global_train_avg_loss:.4f}")

    # Evaluation
    acc, all_preds, all_labels = evaluate_gpt2_xnli(model, tokenizer, dev_loader, max_gen_length=1, device=DEVICE)

    # Save model
    save_moe_model(model, f'{save_base_path}/gpt2-moe-finetuned-fr')
    tokenizer.save_pretrained(f'{save_base_path}/gpt2-moe-finetuned-fr')

    print("Model finetuned on XNLI EN split has been saved to ./gpt2-moe-finetuned-fr")
    print(f"The accuracy of this model is: {acc}")

Epoch 1/1


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 98176/98176 [3:12:54<00:00,  8.48it/s, avg_loss=0.9598, aux_loss=0.1206]


Epoch 1 finished | Global Avg Loss: 0.9598


Generating:   0%|          | 0/2490 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2490/2490 [01:15<00:00, 32.97it/s]


Evaluation accuracy: 63.57%
Model saved to /content/drive/MyDrive/MoE/gpt2-moe-finetuned-fr
Model finetuned on XNLI EN split has been saved to ./gpt2-moe-finetuned-fr
The accuracy of this model is: 0.6357429718875502


In [None]:
del model
del tokenizer
torch.cuda.empty_cache()

## French Finetune Eavluation

In [None]:
finetuned_model = load_moe_model(f'{save_base_path}/gpt2-moe-finetuned-fr', device='cuda')
finetuned_model.to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained(f'{save_base_path}/gpt2-moe-finetuned-fr')

all_test_datasets = {}
all_test_loader = {}
for lang in langs:
    test_dataset = XNLIDataset(split="test", lang=lang, tokenizer=tokenizer, max_length=1024, subset=TEST_SUBSET)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=XNLIDataset.collate_fn)
    all_test_datasets[lang] = test_dataset
    all_test_loader[lang] = test_loader

all_results = {}
for lang in langs:
    test_loader = all_test_loader[lang]
    if lang == "fr":
        print(f"Evaluating on {lang}...")
    else:
        print(f"Evaluating zero-shot on {lang}...")
    acc, all_preds, all_labels = evaluate_gpt2_xnli(finetuned_model, tokenizer, test_loader, max_gen_length=1, device=DEVICE)
    all_results[lang] = acc

print("Zero-shot cross-lingual accuracy per language:")
for lang, acc in all_results.items():
    print(f"{lang}: {acc*100:.2f}%")

Model loaded from /content/drive/MyDrive/MoE/gpt2-moe-finetuned-fr
Dataset initialized: split='test', lang='en', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='ar', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='bg', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='de', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='el', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='es', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='fr', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='hi', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='ru', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='sw', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='th', total=5010, subset=1, subse

Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:32<00:00, 32.91it/s]


Evaluation accuracy: 56.09%
Evaluating zero-shot on ar...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:38<00:00, 31.62it/s]


Evaluation accuracy: 35.41%
Evaluating zero-shot on bg...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:40<00:00, 31.17it/s]


Evaluation accuracy: 36.01%
Evaluating zero-shot on de...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:33<00:00, 32.56it/s]


Evaluation accuracy: 45.81%
Evaluating zero-shot on el...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:42<00:00, 30.76it/s]


Evaluation accuracy: 36.83%
Evaluating zero-shot on es...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:32<00:00, 32.85it/s]


Evaluation accuracy: 53.53%
Evaluating on fr...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:33<00:00, 32.63it/s]


Evaluation accuracy: 64.17%
Evaluating zero-shot on hi...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:45<00:00, 30.28it/s]


Evaluation accuracy: 35.67%
Evaluating zero-shot on ru...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:42<00:00, 30.88it/s]


Evaluation accuracy: 36.03%
Evaluating zero-shot on sw...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:33<00:00, 32.59it/s]


Evaluation accuracy: 41.46%
Evaluating zero-shot on th...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:52<00:00, 29.05it/s]


Evaluation accuracy: 34.85%
Evaluating zero-shot on tr...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:37<00:00, 31.84it/s]


Evaluation accuracy: 43.69%
Evaluating zero-shot on ur...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:46<00:00, 30.10it/s]


Evaluation accuracy: 34.95%
Evaluating zero-shot on vi...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:43<00:00, 30.63it/s]


Evaluation accuracy: 38.56%
Evaluating zero-shot on zh...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:41<00:00, 31.06it/s]

Evaluation accuracy: 37.58%
Zero-shot cross-lingual accuracy per language:
en: 56.09%
ar: 35.41%
bg: 36.01%
de: 45.81%
el: 36.83%
es: 53.53%
fr: 64.17%
hi: 35.67%
ru: 36.03%
sw: 41.46%
th: 34.85%
tr: 43.69%
ur: 34.95%
vi: 38.56%
zh: 37.58%





In [None]:
del finetuned_model
del tokenizer
torch.cuda.empty_cache()

## Multi Language Finetune

In [18]:
import os
import json
import torch
from safetensors import safe_open

load_path = f'{save_base_path}/gpt2-moe-finetuned/checkpoint-8000'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 2. Load MoE Configuration (moe_config.json assumed to be present) ---
with open(os.path.join(load_path, 'moe_config.json'), 'r') as f:
    moe_config = json.load(f)

# --- 3. Rebuild MoE Architecture (Same as before) ---
model = GPT2LMHeadModel.from_pretrained('gpt2')
for layer_idx in moe_config['moe_layers']:
    original_mlp = model.transformer.h[layer_idx].mlp
    model.transformer.h[layer_idx].mlp = MoELayer(
        original_mlp,
        num_experts=moe_config['num_experts'],
        top_k=moe_config['top_k'],
        drop_ratio=0.0
    )
tokenizer = GPT2Tokenizer.from_pretrained(f'{save_base_path}/gpt2-moe-finetuned')

# --- 4. Load Safetensors Weights (The NEW Step) ---
state_dict = {}
safetensors_path = os.path.join(load_path, 'model.safetensors')

# Use safe_open to read the weights and store them in the state_dict
with safe_open(safetensors_path, framework="pt", device="cpu") as f:
    for key in f.keys():
        state_dict[key] = f.get_tensor(key)

# --- 5. Apply State Dictionary and Move to GPU ---
# This line maps the safetensors weights (including experts and router)
# onto the rebuilt MoE structure.
model.load_state_dict(state_dict, strict=False)
model.to('cuda')

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768

In [20]:
from torch.utils.data import ConcatDataset

"""
Load NLI datasets for fine-tuning and evaluation on multiple non-english splits.
Final training and evaluation should use the full dataset (SUBSET=1).
"""

target_langs = ["en", "fr", "es"]

train_datasets = []
for l in target_langs:
    ds = XNLIDataset(
        split="train",
        lang=l,
        tokenizer=tokenizer,
        subset=1.0
    )
    train_datasets.append(ds)

combined_train_dataset = ConcatDataset(train_datasets)

train_loader = DataLoader(
    combined_train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=XNLIDataset.collate_fn
)

dev_datasets = []
for l in target_langs:
    ds = XNLIDataset(
        split="dev",
        lang=l,
        tokenizer=tokenizer,
        subset=1.0
    )
    dev_datasets.append(ds)

combined_dev_dataset = ConcatDataset(dev_datasets)

dev_loader = DataLoader(
    combined_dev_dataset,
    shuffle=False,
    collate_fn=XNLIDataset.collate_fn
)

Dataset initialized: split='train', lang='en', total=392702, subset=1.0, subset_count=392702
Dataset initialized: split='train', lang='fr', total=392702, subset=1.0, subset_count=392702
Dataset initialized: split='train', lang='es', total=392702, subset=1.0, subset_count=392702
Dataset initialized: split='dev', lang='en', total=2490, subset=1.0, subset_count=2490
Dataset initialized: split='dev', lang='fr', total=2490, subset=1.0, subset_count=2490
Dataset initialized: split='dev', lang='es', total=2490, subset=1.0, subset_count=2490


In [21]:
VOCAB_SIZE = tokenizer.vocab_size

criterion = torch.nn.CrossEntropyLoss()
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, correct_bias=CORRECT_BIAS)
# Track training progress
global_train_losses = []
total_train_loss = 0.0
total_train_steps = 0
print_interval = 10

# Track best dev accuracy for model saving
# This only works for epoch > 1
best_dev_acc = 0.0
SAVE_DIR = "best_model"
os.makedirs(SAVE_DIR, exist_ok=True)

# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    model.train()

    # Iterate over batches
    loop = tqdm(train_loader, desc="Training")
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch.get("labels").to(DEVICE)

        optimizer.zero_grad()

        # 1. Forward pass to get hidden states
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        vocabulary_logits = outputs.logits

        # 3. Shift logits and labels for next-token prediction
        shifted_logits = vocabulary_logits[:, :-1, :].contiguous()
        shifted_labels = labels[:, 1:].contiguous()

        # Flatten for CrossEntropyLoss
        logits_for_loss = shifted_logits.view(-1, VOCAB_SIZE)
        labels_for_loss = shifted_labels.view(-1)

        # 4. Compute Language Modeling Loss
        lm_loss = criterion(logits_for_loss, labels_for_loss)

        # 5. --- FIX: Add Load Balancing Auxiliary Loss ---
        total_aux_loss = 0.0
        for module in model.modules():
            if isinstance(module, MoELayer):
                total_aux_loss += module.aux_loss

        # Combine losses
        loss = lm_loss + total_aux_loss

        # 6. Backpropagation
        loss.backward()
        optimizer.step()

        # Update stats
        total_train_loss += loss.item()
        total_train_steps += 1
        global_train_avg_loss = total_train_loss / total_train_steps
        global_train_losses.append(global_train_avg_loss)

        loop.set_postfix({
            'avg_loss': f"{global_train_avg_loss:.4f}",
            'aux_loss': f"{total_aux_loss:.4f}" # Optional: Monitor aux loss separately
        })

    print(f"Epoch {epoch+1} finished | Global Avg Loss: {global_train_avg_loss:.4f}")

    # Evaluation
    acc, all_preds, all_labels = evaluate_gpt2_xnli(model, tokenizer, dev_loader, max_gen_length=1, device=DEVICE)

    # Save model
    save_moe_model(model, f'{save_base_path}/gpt2-moe-finetuned-en-es-fr')
    tokenizer.save_pretrained(f'{save_base_path}/gpt2-moe-finetuned-en-es-fr')

    print("Model finetuned on XNLI EN split has been saved to ./gpt2-moe-finetuned-en-es-fr")
    print(f"The accuracy of this model is: {acc}")

Epoch 1/1


Training: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 294527/294527 [9:21:24<00:00,  8.74it/s, avg_loss=0.8632, aux_loss=0.1205]


Epoch 1 finished | Global Avg Loss: 0.8632


Generating:   0%|          | 0/7470 [00:00<?, ?it/s]The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 7470/7470 [03:40<00:00, 33.91it/s]


Evaluation accuracy: 68.45%
Model saved to /content/drive/MyDrive/MoE/gpt2-moe-finetuned-en-es-fr
Model finetuned on XNLI EN split has been saved to ./gpt2-moe-finetuned-en-es-fr
The accuracy of this model is: 0.684471218206158


In [23]:
finetuned_model = load_moe_model(f'{save_base_path}/gpt2-moe-finetuned-en-es-fr', device='cuda')
finetuned_model.to('cuda')
tokenizer = GPT2Tokenizer.from_pretrained(f'{save_base_path}/gpt2-moe-finetuned-en-es-fr')

all_test_loader = {}

for l in langs:
    test_dataset = XNLIDataset(split="test", lang=l, tokenizer=tokenizer, max_length=1024, subset=TEST_SUBSET)
    test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=XNLIDataset.collate_fn)
    all_test_loader[l] = test_loader

all_results = {}

for test_lang in langs:
    test_loader = all_test_loader[test_lang]

    if test_lang in target_langs:
        print(f"Evaluating on {test_lang}...")
    else:
        print(f"Evaluating zero-shot on {test_lang}...")

    acc, all_preds, all_labels = evaluate_gpt2_xnli(finetuned_model, tokenizer, test_loader, max_gen_length=1, device=DEVICE)

    all_results[test_lang] = acc

print(f"\nResults for model trained on {target_langs}:")
for l, acc in all_results.items():
    print(f"{l}: {acc*100:.2f}%")

Model loaded from /content/drive/MyDrive/MoE/gpt2-moe-finetuned-en-es-fr
Dataset initialized: split='test', lang='en', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='ar', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='bg', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='de', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='el', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='es', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='fr', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='hi', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='ru', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='sw', total=5010, subset=1, subset_count=5010
Dataset initialized: split='test', lang='th', total=5010, subset=1,

Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:26<00:00, 34.12it/s]


Evaluation accuracy: 74.57%
Evaluating zero-shot on ar...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:34<00:00, 32.51it/s]


Evaluation accuracy: 37.68%
Evaluating zero-shot on bg...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:35<00:00, 32.19it/s]


Evaluation accuracy: 38.72%
Evaluating zero-shot on de...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:29<00:00, 33.51it/s]


Evaluation accuracy: 44.79%
Evaluating zero-shot on el...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:37<00:00, 31.87it/s]


Evaluation accuracy: 39.32%
Evaluating on es...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:29<00:00, 33.54it/s]


Evaluation accuracy: 65.33%
Evaluating on fr...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:28<00:00, 33.72it/s]


Evaluation accuracy: 64.05%
Evaluating zero-shot on hi...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:39<00:00, 31.39it/s]


Evaluation accuracy: 36.07%
Evaluating zero-shot on ru...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:35<00:00, 32.15it/s]


Evaluation accuracy: 38.84%
Evaluating zero-shot on sw...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:28<00:00, 33.83it/s]


Evaluation accuracy: 42.71%
Evaluating zero-shot on th...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:41<00:00, 31.11it/s]


Evaluation accuracy: 36.89%
Evaluating zero-shot on tr...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:30<00:00, 33.18it/s]


Evaluation accuracy: 42.42%
Evaluating zero-shot on ur...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:36<00:00, 32.00it/s]


Evaluation accuracy: 37.52%
Evaluating zero-shot on vi...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:33<00:00, 32.70it/s]


Evaluation accuracy: 40.06%
Evaluating zero-shot on zh...


Generating: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 5010/5010 [02:31<00:00, 33.10it/s]

Evaluation accuracy: 39.40%

Results for model trained on ['en', 'fr', 'es']:
en: 74.57%
ar: 37.68%
bg: 38.72%
de: 44.79%
el: 39.32%
es: 65.33%
fr: 64.05%
hi: 36.07%
ru: 38.84%
sw: 42.71%
th: 36.89%
tr: 42.42%
ur: 37.52%
vi: 40.06%
zh: 39.40%



