In [14]:
# Imports
import torch
from torch import nn
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from helpers.memory import check_memory, profile_memory
from helpers.logging import get_gradient_stats
from helpers.tree import enumerate_paths
from dataclasses import dataclass, asdict
import time
from collections import defaultdict
import os
import glob 
import json
from datetime import datetime


main_device = 'cuda:0'
seed = 1234
check_memory()

Device 0: NVIDIA A100 80GB PCIe
  Allocated: 1.45 GB
  Reserved: 1.53 GB
  Total: 79.25 GB



### Set model configuration settings

In [15]:
"""
Create a conf with configurable model settings.
- These will be passed into the model class during model initialization, so add new confs needed for whatever architecture is used.
- If you use the default conf values with the default model class defined later, it will exactly replicate the OlMoE-7B model,
   with 7B total params/1B active/64 experts.
"""

@dataclass
class ModelConf:
    """
    General config settings for this MoE
    """
    vocab_size: int = 50304 # Base OlMoE: 50304 (vocab size)
    D: int = 2048 # Base OlMoE: 2048 (hidden state dimension)
    H: int = 16 # Base OlMoE: 16 (number of attention heads)
    I: int = 1024 # Base OlMoE: 1024 (expert MLP dimension)

    # n_experts: int = 6 # Base OlMoE: 6 (non-shared experts only)
    # n_shared_experts: int = 0 # Base OlMoE: 0 (base OlMoE doesn't support shared experts, but may help with inducing expert specialization - see Deepseek paper)
    top_k: int = 2 # Base OlMoE: 8 
    norm_topk_prob: bool = False # Base OlMoE: false (whether to normalize so that expert weights sum to 1 after topk)
    padding_idx: int = 1 # Base OlMoE: 1 (index where padding gets mapped to)
    n_layers: int = 16 # Base OlMoE: 16 (transformer layers)
    rms_norm_eps: float = 1e-05 # Base OlMoE: 1e-05
    rope_theta: float = 10000.0 # Base OlMoe: 10000.0 (this is something needed for ROPE)
    max_position_embeddings: int = 4096 # Base OlMoE: 4096 (this is something needed for ROPE)
    attn_method: str = 'fa2' # In OlMoE this is chosen automatically, here we explicitly pass it - choose 'normal', 'sdpa', or 'fa2'
    
    # Hierarchical MoE settings
    n_layers_tree: int = 6 # The number of layers in the hierarchical tree, as least 1
    n_branches_tree: int = 2 # The number of branches, at least 2 (maybe consider 1 for debugging)

conf = ModelConf(
    D = 768,
    H = 8,
    I = int(768 * 4),
    n_layers_tree = 4,
    n_branches_tree = 3,
    max_position_embeddings = 2048,
)
# n_experts = int(self.n_branches_tree*(self.n_branches_tree ** self.n_layers_tree -1) // (self.n_branches_tree - 1))

In [16]:
n_layers_tree = conf.n_layers_tree
n_branches_tree = conf.n_branches_tree
node_paths, branch_paths, expert_indices  = enumerate_paths(n_layers_tree, n_branches_tree)
node_paths = node_paths.to(main_device)
branch_paths = branch_paths.to(main_device)
expert_indices = expert_indices.to(main_device)


### Helper funs

In [17]:
""" 
These is a dump of helper functions called by the model layers, needed to make forward/backward passes correctly.
- `_prepare_4d_causal_attention_mask_with_cache_position` is used to create the upper-triangular infinity mask for attention (not used by flash attention).
- `load_balancing_loss_func` is the usual load balancing function.
- Add any new functions here if needed, but most experiments won't need to touch this section.
"""

# Create the upper-trangular matrix of infinities to mask future tokens in the attention softmax (needed for SDPA + normal attention)
# Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L1099C1-L1152 
def _prepare_4d_causal_attention_mask_with_cache_position(attention_mask: torch.Tensor, sequence_length: int, target_length: int, dtype: torch.dtype, device: torch.device, cache_position: torch.Tensor, batch_size: int):
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:
        min_dtype = torch.finfo(dtype).min
        causal_mask = torch.full((sequence_length, target_length), fill_value = min_dtype, dtype=dtype, device=device)
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )
    return causal_mask

# Load balancing loss, copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
def load_balancing_loss_func(gate_logits, num_experts, top_k, attention_mask):
    compute_device = gate_logits[0].device
    concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim = 0)
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (attention_mask[None, :, :, None, None].expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (attention_mask[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(compute_device))
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts

### Define model layers

In [18]:
""" 
First let's define the RMSNorm, ROPE, and self-attention layers.
- These are basically taken straight from the OlMoE source code, but heavily simplified/cleaned up.
- Note that RMSNorm is the ONLY norm type we define (same as OlMoE).
- These layers generally do not need to be modified for MoE experiments.
"""
from transformers.modeling_flash_attention_utils import _flash_attention_forward # Flash attention forward

class OlmoeRMSNorm(nn.Module):
    """
    Apply RMS Norm
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L137-L154
    - This is the only norm used in OlMoE!
      - It's used 4 times per layer (attention key norm, attention query norm, layer residual pre-attention norm, post-attention norm)
      - Also one additional time before the final LM head 
    """
    def __init__(self, D, eps = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(D))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim = True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class OlmoeRotaryEmbedding(nn.Module):
    """
    Get sin/cos ROPE embeddings
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L161-L219
    - Code has been simplified heavily since we're not using dynamic ROPE scaling
    """
    def __init__(self, conf: ModelConf):
        super().__init__()
        dim = int(conf.D/conf.H)
        inv_freq = 1.0 / (conf.rope_theta ** (torch.arange(0, dim, 2, dtype = torch.int64).float()/dim))
        self.register_buffer("inv_freq", inv_freq, persistent = False)
        
    @torch.no_grad()
    def forward(self, x, position_ids):
        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type = device_type, enabled = False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim = -1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype = x.dtype), sin.to(dtype = x.dtype)

class OlmoeAttention(nn.Module):
    """
    Attention implementation
    - Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L288-L391
    - Simplfied to handle base attention/sdpa/flash attention within this one class
    - Also doesn't support GQA (OlMoE doesn't use anyways)
    """
    def __init__(self, conf: ModelConf):
        super().__init__()
        self.attn_method = conf.attn_method
        self.D = conf.D # Hidden state dim
        self.H = conf.H # Num of attention heads
        self.Dh = int(conf.D/conf.H) # Dimensions per head
        
        # Initialize attention layers - no biases following OlMoE architecture
        self.q_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.k_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.v_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.o_proj = nn.Linear(self.D, self.D, bias = False)
        self.q_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)
        self.k_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)

    # Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L223-L255
    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim = 1):
        def rotate_half(x):
            """Rotates half the hidden dims of the input."""
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)
            
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor]):
        
        B, N , D = hidden_state.shape

        query_state = self.q_norm(self.q_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        key_state = self.k_norm(self.k_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        value_state = self.v_proj(hidden_state).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = self.apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if self.attn_method == 'normal':
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(self.Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + attention_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
            
        elif self.attn_method == 'sdpa':
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_state, key_state, value_state,
                attention_mask, dropout_p = 0.0, is_causal = True
            )
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
            
        elif self.attn_method == 'fa2':
            query_state = query_state.transpose(1, 2)
            key_state = key_state.transpose(1, 2)
            value_state = value_state.transpose(1, 2)
            attn_output = _flash_attention_forward(
                query_state, key_state, value_state,
                attention_mask, N, dropout = 0.0, use_top_left_mask = False, is_causal = True
            )
            attn_output = attn_output.reshape(B, N, D).contiguous()
            
        attn_output = self.o_proj(attn_output)
        return attn_output

In [19]:
""" 
Now let's define the MLP layer and the MoE layer.
- The MLP layer is simple; modify as needed.
- However, the MoE layer is much more complex, and this layer will probably need to be modified heavily for most experiments.
  - By default, I've defined three forward methods here. As currently implemented, they all generate IDENTICAL outputs but become increasingly more efficient yet complex.
    - `forward_slow` is the most straightforward implementation (similar to the original OlMoE code). It is also the fastest for single-GPU, limited experts (32 or less) operations.
    - `forward_fast` is faster for large # experts, as it places all the relevant states for a single expert to be continguous in memory. For single GPU, it reaches parity w/forward_slow at ~64 experts.
    - `forward_async` is faster for large GPU counts + large # of experts, as it batches all experts who belong on one device together, and also runs them all asynchronously.
    - For initial testing, it's probably best to modify just `forward_slow`, and only modify the others once you want to run a large-scale training run.
  - Each forward method must return a tuple where the first element is the B x N x D MoE layer output, and the second element is the router logits. 
    - To return more, you'll need to also modify the transformer layer class in the next section.
"""
from transformers.activations import silu

class OlmoeMLP(nn.Module):
    """
    Individual expert MLP
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L258-L272
    - Added a reduction factor to reduce the intermediate dimension of the MLP
    """
    def __init__(self, conf: ModelConf, reduction_factor: int = 1):
        super().__init__()
        self.conf = conf
        self.D = conf.D
        self.I = max(1,int(conf.I / reduction_factor))
        self.gate_proj = nn.Linear(self.D, self.I, bias = False)
        self.up_proj = nn.Linear(self.D, self.I, bias = False)
        self.down_proj = nn.Linear(self.I, self.D, bias = False)
        self.act_fn = silu

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


In [20]:
# Since we consider a complete tree, the indices for parent and children can be computed by the index of the node
class OlmoeMoe(nn.Module):

    def __init__(self, conf):
        super(OlmoeMoe, self).__init__()
 
    
        self.n_branches_tree = conf.n_branches_tree
        self.n_layers_tree = conf.n_layers_tree
        
        self.n_routers = int((self.n_branches_tree ** self.n_layers_tree -1) // (self.n_branches_tree - 1))
        self.n_experts = int(self.n_branches_tree*(self.n_branches_tree ** self.n_layers_tree -1) // (self.n_branches_tree - 1))

        self.experts = nn.ModuleList([OlmoeMLP(conf, reduction_factor = self.get_expert_reduction_factor(expert_idx) ) for expert_idx in range(self.n_experts)])
        
        self.routers = nn.Linear(conf.D , self.n_branches_tree * self.n_routers, bias=False) # this is a flattened version of all routers
        
        self.n_all_paths = self.n_branches_tree ** self.n_layers_tree
        
        self.node_paths = node_paths
        self.branch_paths = branch_paths
        self.top_k = conf.top_k
    
        self.shared_expert = OlmoeMLP(conf)
        
        self.to(main_device)
    
    def get_expert_reduction_factor(self, index):

        if self.n_branches_tree == 1:
            return 1

        value = 1 + ((self.n_branches_tree - 1) * index) / self.n_branches_tree
        layer = math.floor(math.log(value, self.n_branches_tree)) + 1
        reduction_factor = self.n_branches_tree ** layer
        return reduction_factor

    def forward(self, hidden_state: torch.Tensor, node_paths, branch_paths, expert_indices) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            """
            A vectorized MoE forward pass.

            Input:
            hidden_state: Tensor of shape (B, N, D) representing token representations.
            Other inputs: node_paths, branch_paths, expert_indices
            Returns:
            aggregated_out: Tensor of shape (B, N, d_out) – the combined output from the experts.
            router_logits : Tensor of shape (BN, n_experts) – the raw gating scores.
            token_expert_weights: Tensor of shape (BN, n_experts) – the weights of the experts for each token.
            """
            B, N, D = hidden_state.shape
            # Flatten B x N tokens into BN tokens.
            hidden_state_flat = hidden_state.view(B*N, D)  # shape: (BN, D)

            
            # 1) Compute routing (gate) logits and probabilities.
            router_logits = self.routers(hidden_state_flat).reshape(B*N,  self.n_routers, self.n_branches_tree) # (BN, self.n_routers, self.n_branches_tree)
            routing_weights = F.softmax(router_logits, dim=-1) # (BN, self.n_routers, self.n_branches_tree)
            routing_weights = routing_weights[:, node_paths, branch_paths] # (BN, self.n_all_paths, n_layers_tree)
            routing_weights_finalprob = routing_weights.prod(dim=-1) # (BN, self.n_all_paths)
            routing_weights_probs = routing_weights.cumprod(dim=-1) # (BN, self.n_all_paths, n_layers_tree)
            _, topk_indices = torch.topk(routing_weights_finalprob, self.top_k, dim=-1)
            selected_experts = expert_indices[topk_indices] # (BN, self.top_k, self.n_layers_tree)
            routing_weights_probs = torch.gather(input=routing_weights_probs, dim=1, index=topk_indices.unsqueeze(-1).expand(-1, -1, self.n_layers_tree))
            
            one_hot = F.one_hot(selected_experts, num_classes=self.n_experts).to(hidden_state_flat.device)
            weighted_one_hot = one_hot * routing_weights_probs.unsqueeze(-1) # rel_weights.unsqueeze(-1) has shape (BN, self.n_all_paths, self.n_layers_tree, 1)
            token_expert_weights = weighted_one_hot.sum(dim=(1, 2)) # (BN, self.n_experts)
            token_expert_weights = token_expert_weights.to(hidden_state_flat.dtype)
            # 2) self.use_lflb and self.norm_topk_prob are not used in this implementation
            # 3) ---------------- Dispatch tokens to experts and accumulate outputs ----------------
    
            mlp_output = torch.zeros((B * N, D), dtype=hidden_state.dtype, device=hidden_state.device)
            
            # For each expert, fetch all the tokens that are activated (i.e. weight > 0)
            for expert_ix, expert in enumerate(self.experts):
                # token_expert_weights[:, expert_ix] is (BN,)
                token_mask = token_expert_weights[:, expert_ix] > 0
                if token_mask.sum() == 0:
                    continue  # No tokens routed to this expert.
                token_indices = token_mask.nonzero(as_tuple=True)[0]  # indices of activated tokens.
                
                # Gather input tokens for this expert.
                tokens_for_expert = hidden_state_flat[token_indices, :]  # (num_tokens, D)
                
                # Move tokens to the expert's device.
                expert_device = next(expert.parameters()).device
                tokens_for_expert = tokens_for_expert.to(expert_device)
                
                # Forward through expert.
                expert_output = expert(tokens_for_expert)  # (num_tokens, D)
                
                # Multiply each expert output by its corresponding routing weight.
                weights = token_expert_weights[token_indices, expert_ix].unsqueeze(1).to(expert_device)  # (num_tokens, 1)
                expert_output = expert_output * weights
                
                # Move back to the original device and accumulate into mlp_output.
                expert_output = expert_output.to(mlp_output.device)
                mlp_output.index_add_(0, token_indices, expert_output.to(hidden_state.dtype))
            
            # Reshape aggregated output back to (B, N, D)
            aggregated_out = mlp_output.view(B, N, D)
            
            # 4 ) Add the shared expert output
            shared_output = self.shared_expert(hidden_state)
            aggregated_out = aggregated_out + shared_output
            
            return aggregated_out, router_logits, token_expert_weights
            

        
    

In [21]:
hmoe = OlmoeMoe(conf)

# Create a dummy hidden_state tensor with shape (B, N, D).
B = 2    # batch size
N = 10   # number of tokens per batch
hidden_state = torch.randn(B, N, conf.D).to(main_device)

# Run the forward pass.
with torch.no_grad():
    aggregated_out, router_logits, token_expert_weights = hmoe(hidden_state, node_paths, branch_paths, expert_indices)

# Print out the results.
print("Aggregated output shape:", aggregated_out.shape)  # Expected: (B, N, D)
print("Router logits shape:", router_logits.shape)       # Expected: (B*N, n_routers, n_branches_tree)
print("Expert weights shape:", token_expert_weights.shape)

Aggregated output shape: torch.Size([2, 10, 768])
Router logits shape: torch.Size([20, 40, 3])
Expert weights shape: torch.Size([20, 120])


In [22]:
""" 
Now let's define the transformer block.
- Most likely, there is nothing to change here, unless you need to change the input/outputs from the MoE layer.
- Note that this forward pass is nested within a `custom_forward` call in order to support gradient checkpointing.
"""

class OlmoeBlock(nn.Module):
    """
    A single transformer layer
    """
    def __init__(self, conf: ModelConf, layer_idx: int):
        super().__init__()
        self.D = conf.D
        self.self_attn = OlmoeAttention(conf = conf)
        # self.self_attn = torch.compile(self.self_attn) # attn layers can be compiled for speed 
        self.moe = OlmoeMoe(conf)
        self.input_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)
        self.post_attention_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)

    def forward(
        self,
        hidden_state: torch.Tensor,
        attention_mask: torch.Tensor,
        position_ids: torch.LongTensor,
        position_embeddings: tuple[torch.Tensor, torch.Tensor]
        ):

        def custom_forward(hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings):

            ### Pre-SA Residual Stream + Norm ###
            residual = hidden_state
            hidden_state = self.input_layernorm(hidden_state)
            
            ### SA + Sum to Residual Stream ###
            attn_output = self.self_attn(
                hidden_state,
                attention_mask = attention_mask,
                position_ids = position_ids,
                position_embeddings = position_embeddings
            )
            hidden_state = residual + attn_output

            ### Pre-MLP Residual Stream + Norm ###
            residual = hidden_state
            hidden_state = self.post_attention_layernorm(hidden_state)
            
            ### MLP + Sum to Residual Stream###
            mlp_output, router_logits, topk_experts = self.moe(hidden_state, node_paths, branch_paths, expert_indices)
            hidden_state = residual + mlp_output
            
            return hidden_state, router_logits, topk_experts
        
            
        hidden_state, router_logits, topk_experts = custom_forward(
            hidden_state, attention_mask, position_ids, position_embeddings
        )

            
        return hidden_state, router_logits, topk_experts

In [23]:
""" 
Now define the top-level model.
- This class is initialized with the `ModelConf` config settings as well as a list of expert-device mappings (leave blank for single-GPU tests).
- After initialization, it creates all child layers and moves the experts to their correct devices. 
  - All other parameters will continue to exist on the default device.
- Modify `_init_weights` to change the weight initialization scheme.
- The forward pass calls the children layers and also calculates the loss (standard cross-entropy + aux loss). 
"""
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting

class OlmoeModel(nn.Module):
    """
    The top level model object. Also handles weight initialization and loss calculations.
    """
    def __init__(self, conf: ModelConf, primary_device: str, expert_device_map: None|list[str] = None):
        """
        Params:
            @conf: A configuration object of class ModelConf.
            @primary_device: A device for which to store the dense layers and shared experts on.
            @expert_device_map: A list of devices to store experts on. If `None`, stores them all on whatever the torch default device is.
              For example, `expert_device_map = ['cuda:0', 'cuda:1', 'cuda:1', 'cuda:2']` means to store expert 0 on cuda:0, experts 1-2 on the device cuda:1, and expert 3 on cuda:2.
        """
        super().__init__()
        self.conf = conf
        
        ### Layers ###
        self.embed_tokens = nn.Embedding(self.conf.vocab_size, self.conf.D, self.conf.padding_idx)
        self.rotary_emb = OlmoeRotaryEmbedding(conf = self.conf)
        self.layers = nn.ModuleList([OlmoeBlock(self.conf, layer_idx) for layer_idx in range(self.conf.n_layers)])
        self.norm = OlmoeRMSNorm(self.conf.D, eps = self.conf.rms_norm_eps)
        self.lm_head = nn.Linear(self.conf.D, self.conf.vocab_size, bias = False)
        
        ### Initialize weights ###
        self.apply(self._init_weights)

        ### Model ###
        self.to(primary_device)

        ### Experts ###
        if expert_device_map is not None:
            self._move_experts_to_devices(expert_device_map)

    # OlMoE weight initiation - see https://github.com/huggingface/transformers/blob/8f1509a96c96747c893051ac947795cfb0750357/src/transformers/modeling_utils.py#L2500-L2515
    # Normal distribution for linear layers + embeddings
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            # In the vocab -> embedding layer, set all embeddings to 0 for the padding token (tokenizer.pad_token_id)
            if module is self.embed_tokens:
                self.embed_tokens.weight.data[self.conf.padding_idx].zero_() 
        # Seems to use default weight initialization for other layers
            # Move all parameters and buffers to the specified dtype
            
    def _move_experts_to_devices(self, expert_device_map: list[str]):
        """
        Move each expert in each layer's MoE to the specified device.
        """
        # Require that the length of expert_device_map equal the length of conf.n_experts.
        n_experts = int(self.conf.n_branches_tree*(self.conf.n_branches_tree ** self.conf.n_layers_tree -1) // (self.conf.n_branches_tree - 1))
        if len(expert_device_map) != n_experts:
            raise ValueError(f"expert_device_map has length {len(expert_device_map)} but n_experts = {n_experts}.")
            
        for _, layer in enumerate(self.layers):
            moe_block = layer.moe 
            for ex_idx, expert in enumerate(moe_block.experts):
                target_dev = expert_device_map[ex_idx]
                expert.to(target_dev)
            
    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.Tensor, moe_method: str, use_lflb: bool = False, use_checkpointing : bool = False):
        """
        Params:
            @input_ids: A tensor of input IDs of size B x N, where B is the batch size and N is the sequence length.
            @attention_mask: An attention mask tensor of size B x N.
            @moe_method: The method to use to calculate the MoE routing. See the `OlmoeMoe` class for details.
            @use_lflb: Whether or not to use loss-free balancing.
            @use_checkpointing: Whether to use gradient checkpointing. Only set `True` during training.
        """
        hidden_state = self.embed_tokens(input_ids)
        B, N, D = hidden_state.shape

        ### Prep rotary embeddings + attention masks  ###
        cache_position = torch.arange(0, N, device = hidden_state.device)
        position_ids = cache_position.unsqueeze(0)
        position_embeddings = self.rotary_emb(hidden_state, position_ids) # Position embeddings to be shared across transformer layers

        # This is the upper-trangular matrix of infinities to mask future tokens in the attention softmax;
        if self.conf.attn_method in ['normal', 'sdpa']:
            causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, N, N, hidden_state.dtype, hidden_state.device, cache_position, B)
        # The flash attention mask is simpler - takes only the original attention mask or None
        elif self.conf.attn_method == 'fa2':
            causal_mask  = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        
        ### Transformer layers ###
        all_router_logits = () # Save router logits from each layer into this; will be needed for load balancing loss
        all_topk_experts = () # Return topk experts
        
        for _, layer in enumerate(self.layers):
            hidden_state, router_logits, topk_experts = layer(
                hidden_state,
                attention_mask = causal_mask,
                position_ids = position_ids,
                position_embeddings = position_embeddings
            )
            all_router_logits += (router_logits, )
            all_topk_experts += (topk_experts,)  # Store the topk_experts for each layer

        hidden_state = self.norm(hidden_state)
        output_logits = self.lm_head(hidden_state)

        ##### Calculate Loss #####
        # The labels object should be a tensor of token IDs or -100 (for attention mask, since don't want to calculate loss for those)
        label_ids = torch.where(input_ids == self.conf.padding_idx, torch.tensor(-100), input_ids)
        # Get regular loss
        base_loss = ForCausalLMLoss(output_logits, label_ids, self.conf.vocab_size)
        # Get load balancing loss
        aux_loss = torch.tensor(0.0).to(main_device)
        # aux_loss = load_balancing_loss_func(gate_logits = all_router_logits, num_experts = self.conf.n_experts, top_k = self.conf.top_k, attention_mask = attention_mask)

        return {
            'all_router_logits': all_router_logits,
            'all_topk_experts': all_topk_experts,
            'logits': output_logits,
            'aux_loss': aux_loss,
            'base_loss': base_loss
        }

### Test the model

In [24]:
import os
os.environ["TORCH_LOGS"] = "recompiles"
""" 
Let's load the model
- Set the default_device to specify where all the non-expert layers live (the experts are moved on model init)
- Set the default_dtype to specify the model dtype, all params will be in this dtype except for this explicitly specified differently in class definition
  - In the default OlMoE, RMSNorm is required to be f32 whereas all other params are bf16. 
"""
# torch.set_default_device(main_device) # This is buggy, don't use
torch.set_default_dtype(torch.bfloat16)
torch.set_float32_matmul_precision('medium') # See https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html 
torch.manual_seed(seed)

n_experts = int(conf.n_branches_tree*(conf.n_branches_tree ** conf.n_layers_tree -1) // (conf.n_branches_tree - 1))

model = OlmoeModel(
    conf,
    primary_device = main_device, # Where to store dense layers and shared experts
    expert_device_map = ['cuda:0'] * n_experts # Here let's test them with all of them on cuda:0
)
model = torch.compile(model)
# make the output in model.forward in moe.py to be a tensor of shape (BN, D)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
check_memory()

Total parameters: 679,121,664
Device 0: NVIDIA A100 80GB PCIe
  Allocated: 1.38 GB
  Reserved: 2.68 GB
  Total: 79.25 GB



In [25]:
""" 
Let's load a forward pass with a batch size of 2, to make sure the model is able to run
- If you have multiple working forward methods, this is a good chance to test them for equality
"""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
prompt = ['I am a dog and I like to eat. My favorite food is', 'My cat is']
inputs = tokenizer(prompt, truncation = True, max_length = 128, padding = 'max_length', return_tensors = 'pt').to(main_device)

with torch.no_grad():
    # output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_slow', use_lflb = True, use_checkpointing = False)
    # output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_fast', use_lflb = True, use_checkpointing = False)
    output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_async', use_lflb = True, use_checkpointing = False)
    

output_ids = torch.argmax(output['logits'][:, :, :], dim = 2)
for i in range(output_ids.size(0)):
    idx = inputs["attention_mask"].sum(dim = -1)[i].item() - 1 # get length of attention mask to find the last non-mask output token ix
    print(tokenizer.decode(output_ids[i, idx], skip_special_tokens=True))

W0212 15:25:24.730000 35197 torch/_dynamo/convert_frame.py:844] [26/8] torch._dynamo hit config.cache_size_limit (8)
W0212 15:25:24.730000 35197 torch/_dynamo/convert_frame.py:844] [26/8]    function: 'forward' (/tmp/ipykernel_35197/750249474.py:31)
W0212 15:25:24.730000 35197 torch/_dynamo/convert_frame.py:844] [26/8]    last reason: 26/0: tensor 'L['x']' size mismatch at index 0. expected 252, actual 4
W0212 15:25:24.730000 35197 torch/_dynamo/convert_frame.py:844] [26/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0212 15:25:24.730000 35197 torch/_dynamo/convert_frame.py:844] [26/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


 audience
�


In [26]:
with torch.no_grad():
    print(profile_memory(
        model,
        input_ids = inputs['input_ids'],
        attention_mask = inputs['attention_mask'],
        moe_method = 'forward_slow',
        use_lflb = False,
        use_checkpointing = False
    ))

{'runs': 10, 'average_time': '0.37910819s', 'average_peak_mem': '1532.0788MB', 'average_increase_mem_MB': '2.7723MB'}


### Test training

In [27]:
"""
First, load some validation data as a dataloader. This will be needed for evaluation during model training later. 
- The function `load_shard_as_dataloader` loads JSON data shards, concatenates them with a seperator in between, then splits them by a given seq length.
- We do not load the dataloader for training data yet; the total size is too large, so instead these will be loaded later as needed in the training loop.
""" 
from helpers.dataset import load_shard_as_dataloader

val_dl = load_shard_as_dataloader(
    './../../data/val_shard.json',
    tokenizer,
    batch_size = 32,
    seq_len = 2048,
    eos_seperator_id = tokenizer.eos_token_id
)

In [28]:
""" 
Now, define a function for calculating validation metrics. This will be later used in the training loop.
- An untrained model should typically return validation loss of ~10 for the base cross-entropy loss.
"""
@torch.no_grad()
def get_val_stats(model, val_dl, router_aux_loss_coef):
    """
    Get eval set metrics
    """
    model.eval()

    val_loss_sum = 0.0
    val_base_sum = 0.0
    val_aux_sum = 0.0
    val_steps = 0

    for val_batch in val_dl:
        val_input_ids = val_batch['input_ids'].to(main_device)
        val_attn_mask = val_batch['attention_mask'].to(main_device)
        
        test_outputs = model(val_input_ids, val_attn_mask, moe_method = 'forward_slow', use_checkpointing = False)

        val_base_sum += test_outputs['base_loss'].detach().cpu().item()
        val_aux_sum  += test_outputs['aux_loss'].detach().cpu().item()
        val_loss_sum += (test_outputs['base_loss'] + router_aux_loss_coef * test_outputs['aux_loss']).detach().cpu().item()
        
        val_steps += 1

    avg_test_loss = val_loss_sum / val_steps
    avg_test_base = val_base_sum / val_steps
    avg_test_aux  = val_aux_sum  / val_steps

    model.train()
    return {
        "loss": avg_test_loss,
        "base_loss": avg_test_base,
        "aux_loss": avg_test_aux
    }

get_val_stats(model, val_dl, 1e-2)

{'loss': 10.99523800611496, 'base_loss': 10.99523800611496, 'aux_loss': 0.0}

In [29]:
""" 
Set training constants to be used for training later.
- The batch size will be equal to micro_batch_size * accumulation_steps.
"""
@dataclass
class TrainConf:
    router_aux_loss_coef: float = 0.005  # Base OlMoE: 0.01 (relative weight of balancing loss)
    use_lflb: bool = False # Use loss-free load balancing
    bias_update_rate: float = .001 # Bias update rate for lflb
    lr: float = 5e-4 * (64 * 8)/(256) * 1.2 # The starting LR (after warmup)
    min_lr: float = 5e-5 # The minimum LR
    warmup_steps: int = 500 # How long it takes to warmup to the starting LR
    decay_steps: int = 19500 # How long it takes to decay from the starting LR to the minimum LR
    max_grad_norm: float = 1.0 # Gradient clipping for non-expert grads
    max_expert_grad_norm: float = 1.0 # Gradient clipping for expert grads
    micro_batch_size: int = 16 # Size of a microbatch
    accumulation_steps: int = 8 # Number of microbatches within a batch
    seq_len: int = 2048 # The sequence length

train_conf = TrainConf()

In [30]:
"""
Setup a Wandb run for logging. Choose a run name and notes for the run!
"""
RUN_NAME = 'test-11 -single-gpu -experts-16 -topk-3(+1) -forward-slow -lfbl'
RUN_NOTES = 'Baseline test with 16 experts (3+1), and LFBL. Memory savings with compile + fused AdamW. Slight param change from test-10 (higher peak LR + lower lfbl coef)'

load_dotenv('./../../secrets.env')
wandb.login(key = os.getenv('WANDB_API_KEY'))
run = wandb.init(
    project = 'interpretable-moes', 
    name = RUN_NAME,
    notes = RUN_NOTES,
    config = {**asdict(conf), **asdict(train_conf)}
)

# (Optional) Also log various info as a wandb media object.
additional_log_notes = {
    'run_name': RUN_NAME,
    'notes': RUN_NOTES,
    'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'total_model_params': sum(p.numel() for p in model.parameters()),
    'available_cuda_gpus': [torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())],
    'model_conf': asdict(conf),
    'train_conf': asdict(train_conf)
}

wandb.log({'conf': wandb.Html(f"<pre style='font-size:12px;'>{json.dumps(additional_log_notes, indent = 2)}</pre>")})

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33myuanbo096[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


In [None]:
"""
Let's train the model.
- The training loop will loop through training data shards. Each shard will be loaded and concatenated into chunks of size seq_len.
- Things to consider implementing in the future: more aggressive router LR decay (to encourage router stability)
"""
# Initialize optimizer/scheduler. The scheduler combines a warmup + cosine annealing.
optimizer = torch.optim.AdamW(model.parameters(), lr = train_conf.lr, fused = True)
scheduler = torch.optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers = [
        torch.optim.lr_scheduler.LinearLR(optimizer, start_factor = 0.2, end_factor = 1.0, total_iters = train_conf.warmup_steps), # Warmup
        torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, train_conf.decay_steps, eta_min = train_conf.min_lr, last_epoch = -1) # Cosine annealing
    ],
    milestones = [train_conf.warmup_steps]
)

# Look for all training data files
shard_files = sorted(glob.glob("./../../data/train_shard_*.json"))
print(f"Found {len(shard_files)} shards.")

# Initialize step count
step = 0
total_tokens_trained = 0
model.train()
torch.manual_seed(seed)

for shard_idx, shard_path in enumerate(shard_files):

    print(f"\n=== Loading shard {shard_path} (index {shard_idx}) ===")
    shard_dl = load_shard_as_dataloader(shard_path, tokenizer, batch_size = train_conf.micro_batch_size * train_conf.accumulation_steps, seq_len = train_conf.seq_len, eos_seperator_id = tokenizer.eos_token_id)

    for batch_idx, batch in enumerate(shard_dl):

        # ====================== SPLIT BATCH INTO MICRO-BATCHES ======================
        input_ids = batch['input_ids'].to(main_device)
        attention_mask = batch['attention_mask'].to(main_device)

        if input_ids.size(0) < (train_conf.accumulation_steps * train_conf.micro_batch_size):
            print(f"Skipping leftover batch, need at least {train_conf.accumulation_steps * train_conf.micro_batch_size}")
            continue

        sub_input_ids = input_ids.split(train_conf.micro_batch_size, dim = 0) 
        sub_attn_mask = attention_mask.split(train_conf.micro_batch_size, dim = 0)

        # ====================== ZERO GRAD ONCE PER "BIG BATCH" ======================
        optimizer.zero_grad()
        
        # We'll track times and losses across micro-batches
        total_fwd_time = 0.0
        total_bwd_time = 0.0
        total_loss_val = 0.0
        start_batch = time.time()

        # We'll keep a list of dictionaries, one per layer, each mapping expert_id -> usage_count
        usage_accum = [defaultdict(int) for _ in range(model.conf.n_layers)]

        # ====================== MICRO-BATCH LOOP ======================
        for i in range(train_conf.accumulation_steps):

            mb_input_ids = sub_input_ids[i]
            mb_attn_mask = sub_attn_mask[i]

            # ---------------------- Forward ----------------------
            start_fwd = time.time()
            outputs = model(mb_input_ids, mb_attn_mask, moe_method = 'forward_slow', use_lflb = train_conf.use_lflb, use_checkpointing = True)
            loss = outputs['base_loss'] + train_conf.router_aux_loss_coef * outputs['aux_loss']
            fwd_time = time.time() - start_fwd
            total_fwd_time += fwd_time

            # ---------------------- Collect Expert Usage for This Micro-Batch ----------------------
            # with torch.no_grad():
            #     all_topk_experts = outputs['all_topk_experts']
            #     for layer_idx, topk_expert_tensor in enumerate(outputs["all_topk_experts"]):
            #         flat_experts = topk_expert_tensor.view(-1)  
            #         unique_experts = flat_experts.unique()
            #         for ex_id in unique_experts:
            #             ex_count = (flat_experts == ex_id).sum().item()
            #             usage_accum[layer_idx][int(ex_id)] += ex_count
                        
            # ---------------------- Backward ----------------------
            # Divide by accumulation_steps so total gradient matches "big batch" size
            scaled_loss = loss / train_conf.accumulation_steps
            start_bwd = time.time()
            scaled_loss.backward()
            bwd_time = time.time() - start_bwd
            total_bwd_time += bwd_time

            total_loss_val += loss.item()

        # ====================== GRAD CLIPPING & OPT STEP ======================
        shared_params = [p for n,p in model.named_parameters() if 'expert' not in n]
        expert_params = [p for n,p in model.named_parameters() if 'expert' in n]

        torch.nn.utils.clip_grad_norm_(shared_params, train_conf.max_grad_norm)
        torch.nn.utils.clip_grad_norm_(expert_params, train_conf.max_expert_grad_norm)

        optimizer.step()
        scheduler.step()

        # ====================== LOSS-FREE BIAS UPDATE ======================
        # We'll do sign-based bias updates after each "big batch"
        if train_conf.use_lflb:
            for layer_ix in range(model.conf.n_layers):
                model.layers[layer_ix].moe.update_expert_biases(usage_accum[layer_ix], train_conf.bias_update_rate)

        # ============== METRICS ==============
        avg_loss = total_loss_val / train_conf.accumulation_steps # Take the average loss over micro-batches. total_loss_val is the sum of 'loss.item()'.
        total_tokens_trained += attention_mask.sum().detach().cpu().item()
        metrics = {
            'step': step,
            'shard_idx': shard_idx,
            'batch_size': input_ids.shape[0],
            'total_tokens_trained': total_tokens_trained,
            'lr': optimizer.param_groups[0]['lr'],
            'aux_coef': train_conf.router_aux_loss_coef,
            'train': {
                'loss': avg_loss,
                'base_loss': outputs['base_loss'].detach().cpu().item(), # From last microbatch only
                'aux_loss':  outputs['aux_loss'].detach().cpu().item() # From last microbatch only
            },
            'fwd_time':  total_fwd_time,
            'bwd_time':  total_bwd_time,
            'batch_time':  time.time() - start_batch
        }

        # ============== EXPENSIVE METRICS (EVERY 10 STEPS) ==============
        if step % 10 == 0:
            
            metrics['gradients'] = get_gradient_stats(model)

            # Convert usage_accum (list of defaultdicts) into a more standard dict for logging
            usage_dict_final = {}
            for layer_idx, ex_dict in enumerate(usage_accum):
                usage_dict_final[layer_idx] = dict(ex_dict)  # convert defaultdict -> normal dict
            metrics['expert_usage'] = usage_dict_final

        # ============== EXTRA EXPENSIVE METRICS (EVERY 500 STEPS) ==============
        if step % 250 == 0:
            metrics['val'] = get_val_stats(model, val_dl, train_conf.router_aux_loss_coef)

        # ============== SAVE (EVERY 5000 STEPS) ==============
        if step % 2500 == 0:
            torch.save(
                {
                    'model_state_dict': model.state_dict(), 
                    'optimizer_state_dict': optimizer.state_dict(),
                    'scheduler_state_dict': scheduler.state_dict(),
                    'total_tokens_trained': total_tokens_trained,
                    'step': step
                },
                f"saves/checkpoint_{step:08d}.pt"
            )

        # ============== LOG TO W&B ==============
        wandb.log(metrics)

        # ============== PRINT ==============
        if step <= 10 or (step <= 100 and step % 10 == 0) or (step > 100 and step % 100 == 0):
            print(f"Step {step}: avg_loss={metrics['train']['loss']:.4f} "
                f"| fwd_time={metrics['fwd_time']:.2f}s | bwd_time={metrics['bwd_time']:.2f}s | batch_time = {metrics['batch_time']:.2f} "
                f"| lr={metrics['lr']:.1e}"
            )
            
        step += 1

wandb.finish()

Found 400 shards.

=== Loading shard ./../../data/train_shard_0.json (index 0) ===
Step 0: avg_loss=10.9893 | fwd_time=19.66s | bwd_time=13.98s | batch_time = 34.36 | lr=2.4e-04
Step 1: avg_loss=10.3227 | fwd_time=12.30s | bwd_time=13.57s | batch_time = 26.31 | lr=2.4e-04
Step 2: avg_loss=9.7379 | fwd_time=10.50s | bwd_time=12.02s | batch_time = 22.94 | lr=2.5e-04
Step 3: avg_loss=9.4710 | fwd_time=9.50s | bwd_time=10.71s | batch_time = 20.66 | lr=2.5e-04
Step 4: avg_loss=9.2875 | fwd_time=9.35s | bwd_time=10.07s | batch_time = 19.82 | lr=2.5e-04
Step 5: avg_loss=9.1368 | fwd_time=8.67s | bwd_time=9.70s | batch_time = 18.75 | lr=2.5e-04
Step 6: avg_loss=8.9381 | fwd_time=8.44s | bwd_time=9.42s | batch_time = 18.24 | lr=2.5e-04
Step 7: avg_loss=8.7441 | fwd_time=8.37s | bwd_time=9.33s | batch_time = 18.07 | lr=2.6e-04
Step 8: avg_loss=8.6018 | fwd_time=8.11s | bwd_time=8.98s | batch_time = 17.44 | lr=2.6e-04
Step 9: avg_loss=8.4448 | fwd_time=8.08s | bwd_time=8.97s | batch_time = 17.41 

In [None]:
""" 
Final save
""" 
from helpers.memory import clear_all_cuda_memory

torch.save(
    {
        'model_state_dict': model.state_dict(), 
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'total_tokens_trained': total_tokens_trained,
        'step': step,
    },
    f"saves/checkpoint_{step:08d}.pt"
    )

clear_all_cuda_memory()
check_memory()

In [None]:
"""
Qualitative test
"""
prompt = 'My dog likes to eat '

inputs = tokenizer(prompt, return_tensors = 'pt').to(main_device)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)

# Iteratively generate tokens
with torch.no_grad():
    for _ in range(255):
        output = model(input_ids, attention_mask, moe_method = 'forward_slow', use_checkpointing = False)['logits']

        next_token_id = torch.argmax(output[0, -1, :], dim = -1).unsqueeze(0)

        input_ids = torch.cat([input_ids, next_token_id.unsqueeze(0)], dim = 1)
        attention_mask = torch.cat([attention_mask, torch.ones((1, 1), dtype = torch.long, device = input_ids.device)], dim = 1)

        if next_token_id.item() in [tokenizer.eos_token_id, tokenizer.encode('\n')[0]]:
            break

# Decode final sequence
generated_text = tokenizer.decode(input_ids[0], skip_special_tokens = False)
print(generated_text)