# MoE based GPT 2

## Import Dependencies

In [None]:
import torch
import torch.nn as nn
from torch.optim import Optimizer
from transformers import GPT2LMHeadModel, GPT2Config, GPT2Tokenizer
from transformers import Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, Dataset
from typing import Optional
import copy
import os
import json
from typing import Callable, Iterable, Tuple
from utils import *

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(DEVICE)

## MoE Router Module

In [None]:
class TopKRouter(nn.Module):
    """Simple router that selects top-k experts per token"""
    def __init__(self, hidden_size: int, num_experts: int, top_k: int = 2):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        # Router is a simple linear layer mapping hidden states to expert scores
        self.gate = nn.Linear(hidden_size, num_experts, bias=False)
        
    def forward(self, hidden_states):
        # hidden_states: [batch_size, seq_len, hidden_size]
        router_logits = self.gate(hidden_states)  # [batch, seq, num_experts]
        
        # Get top-k experts per token
        routing_weights = torch.softmax(router_logits, dim=-1)
        top_k_weights, top_k_indices = torch.topk(routing_weights, self.top_k, dim=-1)
        
        # Normalize top-k weights to sum to 1
        top_k_weights = top_k_weights / top_k_weights.sum(dim=-1, keepdim=True)
        
        return top_k_weights, top_k_indices, router_logits

## MoE Layer

In [None]:
class MoELayer(nn.Module):
    """Mixture of Experts layer replacing the MLP"""
    def __init__(self, dense_mlp, num_experts: int = 8, top_k: int = 2, drop_ratio: float = 0.0):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        
        # Hugging Face GPT-2 uses Conv1D, where weight shape is [input, output].
        hidden_size = dense_mlp.c_fc.weight.shape[0]
        # Create router
        self.router = TopKRouter(hidden_size, num_experts, top_k)
        
        # Create experts by copying the dense MLP weights with optional drop-upcycling
        self.experts = nn.ModuleList([
            self._copy_mlp_with_drop(dense_mlp, drop_ratio) for _ in range(num_experts)
        ])
        
    def _copy_mlp_with_drop(self, dense_mlp, drop_ratio: float):
        """
        Create a copy of the dense MLP with drop-upcycling.
        Re-initializes drop_ratio% of parameters to promote diversity.
        """
        expert = copy.deepcopy(dense_mlp)
        
        if drop_ratio > 0:
            with torch.no_grad():
                for name, param in expert.named_parameters():
                    # Create a mask for parameters to re-initialize
                    mask = torch.rand_like(param) < drop_ratio
                    
                    # Re-initialize masked parameters with small random values
                    if mask.any():
                        param.data[mask] = torch.randn_like(param[mask]) * 0.02
        
        return expert
    
    def forward(self, hidden_states):
        batch_size, seq_len, hidden_size = hidden_states.shape
        total_tokens = batch_size * seq_len
        
        # Route tokens
        top_k_weights, top_k_indices, router_logits = self.router(hidden_states)
        
        # Flatten everything
        flat_hidden = hidden_states.view(total_tokens, hidden_size)
        flat_weights = top_k_weights.view(total_tokens, self.top_k)
        flat_indices = top_k_indices.view(total_tokens, self.top_k)
        
        # Initialize output
        output = torch.zeros_like(flat_hidden)
        
        # Create a dispatch mask: [num_experts, total_tokens * top_k]
        # This tells us which expert handles which (token, k) pair
        flat_indices_1d = flat_indices.view(-1)  # [total_tokens * top_k]
        flat_weights_1d = flat_weights.view(-1)  # [total_tokens * top_k]
        
        # Token indices repeated for each k
        token_indices = torch.arange(total_tokens, device=hidden_states.device)
        token_indices = token_indices.unsqueeze(1).expand(-1, self.top_k).reshape(-1)
        # token_indices: [total_tokens * top_k]
        
        # Process each expert
        for expert_idx in range(self.num_experts):
            # Find all (token, k) pairs assigned to this expert
            expert_mask = (flat_indices_1d == expert_idx)
            
            if not expert_mask.any():
                continue
            
            # Get unique token indices that route to this expert
            expert_token_indices = token_indices[expert_mask]
            expert_weights = flat_weights_1d[expert_mask]
            
            # Get inputs (may have duplicates if top_k > 1 and same token uses expert twice)
            expert_input = flat_hidden[expert_token_indices]
            
            # Run expert
            expert_output = self.experts[expert_idx](expert_input)
            
            # Weight and scatter
            weighted_output = expert_output * expert_weights.unsqueeze(-1)
            
            # Use scatter_add to handle potential duplicates
            output.index_add_(0, expert_token_indices, weighted_output)
        
        # Reshape back
        output = output.view(batch_size, seq_len, hidden_size)
        
        return output

In [None]:
def calculate_active_params(model):
    """Calculate total parameters and active parameters in MoE layers"""
    total_params = sum(p.numel() for p in model.parameters())
    
    # Calculate MoE-specific info
    moe_info = {
        'total_experts': 0,
        'active_per_token': 0,
        'total_expert_params': 0,
        'active_expert_params': 0,
    }
    
    for name, module in model.named_modules():
        if isinstance(module, MoELayer):
            moe_info['total_experts'] += module.num_experts
            moe_info['active_per_token'] += module.top_k
            
            # Count expert parameters
            expert_params = sum(p.numel() for p in module.experts.parameters())
            moe_info['total_expert_params'] += expert_params
            moe_info['active_expert_params'] += (expert_params / module.num_experts) * module.top_k
    
    return total_params, moe_info

## Upcycle GPT2 Vanilla Weights to MoE Architecture

In [None]:
def upcycle_gpt2_to_moe(
    model_name: str = 'gpt2',
    num_experts: int = 8,
    top_k: int = 2,
    moe_layers: Optional[list] = None,
    drop_ratio: float = 0.0,
    match_active_params: bool = False
):
    """
    Convert a standard GPT-2 model to MoE architecture
    
    Args:
        model_name: HuggingFace model name
        num_experts: Number of experts per MoE layer
        top_k: Number of experts to activate per token
        moe_layers: List of layer indices to convert to MoE (None = all layers)
        drop_ratio: Ratio of parameters to re-initialize for drop-upcycling (0.0-1.0)
                   0.0 = standard upcycling, 0.1-0.2 recommended for drop-upcycling
        match_active_params: If True, automatically adjust num_experts to match
                            vanilla model's active parameters (top_k=1, num_experts=1)
    
    Returns:
        Modified GPT2LMHeadModel with MoE layers
    """
    # Load the pre-trained model with LM head
    model = GPT2LMHeadModel.from_pretrained(model_name)
    original_params = sum(p.numel() for p in model.parameters())
    
    # If no specific layers specified, convert all layers
    if moe_layers is None:
        moe_layers = list(range(len(model.transformer.h)))
    
    # Auto-adjust for fair comparison
    if match_active_params:
        top_k = 1
        print(f"Fair comparison mode: Setting top_k={top_k} to match vanilla GPT-2 active params")
    
    upcycle_type = "Drop-Upcycling" if drop_ratio > 0 else "Standard Upcycling"
    print(f"Converting layers {moe_layers} to MoE with {num_experts} experts (top-{top_k})")
    print(f"Using {upcycle_type}" + (f" with {drop_ratio*100}% parameter re-initialization" if drop_ratio > 0 else ""))
    
    # Replace MLPs with MoE layers
    for layer_idx in moe_layers:
        if layer_idx >= len(model.transformer.h):
            print(f"Warning: Layer {layer_idx} doesn't exist, skipping")
            continue
            
        original_mlp = model.transformer.h[layer_idx].mlp
        
        # Replace with MoE layer
        model.transformer.h[layer_idx].mlp = MoELayer(
            original_mlp,
            num_experts=num_experts,
            top_k=top_k,
            drop_ratio=drop_ratio
        )
        
        print(f"Converted layer {layer_idx}")
    
    # Print parameter comparison
    total_params, moe_info = calculate_active_params(model)
    print("\n" + "=" * 60)
    print("PARAMETER COMPARISON")
    print("=" * 60)
    print(f"Original model params:        {original_params:,}")
    print(f"MoE model total params:       {total_params:,}")
    print(f"MoE model active params:      {int(original_params + moe_info['active_expert_params']):,}")
    print(f"\nPer-layer breakdown:")
    print(f"  Experts per layer:          {num_experts}")
    print(f"  Active experts per token:   {top_k}")
    print(f"  Active ratio:               {top_k}/{num_experts} = {top_k/num_experts:.1%}")
    
    if match_active_params:
        print(f"\nFair comparison mode: Active params ≈ vanilla GPT-2")
    else:
        active_ratio = (original_params + moe_info['active_expert_params']) / original_params
        print(f"\nMoE has {active_ratio:.1f}x active parameters vs vanilla")
    
    return model

In [None]:
def save_moe_model(model, save_path):
    """Save MoE model with custom layers"""
    os.makedirs(save_path, exist_ok=True)
    
    # Save model state dict
    torch.save(model.state_dict(), os.path.join(save_path, 'pytorch_model.bin'))
    
    # Save config
    model.config.save_pretrained(save_path)
    
    # Save MoE configuration
    moe_config = {
        'moe_layers': [],
        'num_experts': None,
        'top_k': None,
    }
    
    for layer_idx, layer in enumerate(model.transformer.h):
        if isinstance(layer.mlp, MoELayer):
            moe_config['moe_layers'].append(layer_idx)
            if moe_config['num_experts'] is None:
                moe_config['num_experts'] = layer.mlp.num_experts
                moe_config['top_k'] = layer.mlp.top_k
    
    with open(os.path.join(save_path, 'moe_config.json'), 'w') as f:
        json.dump(moe_config, f)
    
    print(f"Model saved to {save_path}")

In [None]:
def load_moe_model(load_path, device='cpu'):
    """Load MoE model with custom layers"""
    # Load MoE config
    with open(os.path.join(load_path, 'moe_config.json'), 'r') as f:
        moe_config = json.load(f)
    
    # Load base model
    base_model = GPT2LMHeadModel.from_pretrained('gpt2')
    
    # Convert to MoE architecture
    for layer_idx in moe_config['moe_layers']:
        original_mlp = base_model.transformer.h[layer_idx].mlp
        base_model.transformer.h[layer_idx].mlp = MoELayer(
            original_mlp,
            num_experts=moe_config['num_experts'],
            top_k=moe_config['top_k'],
            drop_ratio=0.0
        )
    
    # Load trained weights
    state_dict = torch.load(os.path.join(load_path, 'pytorch_model.bin'), map_location=device)
    base_model.load_state_dict(state_dict)
    
    print(f"Model loaded from {load_path}")
    return base_model

In [None]:
moe_model = upcycle_gpt2_to_moe(
        model_name='gpt2',
        num_experts=8,
        top_k=1,
        drop_ratio=0.1,
        moe_layers=[1, 3, 5, 7, 9, 11],
        match_active_params=True
    )

save_moe_model(moe_model, './gpt2-moe-upcycled')

del moe_model

## Finetune MoE Model

In [None]:
def finetune_moe_model(
    moe_model_path: str = './gpt2-moe-upcycled',
    output_dir: str = './gpt2-moe-finetuned',
    num_train_steps: int = 5000,
    batch_size: int = 4,
    learning_rate: float = 1e-4,
    warmup_steps: int = 200,
    gradient_accumulation_steps: int = 8,
    save_steps: int = 1000,
    max_length: int = 512,
):
    """
    Fine-tune the upcycled MoE model
    """
    print("=" * 60)
    print("FINE-TUNING MoE MODEL")
    print("=" * 60)
    
    # Load the model using custom loader
    print(f"Loading model from {moe_model_path}...")
    model = load_moe_model(moe_model_path)
    
    print(f"Model loaded with {sum(p.numel() for p in model.parameters()):,} parameters")
    
    # Load tokenizer
    tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
    tokenizer.pad_token = tokenizer.eos_token

    print(f"\nLoading dataset...")
    dataset = load_dataset('c4', 'en', split='train', streaming=True)
    
    # Take limited samples for training
    num_samples = num_train_steps * batch_size * gradient_accumulation_steps
    dataset_list = []
    
    print(f"Collecting {num_samples} samples...")
    for i, example in enumerate(dataset):
        if i >= num_samples:
            break
        dataset_list.append(example)
        if (i + 1) % 1000 == 0:
            print(f"  Collected {i + 1}/{num_samples} samples...")
    
    dataset = Dataset.from_list(dataset_list)
    
    # Tokenize dataset
    def tokenize_function(examples):
        return tokenizer(
            examples['text'], 
            truncation=True, 
            max_length=max_length,
            padding='max_length'
        )
    
    print("Tokenizing dataset...")
    tokenized_dataset = dataset.map(
        tokenize_function,
        batched=True,
        remove_columns=['text', 'timestamp', 'url']
    )

    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False
    )
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        max_steps=num_train_steps,
        per_device_train_batch_size=batch_size,
        gradient_accumulation_steps=gradient_accumulation_steps,
        learning_rate=learning_rate,
        warmup_steps=warmup_steps,
        lr_scheduler_type='cosine',
        save_steps=save_steps,
        save_total_limit=3,
        logging_steps=100,
        fp16=torch.cuda.is_available(),
        dataloader_num_workers=4,
        remove_unused_columns=False,
    )
    
    # Create trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_dataset,
        data_collator=data_collator,
    )
    
    # Train
    print("\n" + "=" * 60)
    print("STARTING TRAINING")
    print("=" * 60)
    print(f"Total steps: {num_train_steps}")
    print(f"Effective batch size: {batch_size * gradient_accumulation_steps}")
    print(f"Learning rate: {learning_rate}")
    print(f"Warmup steps: {warmup_steps}\n")
    
    trainer.train()

    print("\n" + "=" * 60)
    print("SAVING FINAL MODEL")
    print("=" * 60)
    save_moe_model(model, output_dir)
    tokenizer.save_pretrained(output_dir)
    print(f"Model and tokenizer saved to {output_dir}")
    
    return model

In [None]:
finetuned_model = finetune_moe_model(
        moe_model_path='./gpt2-moe-upcycled',
        output_dir='./gpt2-moe-finetuned',
        num_train_steps=3125,
        batch_size=4,
        learning_rate=1e-4,
        warmup_steps=312,
    )

print("Successfully fintetuned model at path ./gpt2-moe-fintuned")

del finetuned_model

## Adam Optimizer

In [None]:
class AdamW(Optimizer):
    def __init__(
            self,
            params: Iterable[torch.nn.parameter.Parameter],
            lr: float = 1e-3,
            betas: Tuple[float, float] = (0.9, 0.999),
            eps: float = 1e-6,
            weight_decay: float = 0.0,
            correct_bias: bool = True,
    ):
        if lr < 0.0:
            raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError("Invalid beta parameter: {} - should be in [0.0, 1.0[".format(betas[1]))
        if not 0.0 <= eps:
            raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(eps))
        defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
        super().__init__(params, defaults)

    def step(self, closure: Callable = None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None:
                    continue
                grad = p.grad.data
                if grad.is_sparse:
                    raise RuntimeError("Adam does not support sparse gradients, please consider SparseAdam instead")

                # State should be stored in this dictionary.
                state = self.state[p]

                # Access hyperparameters from the `group` dictionary.
                lr = group["lr"]
                eps = group["eps"]
                weight_decay = group["weight_decay"]
                correct_bias = group["correct_bias"]
                beta1, beta2 = group["betas"]
                state = self.state[p]
                if len(state) == 0:
                    state["step"] = 0
                    state["exp_avg"] = torch.zeros_like(p.data)
                    state["exp_avg_sq"] = torch.zeros_like(p.data)

                exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]

                state["step"] += 1
                t = state["step"]

                """
                TODO-6: Implement the AdamW parameter update for this step.

                Implementation hints:
                1. Update biased first moment estimate:
                    m_t = beta1 * m_{t-1} + (1 - beta1) * grad
                2. Update biased second raw moment estimate:
                    v_t = beta2 * v_{t-1} + (1 - beta2) * grad^2
                3. Apply bias correction if correct_bias=True:
                    m_hat = m_t / (1 - beta1^t)
                    v_hat = v_t / (1 - beta2^t)
                4. Compute step size:
                    step_size = lr (or lr / (1 - beta1^t) if bias correction)
                5. Update parameters:
                    p = p - step_size * m_hat / (sqrt(v_hat) + eps)
                6. Apply decoupled weight decay after the parameter update (if weight_decay > 0):
                    p = p - lr * weight_decay * p
                Reference:
                Algorithm 1 in "Adam: A Method for Stochastic Optimization"
                https://arxiv.org/abs/1412.6980
                """

                m_t = exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
                v_t = exp_avg_sq.mul_(beta2).add_(grad.square(), alpha=1 - beta2)

                if correct_bias:
                    m_hat = m_t.div(1 - beta1**t)
                    v_hat = v_t.div(1 - beta2**t)
                    step_size = lr
                    # FIXME: following the step size in the comments raise assertion error in sanity check
                    # step_size = lr / (1 - beta1**t)
                else:
                    m_hat = exp_avg
                    v_hat = exp_avg_sq
                    step_size = lr

                denom = torch.sqrt(v_hat).add(eps)
                update_direction = m_hat.div(denom)
                p.data.add_(update_direction, alpha=-step_size)

                if weight_decay > 0:
                    p.data.add_(p.data, alpha=-lr * weight_decay)

        return loss

## Load NLI Datset

In [None]:
def compute_accuracy(preds, labels):
    correct = sum(p.lower().strip() == l.lower().strip() for p, l in zip(preds, labels))
    return correct / len(labels)

def generate_gpt2(model, tokenizer, input_ids, max_gen_length=50, device="cuda"):
    model.eval()
    input_ids = input_ids.to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=max_gen_length,
            pad_token_id=tokenizer.eos_token_id,
            do_sample=False  # Greedy decoding
        )
    
    generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return generated_text

def evaluate_gpt2_xnli(model, tokenizer, dataloader, max_gen_length=10, device="cuda"):
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for item in tqdm(dataloader, desc="Generating"):
            input_ids = item['input_ids']
            gen_text = generate_gpt2(model, tokenizer, input_ids, max_gen_length=max_gen_length, device=device)
            pred_label = gen_text.split("Label:")[-1].strip()
            all_preds.append(pred_label)
            all_labels.extend(item['label_strs'])
    acc = compute_accuracy(all_preds, all_labels)
    print(f"Evaluation accuracy: {acc*100:.2f}%")
    return acc, all_preds, all_labels

class XNLIDataset(Dataset):
    """
    A PyTorch Dataset for XNLI (Cross-lingual Natural Language Inference) task.

    Supports train, dev, and test splits in a specific language,
    tokenizes text inputs for GPT-style models, and optionally subsamples the dataset.

    Attributes:
        split (str): Dataset split, one of 'train', 'dev', 'test'.
        lang (str): Language code (e.g., 'en', 'zh').
        tokenizer: A HuggingFace tokenizer to convert text to input IDs.
        max_length (int): Maximum sequence length for tokenization.
        LABEL2ID (dict): Mapping from textual labels to integer IDs.
        ID2LABEL (dict): Reverse mapping from integer IDs to textual labels.
        data (pd.DataFrame): The loaded and preprocessed dataset.
    """
    def __init__(
        self,
        split="train",
        lang="en",
        train_path_template="XNLI-MT-1.0/multinli/multinli.train.{lang}.tsv",
        test_path="XNLI-1.0/xnli.test.tsv",
        dev_path="XNLI-1.0/xnli.dev.tsv",
        tokenizer=None,
        max_length=1024,
        subset = 1.0  # 0~1
    ):
        self.split = split
        self.lang = lang
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.LABEL2ID = {"entailment": 0, "contradictory": 1, "neutral": 2}
        self.ID2LABEL = {v: k for k, v in self.LABEL2ID.items()}

        if split == "train":
            path = train_path_template.format(lang=lang)
            df = self.read_xnli_tsv(path, split)
            df = df.dropna(subset=['premise','hypo','label'])
        elif split in ["dev", "test"]:
            path = test_path if split=="test" else dev_path
            df = self.read_xnli_tsv(path, split)
            df = df[df['language']==lang].copy()
            keep_cols = ['sentence1', 'sentence2', 'gold_label']
            df = df[keep_cols].dropna()
            df.rename(columns={'sentence1':'premise','sentence2':'hypo','gold_label':'label'}, inplace=True)
            df['label'] = df['label'].replace({'contradiction': 'contradictory'})
        else:
            raise ValueError("split must be one of ['train','dev','test']")

        original_num = len(df)
        if subset < 1.0:
            n = max(1, int(len(df) * subset))
            df = df.iloc[:n].reset_index(drop=True)
        subset_num = len(df)

        self.data = df.reset_index(drop=True)
        print(f"Dataset initialized: split='{split}', lang='{lang}', total={original_num}, subset={subset}, subset_count={subset_num}")

    def read_xnli_tsv(self, path, split):
        """
        Read an XNLI TSV file and return it as a pandas DataFrame.

        Args:
            path (str): Path to the TSV file.
            split (str): One of "train", "dev", "test" indicating the dataset split.

        Returns:
            pd.DataFrame: The dataset as a DataFrame with appropriate columns.
        """
        if split == "train":
            with open(path, "r", encoding="utf-8") as f:
                lines = f.read().splitlines()
            header = lines[0].split("\t")
            data = []
            for i, line in enumerate(lines[1:], start=2):
                parts = line.split("\t")
                if len(parts) == len(header):
                    data.append(parts)
                else:
                    print(f"skip row {i}: {len(parts)} cols → {parts[:2]}")
        else:
            with open(path, "r", encoding="utf-8") as f:
                reader = csv.reader(f, delimiter="\t")
                rows = list(reader)
            header = rows[0]
            expected_cols = len(header)
            data = []
            for i, row in enumerate(rows[1:], start=2):
                if len(row) == expected_cols:
                    data.append(row)
                else:
                    print(f"skip row {i}: {len(row)} cols → {row[:2]}")
        return pd.DataFrame(data, columns=header)

    def __len__(self):
        """Return the number of examples in the dataset."""
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve a single example by index and tokenize it.

        For training split:
            - Constructs the input as "Premise: ... Hypothesis: ... Label: ..."
            - Tokenizes the full input.
            - Masks the prefix tokens in the labels with -100 for GPT loss computation.

        For dev/test split:
            - Constructs the input without label as "Premise: ... Hypothesis: ... Label:"

        Returns:
            dict: Contains 'input_ids', 'attention_mask', 'labels' (train only), 'label_str'
        """
        row = self.data.iloc[idx]
        premise = row['premise']
        hypo = row['hypo']
        label = row['label']
        if self.lang == 'zh': # de-tokenize for Chinese
            premise = premise.replace(" ", "")
            hypo = hypo.replace(" ", "")

        if self.split == "train":
            prefix = f"Premise: {premise}\nHypothesis: {hypo}\nLabel:"
            full_text = prefix + str(self.LABEL2ID[label])
            tokenized = self.tokenizer(
                full_text,
                truncation=True,
                max_length=self.max_length,
                padding=False,
                return_tensors="pt"
            )
            tokenized = {k: v.squeeze(0) for k, v in tokenized.items()}

            prefix_ids = self.tokenizer(prefix).input_ids
            labels_ids = tokenized['input_ids'].clone()
            labels_ids[:len(prefix_ids)] = -100 # Masks the prefix tokens in the labels with -100 for GPT loss computation.
            tokenized['labels'] = labels_ids
            tokenized['label_str'] = str(self.LABEL2ID[label])
            return tokenized
        else:
            text = f"Premise: {premise}\nHypothesis: {hypo}\nLabel:"
            tokenized = self.tokenizer(
                text,
                truncation=True,
                max_length=self.max_length,
                padding=False,
                return_tensors="pt"
            )
            tokenized = {k: v.squeeze(0) for k, v in tokenized.items()}
            tokenized['label_str'] = str(self.LABEL2ID[label])
            return tokenized

    @staticmethod
    def collate_fn(batch):
        """
        Collate a batch of examples into padded tensors.

        Pads 'input_ids' and 'attention_mask' to the max length in the batch.
        Pads 'labels' with -100 if present.
        Collects 'label_str' for reference.

        Returns:
            dict: Padded tensors and label strings for the batch.
        """
        input_ids = torch.nn.utils.rnn.pad_sequence(
            [b['input_ids'] for b in batch],
            batch_first=True,
            padding_value=0
        )
        attention_mask = torch.nn.utils.rnn.pad_sequence(
            [b['attention_mask'] for b in batch],
            batch_first=True,
            padding_value=0
        )

        if 'labels' in batch[0]:
            labels = torch.nn.utils.rnn.pad_sequence(
                [b['labels'] for b in batch],
                batch_first=True,
                padding_value=-100
            )
        else:
            labels = None

        label_strs = [b['label_str'] for b in batch]

        out = {"input_ids": input_ids, "attention_mask": attention_mask, "label_strs": label_strs}
        if labels is not None:
            out["labels"] = labels
        return out

## Load Fresh Tokenizer and Model

In [None]:
EPOCHS = 1
BATCH_SIZE = 4
LR = 5e-5
WEIGHT_DECAY = 0.01
CORRECT_BIAS = True

model = load_moe_model('./gpt2-moe-finetuned', device='cuda')
tokenizer = GPT2Tokenizer.from_pretrained('./gpt2-moe-finetuned')

## English Baseline

In [None]:
TRAIN_SUBSET = 1
DEV_SUBSET = 1
TEST_SUBSET = 1

train_dataset = XNLIDataset(
    split="train",
    lang="en",
    tokenizer=tokenizer,
    subset=TRAIN_SUBSET
)

dev_dataset = XNLIDataset(
    split="dev",
    lang="en",
    tokenizer=tokenizer,
    subset=DEV_SUBSET
)

test_dataset = XNLIDataset(
    split="test",
    lang="en",
    tokenizer=tokenizer,
    subset=TEST_SUBSET
)

In [None]:
# Create DataLoaders for training and validation datasets
train_loader = DataLoader(train_dataset,batch_size=BATCH_SIZE,shuffle=True,collate_fn=XNLIDataset.collate_fn)
dev_loader = DataLoader(dev_dataset,shuffle=False,collate_fn=XNLIDataset.collate_fn)

VOCAB_SIZE = tokenizer.vocab_size

criterion = torch.nn.CrossEntropyLoss()
# Initialize optimizer
optimizer = AdamW(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY, correct_bias=CORRECT_BIAS)
# Track training progress
global_train_losses = []
total_train_loss = 0.0
total_train_steps = 0
print_interval = 10

# Track best dev accuracy for model saving
# This only works for epoch > 1
best_dev_acc = 0.0
SAVE_DIR = "best_model"
os.makedirs(SAVE_DIR, exist_ok=True)

# Training loop
for epoch in range(EPOCHS):
    print(f"Epoch {epoch+1}/{EPOCHS}")
    model.train()
    # Iterate over batches
    loop = tqdm(train_loader, desc="Training")
    for batch in loop:
        input_ids = batch["input_ids"].to(DEVICE)        # [B, seq_len]
        attention_mask = batch["attention_mask"].to(DEVICE)
        labels = batch.get("labels").to(DEVICE)                    # [B, seq_len]

        optimizer.zero_grad()

        hidden_states = model(input_ids=input_ids, attention_mask=attention_mask)['last_hidden_state']  # [B, seq_len, hidden]

        """
        TODO-9: Compute next-token loss from hidden states and update model parameters.

        Implementation hints:
        1. Convert hidden states to logits over the vocabulary using model.hidden_state_to_token.
        2. Shift logits and labels for next-token prediction to align each prediction with the correct next token.
        3. Compute the cross-entropy loss, making sure positions with label=-100 are ignored.
        4. Backpropagate and update model parameters.
        """

        ### YOUR CODE HERE
        vocabulary_logits = model.hidden_state_to_token(hidden_states)
        shifted_logits = vocabulary_logits[:, :-1, :].contiguous()
        shifted_labels = labels[:, 1:].contiguous()
        logits_for_loss = shifted_logits.view(-1, VOCAB_SIZE)
        labels_for_loss = shifted_labels.view(-1)
        loss = criterion(logits_for_loss, labels_for_loss)

        loss.backward()
        optimizer.step()

        total_train_loss += loss.item()
        total_train_steps += 1
        global_train_avg_loss = total_train_loss / total_train_steps
        global_train_losses.append(global_train_avg_loss)

        loop.set_postfix({'avg_loss': f"{global_train_avg_loss:.4f}"})

    print(f"Epoch {epoch+1} finished | Global Avg Loss: {global_train_avg_loss:.4f}")

    acc, all_preds, all_labels = evaluate_gpt2_xnli(model, tokenizer, dev_loader, max_gen_length=1, device=DEVICE)
    save_moe_model(model, "./gpt2-moe-finetuned-en")
    tokenizer.save_pretrained("./gpt2-moe-finetuned-en")

    print("Model finetuned on XNLI EN split has been saved to ./gpt2-moe-finetuned-en")
    print(f"The accuracy of this model is: {acc}")

In [None]:
del model
del tokenizer
torch.cuda.empty_cache()