In [None]:
# Imports
import torch
from torch import nn
from torch.nn import DataParallel
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from accelerate import Accelerator
from helpers.memory import check_memory
from dataclasses import dataclass

load_dotenv('secrets.env')
main_device = 'cuda:0'
check_memory()

### 1. Create a model class that can be used for training.
- The model class does not have to handle sharding/expert movements or anything like that, later a wrapper will be used

In [11]:
@dataclass
class MoeConf:
    """
    General config settings for this MoE
    """
    vocab_size: int = 50304 # Base OlMoE: 50304 (vocab size)
    D: int = 2048 # Base OlMoE: 2048 (hidden state dimension)
    H: int = 16 # Base OlMoE: 16 (number of attention heads)
    I: int = 1024 # Base OlMoE: 1024 (expert MLP dimension)
    n_experts: int = 64 # Base OlMoE: 64
    top_k: int = 8 # Base OlMoE: 8 
    norm_topk_prob: bool = False # Base OlMoE: false (whether to normalize so that expert weights sum to 1 after topk)
    padding_idx: int = 1 # Base OlMoE: 1 (index where padding gets mapped to)
    n_layers: int = 16 # Base OlMoE: 16 (transformer layers)
    rms_norm_eps: float = 1e-05 # Base OlMoE: 1e-05
    rope_theta: float = 10000.0 # Base OlMoe: 10000.0 (this is something needed for ROPE)
    max_position_embeddings: int = 4096 # Base OlMoE: 4096 (this is something needed for ROPE)
    router_aux_loss_coef: float = 0.01  # Base OlMoE: 0.01 (relative weight of balancing loss)
    attn_method: str = 'fa2' # In OlMoE this is chosen automatically, here we explicitly pass it - choose 'normal', 'sdpa', or 'fa2'
    torch_dtype: torch.dtype = torch.bfloat16

conf = MoeConf()

In [12]:
# Some helper functions that will be later needed by the class

# Create the upper-trangular matrix of infinities to mask future tokens in the attention softmax;
def _prepare_4d_causal_attention_mask_with_cache_position(
    attention_mask: torch.Tensor,
    sequence_length: int,
    target_length: int,
    dtype: torch.dtype,
    device: torch.device,
    cache_position: torch.Tensor,
    batch_size: int,
    **kwargs,
):
    if attention_mask is not None and attention_mask.dim() == 4:
        # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
        causal_mask = attention_mask
    else:
        min_dtype = torch.finfo(dtype).min
        causal_mask = torch.full(
            (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device
        )
        if sequence_length != 1:
            causal_mask = torch.triu(causal_mask, diagonal=1)
        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            mask_length = attention_mask.shape[-1]
            padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
            padding_mask = padding_mask == 0
            causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
                padding_mask, min_dtype
            )

    return causal_mask

# Load balancing loss, copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
def load_balancing_loss_func(gate_logits, num_experts, top_k, attention_mask):
    compute_device = gate_logits[0].device
    concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim = 0)
    routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
    _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
    expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
    if attention_mask is None:
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.mean(routing_weights, dim=0)
    else:
        batch_size, sequence_length = attention_mask.shape
        num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
        # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
        expert_attention_mask = (attention_mask[None, :, :, None, None].expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
        # Compute the percentage of tokens routed to each experts
        tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
        # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
        router_per_expert_attention_mask = (attention_mask[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(compute_device))
        # Compute the average probability of routing to these experts
        router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
    
    overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
    return overall_loss * num_experts

In [4]:
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from transformers.modeling_flash_attention_utils import _flash_attention_forward # Flash attention forward
from transformers.activations import silu

class OlmoeRMSNorm(nn.Module):
    """
    Apply RMS Norm
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L137-L154
    - This is the only norm used in OlMoE!
      - It's used 4 times per layer (attention key norm, attention query norm, layer residual pre-attention norm, post-attention norm)
      - Also one additional time before the final LM head 
    """
    def __init__(self, D, eps = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(D))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim = True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

class OlmoeRotaryEmbedding(nn.Module):
    """
    Get sin/cos ROPE embeddings
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L161-L219
    - Code has been simplified heavily since we're not using dynamic ROPE scaling
    """
    def __init__(self, conf: MoeConf):
        super().__init__()
        dim = int(conf.D/conf.H)
        inv_freq = 1.0 / (conf.rope_theta ** (torch.arange(0, dim, 2, dtype = torch.int64).float()/dim))
        self.register_buffer("inv_freq", inv_freq, persistent = False)
        
    @torch.no_grad()
    def forward(self, x, position_ids):
        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type = device_type, enabled = False):
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim = -1)
            cos = emb.cos()
            sin = emb.sin()
        return cos.to(dtype = x.dtype), sin.to(dtype = x.dtype)

class OlmoeAttention(nn.Module):
    """
    Attention implementation
    - Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L288-L391
    - Simplfied to handle base attention/sdpa/flash attention within this one class
    - Also doesn't support GQA (OlMoE doesn't use anyways)
    """
    def __init__(self, conf: MoeConf):
        super().__init__()
        self.attn_method = conf.attn_method
        self.D = conf.D # Hidden state dim
        self.H = conf.H # Num of attention heads
        self.Dh = int(conf.D/conf.H) # Dimensions per head
        
        # Initialize attention layers - no biases following OlMoE architecture
        self.q_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.k_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.v_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.o_proj = nn.Linear(self.D, self.D, bias = False)
        self.q_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)
        self.k_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)

    # Taken from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L223-L255
    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim = 1):
        
        def rotate_half(x):
            """Rotates half the hidden dims of the input."""
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)
            
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor]):
        
        B, N , D = hidden_state.shape

        query_state = self.q_norm(self.q_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        key_state = self.k_norm(self.k_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        value_state = self.v_proj(hidden_state).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = self.apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if self.attn_method == 'normal':
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(self.Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + attention_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
            
        elif self.attn_method == 'sdpa':
            attn_output = torch.nn.functional.scaled_dot_product_attention(
                query_state, key_state, value_state,
                attention_mask, dropout_p = 0.0, is_causal = True
            )
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
            
        elif self.attn_method == 'fa2':
            query_state = query_state.transpose(1, 2)
            key_state = key_state.transpose(1, 2)
            value_state = value_state.transpose(1, 2)
            attn_output = _flash_attention_forward(
                query_state, key_state, value_state,
                attention_mask, N, dropout = 0.0, use_top_left_mask = False, is_causal = True
            )
            attn_output = attn_output.reshape(B, N, D).contiguous()
            
        attn_output = self.o_proj(attn_output)
        return attn_output
    
class OlmoeMLP(nn.Module):
    """
    Individual expert MLP
    - Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L258-L272
    """
    def __init__(self, conf: MoeConf):
        super().__init__()
        self.conf = conf
        self.D = conf.D
        self.I = conf.I
        self.gate_proj = nn.Linear(self.D, self.I, bias = False)
        self.up_proj = nn.Linear(self.D, self.I, bias = False)
        self.down_proj = nn.Linear(self.I, self.D, bias = False)
        self.act_fn = silu

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj


class OlmoeMoe(nn.Module):
    """
    Entire MLP layer including router
    - Modified from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L604-L649
    """
    def __init__(self, conf: MoeConf):
        super().__init__()
        self.n_experts = conf.n_experts
        self.top_k = conf.top_k
        self.norm_topk_prob = conf.norm_topk_prob
        self.gate = nn.Linear(conf.D, self.n_experts, bias = False) # Router
        self.experts = nn.ModuleList([OlmoeMLP(conf) for _ in range(self.n_experts)]) # Create experts using OlmoeMLP

    def forward(self, hidden_state: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        B, N, D = hidden_state.shape
        
        hidden_state = hidden_state.view(B * N, D) # Flatten out B x N x D to BN x D (flattened token-level reps) to route all tokens seperately
        router_logits = self.gate(hidden_state) # Output BN x N_EXPERTS (routing probability for each token)
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)

        # Below both routing_weights and selected_experts are of size BN x TOP_K (for each token, the selected TOP_K experts and corresponding weights)
        # Weights do NOT sum to 1 since we only top_k'd after the softmax
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim = -1) 
        if self.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        routing_weights = routing_weights.to(hidden_state.dtype)

        # One hot encode - for each expert, which topk x token is active - e.g. expert_assignment_mask[0, :] will be 0s if the first expert is never chosen
        expert_assignment_mask = F.one_hot(selected_experts, num_classes = self.n_experts).permute(2, 1, 0) # Creates (N_EXPERTS, TOP_K, BN)

        mlp_output = torch.zeros((B * N, D), dtype = hidden_state.dtype, device = hidden_state.device) # Initialize MLP output - later iterate through experts and sum onto this object
        
        # Iterate through all the experts, apply each expert to the tokens where the expert are relevant, multiple output by the weights for the topk/token for that expert, then sum onto the mlp_output obj
        for expert_ix, expert in enumerate(self.experts):
            # For this expert, gives the (topk, token) coordinates which uses the expert
            topk_slot, token_indices = torch.where(expert_assignment_mask[expert_ix, :])
            
            if token_indices.numel() == 0:
                continue

            # Get hidden states for tokens that use this expert - shape of num_assigned_tokens x D
            tokens_for_expert = hidden_state[token_indices, :]

            # Move input to expert device and get expert output
            expert_output = expert(tokens_for_expert)
            # For each num_assigned_tokens, multiples it by the corresponding weight in topk_slot fort that token_index
            expert_output = expert_output * routing_weights[token_indices, topk_slot].unsqueeze(1) 

            mlp_output.index_add_(0, token_indices, expert_output.to(hidden_state.dtype))

        mlp_output = mlp_output.reshape(B, N, D) # Convert back from BN x D -> B x N x D
        
        return mlp_output, router_logits

class OlmoeBlock(nn.Module):
    """
    A single transformer layer
    """
    def __init__(self, conf: MoeConf, layer_idx: int):
        super().__init__()
        self.D = conf.D
        self.self_attn = OlmoeAttention(conf = conf)
        self.moe = OlmoeMoe(conf)
        self.input_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)
        self.post_attention_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
            
        ### Pre-SA Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        
        ### SA + Sum to Residual Stream ###
        attn_output = self.self_attn(
            hidden_state,
            attention_mask = attention_mask,
            position_ids = position_ids,
            position_embeddings = position_embeddings
        )
        hidden_state = residual + attn_output

        ### Pre-MLP Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = self.post_attention_layernorm(hidden_state)
        
        ### MLP + Sum to Residual Stream###
        mlp_output, router_logits = self.moe(hidden_state)
        hidden_state = residual + mlp_output
        
        return hidden_state, router_logits

class OlmoeModel(nn.Module):
    """
    The top level model object. Also handles weight initialization and loss calculations.
    """
    def __init__(self, conf: MoeConf):
        """
        Params:
            @conf: A configuration object of class MoeConf
        """
        super().__init__()
        self.conf = conf
        
        ### Layers ###
        self.embed_tokens = nn.Embedding(self.conf.vocab_size, self.conf.D, self.conf.padding_idx)
        self.rotary_emb = OlmoeRotaryEmbedding(conf = self.conf)
        self.layers = nn.ModuleList(
            [OlmoeBlock(self.conf, layer_idx) for layer_idx in range(self.conf.n_layers)]
        )
        self.norm = OlmoeRMSNorm(self.conf.D, eps = self.conf.rms_norm_eps)
        self.lm_head = nn.Linear(self.conf.D, self.conf.vocab_size, bias = False)
        
        ### Initialize weights ###
        self.apply(self._init_weights)

    # OlMoE weight initiation - see https://github.com/huggingface/transformers/blob/8f1509a96c96747c893051ac947795cfb0750357/src/transformers/modeling_utils.py#L2500-L2515
    # Normal distribution for linear layers + embeddings
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            # In the vocab -> embedding layer, set all embeddings to 0 for the padding token (tokenizer.pad_token_id)
            if module is self.embed_tokens:
                self.embed_tokens.weight.data[self.conf.padding_idx].zero_() 
        # Seems to use default weight initialization for other layers
            # Move all parameters and buffers to the specified dtype
    
    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.Tensor):

        hidden_state = self.embed_tokens(input_ids)
        B, N, D = hidden_state.shape

        ### Prep rotary embeddings + attention masks  ###
        cache_position = torch.arange(0, N, device = hidden_state.device)
        position_ids = cache_position.unsqueeze(0)
        position_embeddings = self.rotary_emb(hidden_state, position_ids) # Position embeddings to be shared across the decoder layers

        # This is the upper-trangular matrix of infinities to mask future tokens in the attention softmax;
        if self.conf.attn_method in ['normal', 'sdpa']:
            causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(attention_mask, N, N, hidden_state.dtype, hidden_state.device, cache_position, B)
        # The flash attention mask is simpler - takes only the original attention mask or None
        elif self.conf.attn_method == 'fa2':
            causal_mask  = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        
        ### Transformer layers ###
        all_router_logits = () # Save router logits from each layer into this; will be needed for balancing loss
        for i, layer in enumerate(self.layers):
            hidden_state, router_logits = layer(
                hidden_state,
                causal_mask,
                position_ids,
                position_embeddings
            )
            all_router_logits += (router_logits, )

        hidden_state = self.norm(hidden_state)
        output_logits = self.lm_head(hidden_state)

        ##### Calculate Loss #####
        # The labels object should be a tensor of token IDs or -100 (for attention mask, since don't want to calculate loss for those)
        label_ids = torch.where(input_ids == self.conf.padding_idx, torch.tensor(-100), input_ids)
        # Get regular loss
        base_loss = ForCausalLMLoss(output_logits, label_ids, self.conf.vocab_size)
        # Get load balancing loss
        aux_loss = load_balancing_loss_func(gate_logits = all_router_logits, num_experts = self.conf.n_experts, top_k = self.conf.top_k, attention_mask = attention_mask)
        # Get total loss = regular loss + .01 * load bal loss
        loss = base_loss + self.conf.router_aux_loss_coef * aux_loss 

        return output_logits, loss

In [None]:
# Let's test a forward pass

# Load the model - everything on the main device with bf16
torch.set_default_device(main_device)
torch.set_default_dtype(torch.bfloat16)

model = OlmoeModel(conf)
check_memory()

# Test a forward pass
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False, padding_side = 'left')

prompt = 'I am a dog and I like to eat. My favorite food is'
inputs = tokenizer(prompt, return_tensors = 'pt').to(main_device)
with torch.no_grad():
    output = model(inputs['input_ids'], inputs['attention_mask'])
    
output_ids = torch.argmax(output[0][0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

### 2. Now wrap it in a parent class to handle device movement

In [15]:
class DeviceWrappedExpert(nn.Module):
    def __init__(self, expert: OlmoeMLP, device: torch.device):
        super().__init__()
        self.device = device
        self.expert = expert.to(device)
        self.stream = torch.cuda.Stream(device = device)  # Dedicated stream per expert

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        orig_device = x.device
        with torch.cuda.stream(self.stream):
            x = x.to(self.device, non_blocking = True)
            out = self.expert(x)
            out = out.to(orig_device, non_blocking = True)
        return out

class DistributedOlmoeModel(nn.Module):
    """
    Wrapper class for OlmoeModel that handles multi-device expert distribution.
    Keeps the original implementation intact while managing device placement.
    """
    def __init__(self, conf: MoeConf, main_device: str | torch.device, expert_device_map: dict[int, str | torch.device] | None = None, expert_device_list: list[str | torch.device] | None = None):
        """
        This is a wrapper around the main model class that just handles device movement. Either set `expert_device_map` to explicit map experts to devices, or `expert_device_list` to equally partition experts among devices. If both are None, all experts are placed on the `main_device`.

        Params:
            @conf: A configuration object of class MoeConf
            @main_device: A torch.device object where all dense layers will be stored, such as "cuda:0"
            @expert_device_map: A mapping of experts to devices, e.g. `{0: "cuda:1", 1: "cuda:2", 2: "cuda:1", ...}`.
            @expert_device_list: A list of devices to equally allocate the experts, e.g. `["cuda:1", "cuda:2"]`.
        """
        super().__init__()
        
        # Validate device configurations
        if expert_device_map and expert_device_list:
            raise ValueError("Cannot specify both expert_device_map and expert_device_list")
        
        # Convert main_device to torch.device if it's a string
        self.main_device = torch.device(main_device)

        # Auto-generate expert devices if list provided
        if expert_device_map is None and expert_device_list is not None:
            expert_device_map = self._create_block_expert_devices(conf.n_experts, [torch.device(d) for d in expert_device_list])
        
        # Default expert devices if not provided
        self.expert_device_map = expert_device_map or {}
        
        # Initialize base model on main device
        self.model = OlmoeModel(conf).to(self.main_device)

        # Distribute experts
        self._distribute_experts()

    def _create_block_expert_devices(self, n_experts: int, device_list: list[str | torch.device]) -> dict[int, str | torch.device]:
        """
        Distribute experts in contiguous blocks across devices.
        
        Example: 8 experts, 2 devices → [0-3] on device 0, [4-7] on device 1
        """
        num_devices = len(device_list)
        experts_per_device = n_experts // num_devices
        remainder = n_experts % num_devices
        
        expert_map = {}
        expert_idx = 0
        
        for dev_idx, device in enumerate(device_list):
            # Calculate how many experts this device gets
            count = experts_per_device + (1 if dev_idx < remainder else 0)
            
            # Assign contiguous block
            for _ in range(count):
                expert_map[expert_idx] = device
                expert_idx += 1
                
        return expert_map
    
    def _distribute_experts(self):
        """
        Wrap experts in device-aware modules
        """
        for layer in self.model.layers:
            if not hasattr(layer, 'moe'):
                continue  # Skip non-MoE layers

            moe_layer = layer.moe
            new_experts = nn.ModuleList()
            
            for expert_idx, expert in enumerate(moe_layer.experts):
                # Get device for this expert index
                device = self.expert_device_map.get(expert_idx, self.main_device)
                
                # Wrap expert with device handler
                wrapped_expert = DeviceWrappedExpert(expert, torch.device(device))
                new_experts.append(wrapped_expert)
                
            # Replace original experts with wrapped versions
            moe_layer.experts = new_experts

    def forward(self, input_ids: torch.LongTensor, attention_mask: torch.Tensor):
        # Ensure base inputs stay on main device
        return self.model(
            input_ids.to(self.main_device), 
            attention_mask.to(self.main_device)
        )
        
    def state_dict(self, *args, **kwargs):
        """Return the state dict of the underlying model"""
        return self.model.state_dict(*args, **kwargs)
    
    def load_state_dict(self, state_dict, *args, **kwargs):
        """Load state dict into the underlying model"""
        return self.model.load_state_dict(state_dict, *args, **kwargs)


In [None]:
# Let's test this
 # Make sure everything is empty
if 'model' in locals() or 'model' in globals():
    del model
torch.cuda.empty_cache()
check_memory()

# Load the model - everything on the main device with bf16
torch.set_default_device(main_device)
torch.set_default_dtype(torch.bfloat16)

model = DistributedOlmoeModel(conf, main_device = main_device, expert_device_list = ['cuda:1', 'cuda:2'])
print('Model loaded')
check_memory()

# Test a forward pass
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False, padding_side = 'left')

prompt = 'I am a dog and I like to eat. My favorite food is'
inputs = tokenizer(prompt, return_tensors = 'pt').to(main_device)
with torch.no_grad():
    output = model(inputs['input_ids'], inputs['attention_mask'])
    
output_ids = torch.argmax(output[0][0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

### 3. Let's use some training data

In [None]:
from datasets import load_dataset
from transformers import AutoTokenizer
from torch.utils.data import DataLoader, Dataset
from itertools import islice

def get_fineweb_edu_ds(n_samples: int = 1000):
    dataset = load_dataset("HuggingFaceFW/fineweb-edu", name = "default", split = 'train', streaming = True)
    dataset = dataset.filter(lambda x: x.get('language') == 'en' and x.get('score') >= 4)
    dataset_pulled = list(islice(dataset, n_samples))  # Convert to a list of the first 1,000 samples
    dataset_pulled = [x['text'] for x in dataset_pulled]    
    return dataset_pulled

class TestDataset(Dataset):
    def __init__(self, tokenizer_output):
        self.input_ids = tokenizer_output['input_ids']
        self.attention_mask = tokenizer_output['attention_mask']
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        return {
            'input_ids': self.input_ids[idx],
            'attention_mask': self.attention_mask[idx]
        }
    
fw_data = get_fineweb_edu_ds(100)
fw_tokenized = tokenizer(fw_data, truncation = True, max_length = 4096, padding = 'max_length', return_tensors = 'pt')
fw_ds = TestDataset(fw_tokenized)
fw_dl = DataLoader(fw_ds, batch_size = 2, shuffle = True, generator = torch.Generator(device = main_device))

In [None]:
# Test training code
import time
from collections import defaultdict

# Set different LRs for experts/non-experts
optimizer = torch.optim.AdamW([
    {'params': [p for n,p in model.named_parameters() if 'expert' in n], 'lr': 1e-4},
    {'params': [p for n,p in model.named_parameters() if 'expert' not in n], 'lr': 3e-5}
], weight_decay = 0.01)

max_grad_norm = 1.0  # Set the value for gradient clipping
step = 0


for batch in fw_dl:
    model.train()

    # Device-aware batch preparation
    input_ids = batch['input_ids'].to(model.main_device)
    attention_mask = batch['attention_mask'].to(model.main_device)

    start_time = time.time()

    # Forward pass with timing
    optimizer.zero_grad()
    start_time = time.time()
    logits, loss = model(input_ids, attention_mask)
    fwd_time = time.time() - start_time

    # Backward pass
    start_bwd = time.time()
    loss.backward()
    bwd_time = time.time() - start_bwd

    # MoE-specific gradient clipping
    expert_params = [p for n,p in model.named_parameters() if 'expert' in n]
    shared_params = [p for n,p in model.named_parameters() if 'expert' not in n]

    # Different clipping for experts vs shared params
    torch.nn.utils.clip_grad_norm_(expert_params, max_grad_norm * 2)  # Looser for experts
    torch.nn.utils.clip_grad_norm_(shared_params, max_grad_norm)
    
    # Optimizer step
    optimizer.step()

    # MoE metrics collection
    metrics = {
        'loss': loss.item(),
        'fwd_time': fwd_time,
        'bwd_time': bwd_time,
        'expert_usage': defaultdict(int)
    }

    # Calculate expert utilization
    with torch.no_grad():
        for layer_idx, layer in enumerate(model.model.layers):
            if hasattr(layer.moe, 'last_router_logits'):
                router_logits = layer.moe.last_router_logits
                _, selected_experts = torch.topk(router_logits, model.conf.top_k, dim=-1)
                for expert_idx in selected_experts.flatten().unique():
                    metrics['expert_usage'][f'layer_{layer_idx}/expert_{expert_idx.item()}'] += 1

    print(f"Step {step}:")
    print(f"  Loss: {metrics['loss']:.3f}")
    print(f"  Fwd/Bwd Time: {metrics['fwd_time']:.2f}s/{metrics['bwd_time']:.2f}s")
    print("  Expert Usage:")
    for expert, count in metrics['expert_usage'].items():
        print(f"    {expert}: {count}")

    step = step + 1

In [None]:
# TBD: Implement streaming to avoid sequential looping over experts, concurrent launch should get ~3-5x speedup

In [None]:
# import inspect
# from transformers.modeling_rope_utils import _compute_default_rope_parameters
# print(inspect.getsource(_compute_default_rope_parameters))