In [None]:
# Imports
import torch
from torch import nn
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from helpers.memory import check_memory
from helpers.logging import get_gradient_stats
from dataclasses import dataclass, asdict
import time
from collections import defaultdict
import os
import glob 

main_device = 'cuda:0'
check_memory()

### Set model configuration settings

In [2]:
"""
Create a conf with configurable model settings.
- These will be passed into the model class during model initialization, so add new confs needed for whatever architecture is used.
- If you use the default conf values with the default model class defined later, it will exactly replicate the OlMoE-7B model,
   with 7B total params/1B active/64 experts.
"""

@dataclass
class BaseConf:
    """
    General config settings for this MoE
    """
    vocab_size: int = 50304 # Base OlMoE: 50304 (vocab size)
    D: int = 2048 # Base OlMoE: 2048 (hidden state dimension)
    H: int = 16 # Base OlMoE: 16 (number of attention heads)
    I: int = 1024 # Base OlMoE: 1024 (expert MLP dimension)
    n_experts: int = 64 # Base OlMoE: 64
    top_k: int = 8 # Base OlMoE: 8 
    norm_topk_prob: bool = False # Base OlMoE: false (whether to normalize so that expert weights sum to 1 after topk)
    padding_idx: int = 1 # Base OlMoE: 1 (index where padding gets mapped to)
    n_layers: int = 16 # Base OlMoE: 16 (transformer layers)
    rms_norm_eps: float = 1e-05 # Base OlMoE: 1e-05
    rope_theta: float = 10000.0 # Base OlMoe: 10000.0 (this is something needed for ROPE)
    max_position_embeddings: int = 4096 # Base OlMoE: 4096 (this is something needed for ROPE)
    router_aux_loss_coef: float = 0.01  # Base OlMoE: 0.01 (relative weight of balancing loss)
    attn_method: str = 'fa2' # In OlMoE this is chosen automatically, here we explicitly pass it - choose 'normal', 'sdpa', or 'fa2'

# Below settings lead to a ~450M param model
conf = BaseConf(
    D = 768,
    H = 8,
    I = int(768 * 4),
    n_layers = 12,
    n_experts = 4,
    top_k = 2,
    router_aux_loss_coef = 1e-2,
    max_position_embeddings = 2048
)

### Helper funs

In [3]:
""" 
These is a dump of helper functions called by the model layers, needed to make forward/backward passes correctly.
- `_prepare_4d_causal_attention_mask_with_cache_position` is used to create the upper-triangular infinity mask for attention (not used by flash attention).
- `load_balancing_loss_func` is the usual load balancing function.
- Add any new functions here if needed, but most experiments won't need to touch this section.
"""

# Create the upper-trangular matrix of infinities to mask future tokens in the attention softmax (needed for SDPA + normal attention)
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L1099C1-L1152 
def _prepare_4d_causal_attention_mask_with_cache_position(attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int):
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:
        min_dtype = torch.finfo(dtype).min
        causal_mask = torch.full((sequence_length, target_length), fill_value = min_dtype, dtype=dtype, device=device)
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )
    return causal_mask

# Load balancing loss, copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
def load_balancing_loss_func(gate_logits, num_experts, top_k, attention_mask):
    compute_device = gate_logits[0].device
    concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim = 0)
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (attention_mask[None, :, :, None, None].expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (attention_mask[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(compute_device))
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts

### Define model layers

In [4]:
""" 
First let's define the RMSNorm, ROPE, and self-attention layers.
- These are basically taken straight from the OlMoE source code, but heavily simplified/cleaned up.
- Note that RMSNorm is the ONLY norm type we define (same as OlMoE).
- These layers generally do not need to be modified for MoE experiments.
"""
from transformers.modeling_flash_attention_utils import _flash_attention_forward # Flash attention forward

class OlmoeRMSNorm(nn.Module):
    """
    Apply RMS Norm
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L137-L154
    - This is the only norm used in OlMoE!
      - It's used 4 times per layer (attention key norm, attention query norm, layer residual pre-attention norm, post-attention norm)
      - Also one additional time before the final LM head 
    """
    def __init__(self, D, eps = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(D))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim = True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class OlmoeRotaryEmbedding(nn.Module):
    """
    Get sin/cos ROPE embeddings
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L161-L219
    - Code has been simplified heavily since we're not using dynamic ROPE scaling
    """
    def __init__(self, conf: BaseConf):
        super().__init__()
        dim = int(conf.D/conf.H)
        inv_freq = 1.0 / (conf.rope_theta ** (torch.arange(0, dim, 2, dtype = torch.int64).float()/dim))
        self.register_buffer("inv_freq", inv_freq, persistent = False)
        
    @torch.no_grad()
    def forward(self, x, position_ids):
        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type = device_type, enabled = False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim = -1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype = x.dtype), sin.to(dtype = x.dtype)

class OlmoeAttention(nn.Module):
    """
    Attention implementation
    - Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L288-L391
    - Simplfied to handle base attention/sdpa/flash attention within this one class
    - Also doesn't support GQA (OlMoE doesn't use anyways)
    """
    def __init__(self, conf: BaseConf):
        super().__init__()
        self.attn_method = conf.attn_method
        self.D = conf.D # Hidden state dim
        self.H = conf.H # Num of attention heads
        self.Dh = int(conf.D/conf.H) # Dimensions per head
        
        # Initialize attention layers - no biases following OlMoE architecture
        self.q_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.k_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.v_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.o_proj = nn.Linear(self.D, self.D, bias = False)
        self.q_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)
        self.k_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)

    # Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L223-L255
    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim = 1):
        def rotate_half(x):
            """Rotates half the hidden dims of the input."""
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)
            
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor]):
        
        B, N , D = hidden_state.shape

        query_state = self.q_norm(self.q_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        key_state = self.k_norm(self.k_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        value_state = self.v_proj(hidden_state).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = self.apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if self.attn_method == 'normal':
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(self.Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + attention_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
            
        elif self.attn_method == 'sdpa':
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_state, key_state, value_state,
                attention_mask, dropout_p = 0.0, is_causal = True
            )
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
            
        elif self.attn_method == 'fa2':
            query_state = query_state.transpose(1, 2)
            key_state = key_state.transpose(1, 2)
            value_state = value_state.transpose(1, 2)
            attn_output = _flash_attention_forward(
                query_state, key_state, value_state,
                attention_mask, N, dropout = 0.0, use_top_left_mask = False, is_causal = True
            )
            attn_output = attn_output.reshape(B, N, D).contiguous()
            
        attn_output = self.o_proj(attn_output)
        return attn_output

In [5]:
""" 
Now let's define the MLP layer and the MoE layer.
- The MLP layer is simple; modify as needed.
- However, the MoE layer is much more complex, and this layer will probably need to be modified heavily for most experiments.
  - By default, I've defined three forward methods here. As currently implemented, they all generate IDENTICAL outputs but become increasingly more efficient yet complex.
    - `forward_slow` is the most straightforward implementation (similar to the original OlMoE code).
    - `forward_fast` is fast when you have large # experts, as it places all the relevant states for a single expert to be continguous in memory.
    - `forward_async` is fast for large GPU counts + large # of experts, as it batches all experts who belong on one device together, and also runs them async.
    - For initial testing, it's probably best to modify just `forward_slow`, and only modify the others once you want to run a large-scale training run.
  - Each forward method must return a tuple where the first element is the B x N x D MoE layer output, and the second element is the router logits. 
    - To return more, you'll need to also modify the transformer layer class in the next section.
"""
from transformers.activations import silu

class OlmoeMLP(nn.Module):
    """
    Individual expert MLP
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L258-L272
    """
    def __init__(self, conf: BaseConf):
        super().__init__()
        self.conf = conf
        self.D = conf.D
        self.I = conf.I
        self.gate_proj = nn.Linear(self.D, self.I, bias = False)
        self.up_proj = nn.Linear(self.D, self.I, bias = False)
        self.down_proj = nn.Linear(self.I, self.D, bias = False)
        self.act_fn = silu

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

class OlmoeMoe(nn.Module):
    """
    Entire MLP layer including router
    - Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L604-L649
    """
    def __init__(self, conf: BaseConf):
        super().__init__()
        self.n_experts = conf.n_experts
        self.top_k = conf.top_k
        self.norm_topk_prob = conf.norm_topk_prob
        self.gate = nn.Linear(conf.D, self.n_experts, bias = False) # Router
        self.experts = nn.ModuleList([OlmoeMLP(conf) for _ in range(self.n_experts)]) # Create experts using OlmoeMLP

        # Store a list of expert-device mappings
        self.expert_device_map = []
        for _, expert in enumerate(self.experts):
            ex_dev = next(expert.parameters()).device
            self.expert_device_map.append(str(ex_dev))
    
    def forward(self, hidden_state: torch.Tensor, moe_method: str) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Forward method routes to one of several possible other forward methods
        """
        if moe_method == 'forward_slow':
            return self.forward_slow(hidden_state)
        elif moe_method == 'forward_fast':
            return self.forward_fast(hidden_state)
        elif moe_method == 'forward_async':
            return self.forward_async(hidden_state)
        else:
            raise ValueError(f'Method "{moe_method}" not implemented.')
            
    def forward_slow(self, hidden_state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        This is the more intuitive forward pass which loops through each expert slowly
        """
        B, N, D = hidden_state.shape

         # Flatten out B x N x D to BN x D (flattened token-level reps) to route all tokens seperately
        hidden_state = hidden_state.view(B * N, D) # Flatten out B x N x D to BN x D (flattened token-level reps) to route all tokens seperately
        
        # Compute router (logits) and top-k
        router_logits = self.gate(hidden_state) # (BN, n_experts) - routing probability for each token
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)  # (BN, n_experts)

        # For each token, the selected TOP_K experts and the corresponding weights/expert indices
        routing_weights, selected_experts  = torch.topk(routing_weights, self.top_k, dim = -1)  # both (BN, top_k)
        routing_weights = routing_weights.to(hidden_state.dtype)
        # If you want the top-k weights to sum to 1
        if self.norm_topk_prob:
            routing_weights /= (routing_weights.sum(dim = -1, keepdim = True) + 1e-9)

        # One hot encode - for each expert, which topk x token is active - e.g. expert_assignment_mask[0, :] will be 0s if the first expert is never chosen
        expert_assignment_mask = F.one_hot(selected_experts, num_classes = self.n_experts).permute(2, 1, 0) # Creates (N_EXPERTS, TOP_K, BN)

        mlp_output = torch.zeros((B * N, D), dtype = hidden_state.dtype, device = hidden_state.device) # Initialize MLP output - later iterate through experts and sum onto this object
        
        # Iterate through all the experts, apply each expert to the tokens where the expert are relevant, multiple output by the weights for the topk/token for that expert, then sum onto the mlp_output obj
        for expert_ix, expert in enumerate(self.experts):
            
            expert_device = next(self.experts[expert_ix].parameters()).device # Get the device this expert lives on

            # For this expert, gives the (topk, token) coordinates which uses the expert
            topk_slot, token_indices = torch.where(expert_assignment_mask[expert_ix, :])
            
            if token_indices.numel() == 0:
                continue

            # Get hidden states for tokens that use this expert - shape of num_assigned_tokens x D
            tokens_for_expert = hidden_state[token_indices, :]
            tokens_for_expert = tokens_for_expert.to(expert_device) # Move to expert device

            # Move input to expert device and get expert output
            expert_output = expert(tokens_for_expert)
            # For each num_assigned_tokens, multiples it by the corresponding weight in topk_slot fort that token_index
            expert_output = expert_output * routing_weights[token_indices, topk_slot].unsqueeze(1)
            expert_output = expert_output.to(hidden_state.device) # Move to original device

            mlp_output.index_add_(0, token_indices, expert_output.to(hidden_state.dtype))

        mlp_output = mlp_output.reshape(B, N, D) # Convert back from BN x D -> B x N x D
        return mlp_output, router_logits


    def forward_fast(self, hidden_state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Efficient MoE routing that batches tokens for each expert in a single pass using gather -> scatter operations. 
        - This will be much faster for a large number of experts, but possibly lower for low expert counts.
        """
        B, N, D = hidden_state.shape
        BN = B * N

        # Flatten so we can route each token individually
        hidden_state_flat = hidden_state.reshape(BN, D)

        # 1) Compute router (logits) and top-k
        router_logits = self.gate(hidden_state_flat)  # (BN, n_experts)
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)  # (BN, n_experts)

        # The top_k values and the corresponding expert indices
        topk_weights, topk_experts = torch.topk(routing_weights, self.top_k, dim= - 1)  # both (BN, top_k)
        topk_weights = topk_weights.to(hidden_state.dtype)
        # If you want the top-k weights to sum to 1
        if self.norm_topk_prob:
            topk_weights = topk_weights / (topk_weights.sum(dim = -1, keepdim = True) + 1e-9)

        # 2) Now gather tokens into per-expert chunks in a single pass.
        # Flatten from (BN, top_k) -> (BN * top_k)
        topk_experts_flat = topk_experts.view(-1) # (BN * top_k,)
        topk_weights_flat = topk_weights.view(-1) # (BN * top_k,)

        # token_indices_flat = [0..BN-1], repeated top_k times (one for each top-k expert)
        token_indices_flat = torch.arange(BN, device = hidden_state.device).unsqueeze(1).expand(BN, self.top_k).reshape(-1) # shape = (BN * top_k,)

        # 3) Sort by expert-id so that we can group the tokens that go to each expert
        #    The sorted_experts array tells us the expert for each row in sorted order
        sorted_experts, sort_indices = torch.sort(topk_experts_flat)
        sorted_token_indices = token_indices_flat[sort_indices]
        sorted_weights = topk_weights_flat[sort_indices]

        # Gather the hidden states in the order of sorted_token_indices
        sorted_inputs = hidden_state_flat[sorted_token_indices]  # shape = (BN*top_k, D)

        # We'll accumulate the expert outputs in a flat buffer (BN, D), then reshape back
        mlp_output_flat = torch.zeros_like(hidden_state_flat, dtype = hidden_state.dtype, device = hidden_state.device)  # (BN, D)

        # 4) Walk through sorted_experts to find contiguous segments belonging to each expert. We can use torch.unique_consecutive to figure out segment boundaries
        unique_expert_ids, counts = torch.unique_consecutive(sorted_experts, return_counts = True)

        # Iterate through expert_ids for active experts
        start_idx = 0
        for expert_id, count in zip(unique_expert_ids, counts):
            # The chunk [start_idx : start_idx+count] corresponds to all tokens for this expert
            end_idx = start_idx + count

            # Pull out the relevant tokens for this expert
            chunk_input = sorted_inputs[start_idx:end_idx] # shape = (count, D)
            chunk_weight = sorted_weights[start_idx:end_idx].unsqueeze(1)  # (count, 1)
            chunk_tokens = sorted_token_indices[start_idx:end_idx] # (count,)

            # Move to device
            expert_device = next(self.experts[expert_id].parameters()).device
            chunk_input = chunk_input.to(expert_device)

            # Forward pass through this expert
            chunk_output = self.experts[expert_id](chunk_input) # (count, D)

            # Multiply by the top-k gate weight
            chunk_output = chunk_output * chunk_weight.to(chunk_output.device)

            # Bring it back to the main device
            chunk_output = chunk_output.to(mlp_output_flat.device)
            
            # Scatter-add back to the correct token positions
            mlp_output_flat.index_add_(0, chunk_tokens, chunk_output)

            start_idx = end_idx

        # Reshape from (BN, D) back to (B, N, D)
        mlp_output = mlp_output_flat.view(B, N, D)

        return mlp_output, router_logits
        
    def forward_async(self, hidden_state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Async MoE forward for optimal multi-GPU speeds that:
          1) Flattens tokens
          2) Does gating + top-k
          3) Sorts by (device, expert)
          4) Groups tokens by device, then sub-groups by expert
          5) Uses asynchronous CUDA streams to overlap transfers & expert compute
          6) Accumulates outputs back to the main device
        This is ~2.5x faster than forward_slow with 64 experts on 4 GPUs
        """
        B, N, D = hidden_state.shape
        BN = B * N
        
        main_device = hidden_state.device

        # ---------------- 1) Gating & Top-K ----------------
        hidden_state_flat = hidden_state.view(BN, D)
        router_logits = self.gate(hidden_state_flat)  # (BN, n_experts)
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)

        topk_weights, topk_experts = torch.topk(routing_weights, self.top_k, dim = -1)  # each (BN, top_k)
        topk_weights = topk_weights.to(hidden_state.dtype)

        if self.norm_topk_prob:
            topk_weights = topk_weights / (topk_weights.sum(dim = -1, keepdim = True) + 1e-9)

        # Flatten top-k arrays
        topk_experts_flat = topk_experts.view(-1) # (BN * top_k,)
        topk_weights_flat = topk_weights.view(-1) # (BN * top_k,)

        token_indices_flat = torch.arange(BN, device=main_device).unsqueeze(1).expand(BN, self.top_k).reshape(-1) # (BN * top_k,)

        # ---------------- 2) Sort by expert to identify device ---------------
        sorted_experts, sort_indices = torch.sort(topk_experts_flat) # (BN*top_k,)
        sorted_tokens = token_indices_flat[sort_indices] # (BN*top_k,)
        sorted_weights = topk_weights_flat[sort_indices] # (BN*top_k,)
        sorted_inputs = hidden_state_flat[sorted_tokens] # gather from main device

        # We'll build a dictionary: device_str -> list of (expert_id, token_idx, weight, input_vec)
        device_to_chunk = defaultdict(list)

        for i in range(sorted_experts.size(0)):
            ex_id = sorted_experts[i].item()
            tok_id = sorted_tokens[i].item()
            w_val = sorted_weights[i]
            inp_vec = sorted_inputs[i]

            dev_str = self.expert_device_map[ex_id]  # for example, 'cuda:1'
            
            # We store enough info to re-group by expert on the same device
            device_to_chunk[dev_str].append( (ex_id, tok_id, w_val, inp_vec) )

        # ---------------- 3) Create streams for each device used ---------------
        # We'll launch each device's compute in parallel
        device_streams = {}
        for dev_str in set(self.expert_device_map):
            # Create cuda streams
            device_streams[dev_str] = torch.cuda.Stream(device = dev_str)

        # We'll store results in a dict of device_str -> list of (token_idx, output_vec on main_device)
        device_results = defaultdict(list)

        # ---------------- 4) Dispatch to each device in its stream ---------------
        # We'll do a loop over the devices. For each device, we do sub-grouping by expert ID
        for dev_str, tuple_list in device_to_chunk.items():
            if len(tuple_list) == 0:
                continue

            # Sort by ex_id so we can do sub-chunks per expert
            tuple_list.sort(key=lambda x: x[0])  # (expert_id, token_idx, w_val, inp_vec)

            with torch.cuda.stream(device_streams[dev_str]):
                # We'll gather all tokens for this device into a single buffer on CPU, then copy once
                # or we can do multiple sub-chunks. Let's do the sub-chunk approach below.

                # Sub-group by expert
                idx_start = 0
                while idx_start < len(tuple_list):
                    current_ex = tuple_list[idx_start][0]
                    idx_end = idx_start
                    # gather all entries with the same expert_id
                    while idx_end < len(tuple_list) and tuple_list[idx_end][0] == current_ex:
                        idx_end += 1
                    sub_chunk = tuple_list[idx_start:idx_end]
                    idx_start = idx_end

                    # Unzip into Tensors
                    token_ids = [sc[1] for sc in sub_chunk]
                    w_vals = [sc[2] for sc in sub_chunk]
                    inps = [sc[3] for sc in sub_chunk]

                    token_ids_t = torch.tensor(token_ids, device = main_device, dtype=torch.long)
                    w_vals_t = torch.tensor(w_vals, device = main_device, dtype=hidden_state.dtype)
                    inps_t = torch.stack(inps, dim = 0).to(main_device)  # shape (count, D)

                    # Move inputs to dev_str non-blocking
                    inps_t = inps_t.to(dev_str, non_blocking=True)
                    w_vals_t = w_vals_t.unsqueeze(1).to(dev_str, non_blocking=True)

                    # forward pass
                    ex_id_int = current_ex
                    chunk_output = self.experts[ex_id_int](inps_t)  # runs on dev_str
                    chunk_output = chunk_output * w_vals_t  # gating

                    # Move output back to main device
                    chunk_output = chunk_output.to(main_device, non_blocking=True)

                    # We'll store for now, and index_add_ after we sync
                    device_results[dev_str].append( (token_ids_t, chunk_output) )

        # ---------------- 5) Synchronize & gather on main device ---------------
        mlp_output_flat = torch.zeros_like(hidden_state_flat, dtype=hidden_state.dtype)

        for dev_str, stream in device_streams.items():
            # Wait for everything launched in that stream to finish
            with torch.cuda.device(dev_str):
                stream.synchronize()

            # Now we can safely do index_add_ on main_device (already transferred)
            for (tok_ids, out_vecs) in device_results[dev_str]:
                mlp_output_flat.index_add_(0, tok_ids, out_vecs)

        # Reshape to [B, N, D]
        mlp_output = mlp_output_flat.view(B, N, D)
        return mlp_output, router_logits


In [6]:
""" 
Now let's define the transformer block.
- Most likely, there is nothing to change here, unless you need to change the input/outputs from the MoE layer.
- Note that this forward pass is nested within a `custom_forward` call in order to support gradient checkpointing.
"""
class OlmoeBlock(nn.Module):
    """
    A single transformer layer
    """
    def __init__(self, conf: BaseConf, layer_idx: int):
        super().__init__()
        self.D = conf.D
        self.self_attn = OlmoeAttention(conf = conf)
        self.moe = OlmoeMoe(conf)
        self.input_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)
        self.post_attention_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], moe_method: str, use_checkpointing: bool):
            
        def custom_forward(hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor], moe_method: str):

            ### Pre-SA Residual Stream + Norm ###
            residual = hidden_state
            hidden_state = self.input_layernorm(hidden_state)
            
            ### SA + Sum to Residual Stream ###
            attn_output = self.self_attn(
                hidden_state,
                attention_mask = attention_mask,
                position_ids = position_ids,
                position_embeddings = position_embeddings
            )
            hidden_state = residual + attn_output

            ### Pre-MLP Residual Stream + Norm ###
            residual = hidden_state
            hidden_state = self.post_attention_layernorm(hidden_state)
            
            ### MLP + Sum to Residual Stream###
            mlp_output, router_logits = self.moe(hidden_state, moe_method = moe_method)
            hidden_state = residual + mlp_output
            
            return hidden_state, router_logits
    
        if use_checkpointing:
            # Use gradient checkpointing to reduce activation memory
            hidden_state, router_logits = torch.utils.checkpoint.checkpoint(
                custom_forward, hidden_state, attention_mask, position_ids, position_embeddings, moe_method, use_reentrant = True
            )
        else:
            # Normal forward pass
            hidden_state, router_logits = custom_forward(
                hidden_state, attention_mask, position_ids, position_embeddings, moe_method
            )
            
        return hidden_state, router_logits

In [7]:
""" 
Now define the top-level model.
- This class is initialized with the `BaseConf` config settings as well as a list of expert-device mappings (leave blank for single-GPU tests).
- After initialization, it creates all child layers and moves the experts to their correct devices. 
  - All other parameters will continue to exist on the default device.
- Modify `_init_weights` to change the weight initialization scheme.
- The forward pass calls the children layers and also calculates the loss (standard cross-entropy + aux loss). 
"""
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting

class OlmoeModel(nn.Module):
    """
    The top level model object. Also handles weight initialization and loss calculations.
    """
    def __init__(self, conf: BaseConf, expert_device_map: None | list[int] = None):
        """
        Params:
            @conf: A configuration object of class BaseConf.
            @expert_device_map: A list of devices to store experts on. If `None`, stores them all on whatever the torch default device is.
              For example, `expert_device_map = ['cuda:0', 'cuda:1', 'cuda:1', 'cuda:2']` means to store expert 0 on cuda:0, experts 1-2 on the device cuda:1, and expert 3 on cuda:2.
        """
        super().__init__()
        self.conf = conf
        
        ### Layers ###
        self.embed_tokens = nn.Embedding(self.conf.vocab_size, self.conf.D, self.conf.padding_idx)
        self.rotary_emb = OlmoeRotaryEmbedding(conf = self.conf)
        self.layers = nn.ModuleList([OlmoeBlock(self.conf, layer_idx) for layer_idx in range(self.conf.n_layers)])
        self.norm = OlmoeRMSNorm(self.conf.D, eps = self.conf.rms_norm_eps)
        self.lm_head = nn.Linear(self.conf.D, self.conf.vocab_size, bias = False)
        
        ### Initialize weights ###
        self.apply(self._init_weights)

        ### Experts ###
        if expert_device_map is not None:
            self._move_experts_to_devices(expert_device_map)

    # OlMoE weight initiation - see https://github.com/huggingface/transformers/blob/8f1509a96c96747c893051ac947795cfb0750357/src/transformers/modeling_utils.py#L2500-L2515
    # Normal distribution for linear layers + embeddings
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            # In the vocab -> embedding layer, set all embeddings to 0 for the padding token (tokenizer.pad_token_id)
            if module is self.embed_tokens:
                self.embed_tokens.weight.data[self.conf.padding_idx].zero_() 
        # Seems to use default weight initialization for other layers
            # Move all parameters and buffers to the specified dtype
            
    def _move_experts_to_devices(self, expert_device_map: list[str]):
        """
        Move each expert in each layer's MoE to the specified device.
        """
        # Require that the length of expert_device_map equal the length of conf.n_experts.
        n_experts = self.conf.n_experts
        if len(expert_device_map) != n_experts:
            raise ValueError(f"expert_device_map has length {len(expert_device_map)} but n_experts = {n_experts}.")
            
        for _, layer in enumerate(self.layers):
            moe_block = layer.moe 
            for ex_idx, expert in enumerate(moe_block.experts):
                target_dev = expert_device_map[ex_idx]
                expert.to(target_dev)
            
    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.Tensor, moe_method: str, use_checkpointing : bool = False):

        hidden_state = self.embed_tokens(input_ids)
        B, N, D = hidden_state.shape

        ### Prep rotary embeddings + attention masks  ###
        cache_position = torch.arange(0, N, device = hidden_state.device)
        position_ids = cache_position.unsqueeze(0)
        position_embeddings = self.rotary_emb(hidden_state, position_ids) # Position embeddings to be shared across transformer layers

        # This is the upper-trangular matrix of infinities to mask future tokens in the attention softmax;
        if self.conf.attn_method in ['normal', 'sdpa']:
            causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, N, N, hidden_state.dtype, hidden_state.device, cache_position, B)
        # The flash attention mask is simpler - takes only the original attention mask or None
        elif self.conf.attn_method == 'fa2':
            causal_mask  = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        
        ### Transformer layers ###
        all_router_logits = () # Save router logits from each layer into this; will be needed for load balancing loss
        for i, layer in enumerate(self.layers):
            hidden_state, router_logits = layer(
                hidden_state,
                attention_mask = causal_mask,
                position_ids = position_ids,
                position_embeddings = position_embeddings,
                moe_method = moe_method,
                use_checkpointing = use_checkpointing
            )
            all_router_logits += (router_logits, )

        hidden_state = self.norm(hidden_state)
        output_logits = self.lm_head(hidden_state)

        ##### Calculate Loss #####
        # The labels object should be a tensor of token IDs or -100 (for attention mask, since don't want to calculate loss for those)
        label_ids = torch.where(input_ids == self.conf.padding_idx, torch.tensor(-100), input_ids)
        # Get regular loss
        base_loss = ForCausalLMLoss(output_logits, label_ids, self.conf.vocab_size)
        # Get load balancing loss
        aux_loss = load_balancing_loss_func(gate_logits = all_router_logits, num_experts = self.conf.n_experts, top_k = self.conf.top_k, attention_mask = attention_mask)
        # Get total loss = regular loss + .01 * load bal loss
        loss = base_loss + self.conf.router_aux_loss_coef * aux_loss 

        return {
            'all_router_logits': all_router_logits,
            'logits': output_logits,
            'aux_loss': aux_loss,
            'base_loss': base_loss,
            'loss': loss
        }

### Test the model

In [None]:
""" 
Let's load the model
- Set the default_device to specify where all the non-expert layers live (the experts are moved on model init)
- Set the default_dtype to specify the model dtype, all params will be in this dtype except for this explicitly specified differently in class definition
  - In the default OlMoE, RMSNorm is required to be f32 whereas all other params are bf16. 
"""
torch.set_default_device(main_device)
torch.set_default_dtype(torch.bfloat16)

model = OlmoeModel(
    conf,
    expert_device_map = ['cuda:0'] * 4 # We have 4 experts, here let's test them with all of them on cuda:0
)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
check_memory()

In [None]:
""" 
Let's load a forward pass with a batch size of 2, to make sure the model is able to run
- If you have multiple working forward methods, this is a good chance to test them for equality
"""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
prompt = ['I am a dog and I like to eat. My favorite food is', 'My cat is']
inputs = tokenizer(prompt, truncation = True, max_length = 128, padding = 'max_length', return_tensors = 'pt').to(main_device)

with torch.no_grad():
    output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_slow', use_checkpointing = False)
    output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_fast', use_checkpointing = False)
    # output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_async', use_checkpointing = False)

output_ids = torch.argmax(output['logits'][:, :, :], dim = 2)
for i in range(output_ids.size(0)):
    idx = inputs["attention_mask"].sum(dim = -1)[i].item() - 1 # get length of attention mask to find the last non-mask output token ix
    print(tokenizer.decode(output_ids[i, idx], skip_special_tokens=True))

### Test training

In [10]:
"""
First, load some validation data as a dataloader. This will be needed for evaluation during model training later. 
- The functino `load_shard_as_dataloader` loads JSON data shards, concatenates them with a seperator in between, then splits them by a given seq length.
- We do not load the dataloader for training data yet; the total size is too large, so instead these will be loaded later as needed in the training loop.
""" 
from helpers.dataset import load_shard_as_dataloader

val_dl = load_shard_as_dataloader(
    './../../data/val_shard.json',
    tokenizer,
    batch_size = 32,
    seq_len = 2048,
    eos_seperator_id = tokenizer.eos_token_id
)

In [None]:
""" 
Now, define a function for calculating validation metrics. This will be later used in the training loop.
- An untrained model should typically return validation loss of ~10 for the base cross-entropy loss.
"""
@torch.no_grad()
def get_val_stats(model, val_dl):
    """
    Get eval set metrics
    """
    model.eval()

    val_loss_sum = 0.0
    val_base_sum = 0.0
    val_aux_sum = 0.0
    val_steps = 0

    for val_batch in val_dl:
        val_input_ids = val_batch['input_ids'].to(main_device)
        val_attn_mask = val_batch['attention_mask'].to(main_device)
        
        test_outputs = model(val_input_ids, val_attn_mask, moe_method = 'forward_slow', use_checkpointing = False)

        val_loss_sum += test_outputs['loss'].detach().cpu().item()
        val_base_sum += test_outputs['base_loss'].detach().cpu().item()
        val_aux_sum  += test_outputs['aux_loss'].detach().cpu().item()

        val_steps += 1

    avg_test_loss = val_loss_sum / val_steps
    avg_test_base = val_base_sum / val_steps
    avg_test_aux  = val_aux_sum  / val_steps

    model.train()
    return {
        "loss": avg_test_loss,
        "base_loss": avg_test_base,
        "aux_loss": avg_test_aux
    }

get_val_stats(model, val_dl)

In [None]:
"""
Setup a Wandb run for logging.
"""
load_dotenv('./../../secrets.env')
wandb.login(key = os.getenv('WANDB_API_KEY'))
run = wandb.init(
    project = 'interpretable-moes', 
    name = 'test-5-slowbatch',
    notes = '',
    config = asdict(conf)
)

In [None]:
"""
Let's train the model. The training loop will loop through training data shards. Each shard will be loaded and concatenated into chunks of a specified size.
"""

# Set training constants. The batch size will be equal to micro_batch_size * accumumolation_steps.
lr = 5e-4
min_lr = 5e-5
warmup_steps = 100
decay_steps = 10000
max_grad_norm = 1.0
max_expert_grad_norm = 2.0
micro_batch_size = 32
accumulation_steps = 8
seq_len = 2048

# Initialize optimizer/scheduler. The scheduler combines a warmup + cosine annealing.
optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers = [
        torch.optim.lr_scheduler.LinearLR(optimizer, start_factor = 0.2, end_factor = 1.0, total_iters = warmup_steps), # Warmup
        torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, decay_steps, eta_min = min_lr, last_epoch = -1) # Cosine annealing
    ],
    milestones = [warmup_steps]
)

# Look for all training data files
shard_files = sorted(glob.glob("./../../data/train_shard_*.json"))
print(f"Found {len(shard_files)} shards.")

# Initialize step count
step = 0
total_tokens_trained = 0
model.train()

for shard_idx, shard_path in enumerate(shard_files):

    print(f"\n=== Loading shard {shard_path} (index {shard_idx}) ===")
    shard_dl = load_shard_as_dataloader(shard_path, tokenizer, batch_size = micro_batch_size * accumulation_steps, seq_len = seq_len, eos_seperator_id = tokenizer.eos_token_id)

    for batch_idx, batch in enumerate(shard_dl):

        # ====================== SPLIT BATCH INTO MICRO-BATCHES ======================
        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)

        if input_ids.size(0) < (accumulation_steps * micro_batch_size):
            print(f"Skipping leftover batch, need at least {accumulation_steps * micro_batch_size}")
            continue

        sub_input_ids = input_ids.split(micro_batch_size, dim = 0) 
        sub_attn_mask = attention_mask.split(micro_batch_size, dim = 0)

        # ====================== ZERO GRAD ONCE PER "BIG BATCH" ======================
        optimizer.zero_grad()
        
        # We'll track times and losses across micro-batches
        total_fwd_time = 0.0
        total_bwd_time = 0.0
        total_loss_val = 0.0
        start_batch = time.time()

        # We'll keep a list of dictionaries, one per layer, each mapping expert_id -> usage_count
        usage_accum = [defaultdict(int) for _ in range(model.conf.n_layers)]

        # ====================== MICRO-BATCH LOOP ======================
        for i in range(accumulation_steps):

            mb_input_ids = sub_input_ids[i]
            mb_attn_mask = sub_attn_mask[i]

            # ---------------------- Forward ----------------------
            start_fwd = time.time()
            outputs = model(mb_input_ids, mb_attn_mask, moe_method = 'forward_slow', use_checkpointing = True)
            loss = outputs['loss']
            fwd_time = time.time() - start_fwd
            total_fwd_time += fwd_time

            # ---------------------- Collect Expert Usage for This Micro-Batch ----------------------
            with torch.no_grad():
                for layer_idx, router_logits in enumerate(outputs["all_router_logits"]):
                    _, selected_experts = torch.topk(router_logits, model.conf.top_k, dim=-1)
                    unique_experts = selected_experts.flatten().unique()
                    for ex_id in unique_experts:
                        ex_count = (selected_experts == ex_id).sum().item()
                        usage_accum[layer_idx][int(ex_id)] += ex_count

            # ---------------------- Backward ----------------------
            # Divide by accumulation_steps so total gradient matches "big batch" size
            scaled_loss = loss / accumulation_steps
            start_bwd = time.time()
            scaled_loss.backward()
            bwd_time = time.time() - start_bwd
            total_bwd_time += bwd_time

            total_loss_val += loss.item()

        # ====================== GRAD CLIPPING & OPT STEP ======================
        shared_params = [p for n,p in model.named_parameters() if 'expert' not in n]
        expert_params = [p for n,p in model.named_parameters() if 'expert' in n]

        torch.nn.utils.clip_grad_norm_(shared_params, max_grad_norm)
        torch.nn.utils.clip_grad_norm_(expert_params, max_expert_grad_norm)

        optimizer.step()
        scheduler.step()

        # ============== METRICS ==============
        avg_loss = total_loss_val / accumulation_steps # Take the average loss over micro-batches. total_loss_val is the sum of 'loss.item()'.
        total_tokens_trained += attention_mask.sum().detach().cpu().item()
        metrics = {
            'step': step,
            'shard_idx': shard_idx,
            'batch_size': input_ids.shape[0],
            'total_tokens_trained': total_tokens_trained,
            'lr': optimizer.param_groups[0]['lr'],
            'train': {
                'loss': avg_loss,
                'base_loss': outputs['base_loss'].detach().cpu().item(), # From last microbatch only
                'aux_loss':  outputs['aux_loss'].detach().cpu().item() # From last microbatch only
            },
            'fwd_time':  total_fwd_time,
            'bwd_time':  total_bwd_time,
            'batch_time':  time.time() - start_batch
        }

        # ============== EXPENSIVE METRICS (EVERY 10 STEPS) ==============
        if step % 10 == 0:
            
            metrics['gradients'] = get_gradient_stats(model)

            # Convert expert-usage list of defaultdicts (usage_accum) into a more standard dict for logging
            usage_dict_final = {}
            for layer_idx, ex_dict in enumerate(usage_accum):
                usage_dict_final[layer_idx] = dict(ex_dict)  # convert defaultdict -> normal dict
                
            metrics['expert_usage'] = usage_dict_final

        # ============== EXTRA EXPENSIVE METRICS (EVERY 500 STEPS) ==============
        if step % 1000 == 0:
            metrics['val'] = get_val_stats(model, val_dl)

        # ============== SAVE (EVERY 5000 STEPS) ==============
        if step % 5000 == 0:
            torch.save(
                {
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'total_tokens_trained': total_tokens_trained,
                    'step': step + 1,
                },
                f"saves/checkpoint_{step:08d}.pt"
            )

        # ============== LOG TO W&B ==============
        wandb.log(metrics)

        # ============== PRINT ==============
        if step <= 10 or (step <= 100 and step % 10 == 0) or (step > 100 and step % 100 == 0):
            print(f"Step {step}: avg_loss={metrics['train']['loss']:.4f} "
                f"| fwd_time={metrics['fwd_time']:.2f}s | bwd_time={metrics['bwd_time']:.2f}s | batch_time = {metrics['batch_time']:.2f} "
                f"| lr={metrics['lr']:.1e}"
            )
            
        step += 1

wandb.finish()

In [None]:
from helpers.memory import clear_all_cuda_memory
clear_all_cuda_memory()
check_memory()