In [None]:
# Imports
import torch
from torch import nn
from torch.nn import DataParallel
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from accelerate import Accelerator
from helpers.memory import check_memory

load_dotenv('secrets.env')
main_device = 'cuda:0'
check_memory()

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False, padding_side = 'left')

model = AutoModelForCausalLM.from_pretrained(
    'allenai/OLMoE-1B-7B-0924',
    device_map = main_device, 
    torch_dtype = torch.bfloat16,
    trust_remote_code = True
)


In [22]:
from dataclasses import dataclass

@dataclass
class MoeConf:
    vocab_size: int = 50304 # Base OlMoE: 50304 (vocab size)
    D: int = 2048 # Base OlMoE: 2048 (hidden state dimension)
    H = 16 # Base OlMoE: 16 (number of attention heads)
    router_aux_loss_coef: float = 0.01  # Base OlMoE: 0.01
    n_experts: int = 64             # Base OlMoE: 64
    top_k: int = 8                  # Base OlMoE: 8 
    padding_idx: int = 1            # Base OlMoE: 1 (index where padding gets mapped to)
    n_layers: int = 16              # Base OlMoE: 16 (transformer layers)
    rms_norm_eps: float = 1e-05     # Base OlMoE: 1e-05
    rope_theta = 10000.0 # Base OlMoe: 10000.0 (this is something needed for ROPE)
    max_position_embeddings: 4096 # Base OlMoE: 4096 (this is something needed for ROPE)

conf = MoeConf()

In [None]:
from transformers.loss.loss_utils import ForCausalLMLoss # Cross-entropy loss that handles label shifting
from transformers.modeling_flash_attention_utils import _flash_attention_forward # Flash attention forward
from transformers.activations import silu

class OlmoeRMSNorm(nn.Module):
    """
    RMS norm, copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L137-L154
    """
    def __init__(self, hidden_size, eps = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)
    

class OlmoeAttention(nn.Module):
    """
    Attention implementation
    """
    def __init__(self, conf: MoeConf, attn_method: str):
        super().__init__()
        self.attn_method = attn_method
        self.D = conf.D # Hidden state dim
        self.H = conf.H # Num of attention heads
        self.Dh = int(conf.D/conf.H) # Dimensions per head
        
        # Initialize attention layers - no biases following OlMoE architecture https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L318-L325
        self.q_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.k_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.v_proj = nn.Linear(self.D, self.H * self.Dh, bias = False)
        self.o_proj = nn.Linear(self.D, self.D, bias = False)
        self.q_norm = OlmoeRMSNorm(self.D, eps = conf.rms_norm_eps)
        self.k_norm = OlmoeRMSNorm((self.D // self.H) * self.num_key_value_heads, eps = conf.rms_norm_eps)

    # See https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py#L223-L255
    def apply_rotary_pos_emb(self, q, k, cos, sin, unsqueeze_dim = 1):
        def rotate_half(x):
            """Rotates half the hidden dims of the input."""
            x1 = x[..., : x.shape[-1] // 2]
            x2 = x[..., x.shape[-1] // 2 :]
            return torch.cat((-x2, x1), dim=-1)
        
        cos = cos.unsqueeze(unsqueeze_dim)
        sin = sin.unsqueeze(unsqueeze_dim)
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        return q_embed, k_embed

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: tuple[torch.Tensor, torch.Tensor]):
        
        B, N , D = hidden_state.shape

        query_state = self.self_attn.q_norm(self.self_attn.q_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        key_state = self.self_attn.k_norm(self.self_attn.k_proj(hidden_state)).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048
        value_state = self.self_attn.v_proj(hidden_state).view(B, N, self.H, self.Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = self.apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if self.attn_method == 'normal':
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + attention_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
        elif self.attn_method == 'sdpa':
            attn_output = torch.nn.functional.scaled_dot_product_attention(query_state, key_state, value_state, dropout_p = 0.0, is_causal = True)
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
        elif self.attn_method == 'fa2':
            query_state = query_state.transpose(1, 2)
            key_state = key_state.transpose(1, 2)
            value_state = value_state.transpose(1, 2)
            attn_output = _flash_attention_forward(
                query_state, key_state, value_state,
                attention_mask, N, dropout = 0.0, use_top_left_mask = False, is_causal = True
            )
            attn_output = attn_output.reshape(B, N, D).contiguous()
    
        return attn_output
    
class OlmoeMLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj
    
class OlmoeMoe(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_experts = config.num_experts
        self.top_k = config.num_experts_per_tok
        self.norm_topk_prob = config.norm_topk_prob
        self.gate = nn.Linear(config.hidden_size, self.num_experts, bias=False)
        self.experts = nn.ModuleList([OlmoeMLP(config) for _ in range(self.num_experts)])

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, sequence_length, hidden_dim = hidden_states.shape
        hidden_states = hidden_states.view(-1, hidden_dim)
        # router_logits: (batch * sequence_length, n_experts)
        router_logits = self.gate(hidden_states)

        routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
        routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
        if self.norm_topk_prob:
            routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
        # we cast back to the input dtype
        routing_weights = routing_weights.to(hidden_states.dtype)

        final_hidden_states = torch.zeros(
            (batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
        )

        # One hot encode the selected experts to create an expert mask
        # this will be used to easily index which expert is going to be selected
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)

        # Loop over all available experts in the model and perform the computation on each expert
        for expert_idx in range(self.num_experts):
            expert_layer = self.experts[expert_idx]
            idx, top_x = torch.where(expert_mask[expert_idx])

            # Index the correct hidden states and compute the expert hidden state for
            # the current expert. We need to make sure to multiply the output hidden
            # states by `routing_weights` on the corresponding tokens (top-1 and top-2)
            current_state = hidden_states[None, top_x].reshape(-1, hidden_dim)
            current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]

            # However `index_add_` only support torch tensors for indexing so we'll use
            # the `top_x` tensor here.
            final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
        final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
        return final_hidden_states, router_logits
    
class OlmoeBlock(nn.Module):
    """
    A single transformer layer
    """
    def __init__(self, conf: MoeConf, layer_idx: int):
        super().__init__()
        self.D = conf.D
        self.self_attn = OlmoeAttention(conf, attn_method = 'fa2')
        self.mlp = OlmoeSparseMoeBlock(conf)
        self.input_layernorm = OlmoeRMSNorm(conf.D, eps = conf.rms_norm_eps)
        self.post_attention_layernorm = OlmoeRMSNorm(conf.hidden_size, eps = conf.rms_norm_eps)
        

    def forward(self, hidden_state: torch.Tensor, attention_mask: torch.Tensor, position_ids: torch.LongTensor, position_embeddings: tuple[torch.Tensor, torch.Tensor]):
            
        ### Pre-SA Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = self.input_layernorm(hidden_state)
        
        ### Self-attention ###
        H = layer.self_attn.num_heads # Number of attention heads
        Dh = int(D/H) # Dimensions per head
        
        query_state = layer.self_attn.q_norm(layer.self_attn.q_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        key_state = layer.self_attn.k_norm(layer.self_attn.k_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        value_state = layer.self_attn.v_proj(hidden_state).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if attention_method == 'normal':
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + causal_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
        elif attention_method == 'sdpa':
            attn_output = torch.nn.functional.scaled_dot_product_attention(query_state, key_state, value_state, dropout_p = 0.0, is_causal = True)
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
        elif attention_method == 'fa2':
            query_state = query_state.transpose(1, 2)
            key_state = key_state.transpose(1, 2)
            value_state = value_state.transpose(1, 2)
            attn_output = _flash_attention_forward(
                query_state,
                key_state,
                value_state,
                causal_mask,
                N,
                dropout = 0.0,
                use_top_left_mask = False,
                is_causal = True
            )
            attn_output = attn_output.reshape(B, N, D).contiguous()

        ### Post-SA linear layer + Sum to Residual Stream ###
        attn_output = layer.self_attn.o_proj(attn_output)
        hidden_state = residual + attn_output

        ### Pre-MLP Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        
        ### MLP ###
        TOP_K = layer.mlp.top_k # 8
        N_EXPERTS = layer.mlp.num_experts # 64

        hidden_state = hidden_state.view(B * N, D) # Flatten out B x N x D to BN x D (flattened token-level reps) to route all tokens seperately
        router_logits = layer.mlp.gate(hidden_state) # Output BN x N_EXPERTS (routing probability for each token)
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)

        # Below both routing_weights and selected_experts are of size BN x TOP_K (for each token, the selected TOP_K experts and corresponding weights)
        # Weights do NOT sum to 1 since we only top_k'd after the softmax
        routing_weights, selected_experts = torch.topk(routing_weights, TOP_K, dim = -1) 
        routing_weights = routing_weights.to(hidden_state.dtype)

        mlp_output = torch.zeros((B * N, D), dtype = hidden_state.dtype, device = hidden_state.device) # Initialize MLP output - later iterate through experts and sum onto this object
        # One hot encode - for each expert, which topk x token is active - e.g. expert_assignment_mask[0, :] will be 0s if the first expert is never chosen
        expert_assignment_mask = torch.nn.functional.one_hot(selected_experts, num_classes = N_EXPERTS).permute(2, 1, 0) # Creates (N_EXPERTS, TOP_K, BN)
        
        # Iterate through all the experts, apply each expert to the tokens where the expert are relevant, multiple output by the weights for the topk/token for that expert, then sum onto the mlp_output obj
        for expert_ix, expert_wrapper in enumerate(layer.mlp.experts):
            expert_device = expert_wrapper.target_device
            expert = expert_wrapper.expert

            # For this expert, gives the (topk, token) coordinates which uses the expert
            topk_slot, token_indices = torch.where(expert_assignment_mask[expert_ix, :])
            # Get hidden states for tokens that use this expert - shape of num_assigned_tokens x D
            tokens_for_expert = hidden_state[token_indices, :]

            # Get expert output, multiply by routing weights
            gate_output = expert.gate_proj(tokens_for_expert.to(expert_device))
            expert_output = expert.act_fn(expert.gate_proj(tokens_for_expert.to(expert_device))) * expert.up_proj(tokens_for_expert.to(expert_device)) # Gate * up_proj
            expert_output = expert.down_proj(expert_output) # Down project it -> Shape = num_assigned_tokens x D
            expert_output = expert_output.to(main_device) * routing_weights[token_indices, topk_slot].unsqueeze(1) # For each num_assigned_tokens, multiples it by the corresponding weight in topk_slot fort that token_index

            mlp_output.index_add_(0, token_indices, expert_output.to(hidden_state.dtype))

        mlp_output = mlp_output.reshape(B, N, D) # Convert back from BN x D -> B x N x D

        ### Post-MLP Sum to Residual Stream ###
        hidden_state = residual + mlp_output
        
        all_router_logits += (router_logits, ) # Also save router logits from this layer


class OlmoeModel(nn.Module):
    """
    The top level model object. Also handles weight initialization and loss calculations.
    """
    def __init__(self, conf: MoeConf):
        super().__init__()
        self.conf = conf

        ### Layers ###
        self.embed_tokens = nn.Embedding(self.vocab_size, self.D, self.padding_idx)
        self.layers = nn.ModuleList(
            [OlmoeBlock(self.conf, layer_idx) for layer_idx in range(self.num_hidden_layers)]
        )
        self.norm = OlmoeRMSNorm(config.hidden_size, eps = self.conf.rms_norm_eps)
        self.rotary_emb = OlmoeRotaryEmbedding(config = self.conf)
        self.lm_head = nn.Linear(self.D, self.vocab_size, bias = False)

        ### Init ###
        self.apply(self._init_weights)

    # OlMoE weight initiation - see https://github.com/huggingface/transformers/blob/8f1509a96c96747c893051ac947795cfb0750357/src/transformers/modeling_utils.py#L2500-L2515
    # Normal distribution for linear layers + embeddings
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean = 0.0, std = 0.02)
            # In the vocab -> embedding layer, set all embeddings to 0 for the padding token (tokenizer.pad_token_id)
            self.embed_tokens.weight.data[self.padding_idx].zero_()
    
    # Copied from https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
    def load_balancing_loss_func(gate_logits, num_experts, top_k, attention_mask):
        compute_device = gate_logits[0].device
        concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim = 0)
        routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1)
        _, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
        expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts)
        if attention_mask is None:
            # Compute the percentage of tokens routed to each experts
            tokens_per_expert = torch.mean(expert_mask.float(), dim=0)
            # Compute the average probability of routing to these experts
            router_prob_per_expert = torch.mean(routing_weights, dim=0)
        else:
            batch_size, sequence_length = attention_mask.shape
            num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length)
            # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask
            expert_attention_mask = (attention_mask[None, :, :, None, None].expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)).reshape(-1, top_k, num_experts).to(compute_device))
            # Compute the percentage of tokens routed to each experts
            tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum(expert_attention_mask, dim=0)
            # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert
            router_per_expert_attention_mask = (attention_mask[None, :, :, None].expand((num_hidden_layers, batch_size, sequence_length, num_experts)).reshape(-1, num_experts).to(compute_device))
            # Compute the average probability of routing to these experts
            router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum(router_per_expert_attention_mask, dim=0)
        overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0))
        return overall_loss * num_experts
    

    def forward(self, input_ids: torch.LongTensor = None, attention_mask: torch.Tensor = None, position_ids: torch.LongTensor = None, labels: torch.LongTensor = None):
        
        embeds_output = self.embed_tokens(input_ids)
        B, N, D = embeds_output.shape

        cache_position = torch.arange(0, N, device = embeds_output.device)
        position_ids = cache_position.unsqueeze(0)
        # Flash attention mask
        causal_mask  = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
        position_embeddings = self.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers
        
        hidden_state = embeds_output

        # Now iterate through the layers
        all_router_logits = () # Save router logits from each layer into this; will be needed for balancing loss
        for i, layer in enumerate(self.layers):
            hidden_state, router_logits = layer(hidden_state)
            all_router_logits += (router_logits, ) # Also save router logits from this layer

        hidden_state = self.norm(hidden_state)
        output_logits = self.lm_head(hidden_state)

        ##### Calculate Loss #####
        # The labels object should be a tensor of token IDs or -100 (for attention mask, since don't want to calculate loss for those)
        label_ids = torch.where(input_ids == self.padding_idx, torch.tensor(-100), input_ids)
        # Get regular loss
        base_loss = ForCausalLMLoss(output_logits, label_ids, self.conf.vocab_size)
        # Get load balancing loss
        aux_loss = self.load_balancing_loss_func(gate_logits = all_router_logits, num_experts = self.conf.n_experts, top_k = self.conf.top_k, attention_mask = attention_mask)
        # Get total loss = regular loss + .01 * load bal loss
        loss = base_loss + self.router_aux_loss_coef * aux_loss 



In [None]:
from transformers.activations import ACT2FN

ACT2FN['silu']

In [None]:
silu

In [None]:
tokenizer.pad_token_id


In [None]:
from transformers.loss.loss_utils import ForCausalLMLoss

ForCausalLMLoss