## Setup

In [None]:
# Imports

import torch
from torch.nn import DataParallel
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from accelerate import Accelerator
from helpers.memory import check_memory

load_dotenv('secrets.env')
main_device = 'cuda:0'
check_memory()

In [None]:
# Test Multi-GPU Access

# Verify Pytorch can communicate with each GPU
def check_communication():
    for i in range(torch.cuda.device_count()):
        device = torch.device(f"cuda:{i}")
        try:
            x = torch.tensor([1.0, 2.0, 3.0], device = device)
            print(f"GPU {i}: Computation successful.")
        except Exception as e:
            print(f"GPU {i}: Computation failed. Error: {e}")

check_communication()

# Initialize accelerator - will later use for training
accelerator = Accelerator(
    device_placement = True,
    mixed_precision = 'bf16',
    gradient_accumulation_steps = 1, # Temp
    split_batches = False
)

## Test Inference with Base HF Model

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False, padding_side = 'right')

model = AutoModelForCausalLM.from_pretrained(
    'allenai/OLMoE-1B-7B-0924',
    device_map = main_device, 
    torch_dtype = torch.bfloat16,
    trust_remote_code = True
)

check_memory()

In [None]:
# Split up experts but keep everything else on the main GPU 
# (to avoid having to constantly switch devices for other computations)
class DistributedExpertWrapper(torch.nn.Module):
    def __init__(self, expert, target_device):
        super().__init__()
        self.expert = expert.to(target_device)
        self.target_device = target_device
    
    def forward(self, x):
        # Move input to expert's device, compute, and return to original device
        orig_device = x.device
        out = self.expert(x.to(self.target_device))
        return out.to(orig_device)  # Now explicitly returning the output


num_gpus = torch.cuda.device_count()

for layer_idx, layer in enumerate(model.model.layers):
    experts = layer.mlp.experts
    num_experts = len(experts)
    experts_per_gpu = num_experts // num_gpus
    
    for expert_idx in range(num_experts):
        target_gpu = (expert_idx // experts_per_gpu) % num_gpus
        # if target_gpu != 0:  # Only move & add a wrapper if not GPU 0
        target_device = f"cuda:{target_gpu}"
        layer.mlp.experts[expert_idx] = DistributedExpertWrapper(
            experts[expert_idx], 
            target_device
        )

check_memory()

In [None]:
# Test with .pipeline()
@torch.no_grad()
def eval_model_v1(model, tokenizer, prompt, max_new_tokens):
    tokens = tokenizer(prompt, return_tensors = "pt").to(main_device)
    res = model.generate(
        **tokens,
        max_new_tokens = max_new_tokens,
        do_sample = False,
        eos_token_id = [tokenizer.eos_token_id]
        )
    print(res)
    return tokenizer.batch_decode(res)[0]

print(eval_model_v1(
    model,
    tokenizer,
    'I am a dog and I like to eat. My favorite food is',
    1
))

In [None]:
# Test with token-by-token generation
@torch.no_grad()
def eval_model_v2(model, tokenizer, prompt, max_new_tokens):
    tokens = tokenizer(prompt, return_tensors = 'pt').to(main_device)['input_ids']
    i = 1
    while i <= max_new_tokens:
        output = model(tokens)
        logits = output['logits']
        output_token = torch.argmax(F.softmax(logits.squeeze(), dim = 1), dim = 1)[-1]
        print(output_token)
        tokens = torch.cat((tokens, output_token.view(1, 1)), dim = 1)
                
        if output_token in [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|end|>")]:
            break

        i = i + 1
    
    print(tokens)
    return tokenizer.batch_decode(tokens)[0]

print(eval_model_v2(
    model,
    tokenizer,
    'I am a dog and I like to eat. My favorite food is',
    1
))

## Reverse Engineer the Class

In [171]:
prompt = 'I am a dog and I like to eat. My favorite food is' # Correct next token output is 'steak'
inputs = tokenizer(prompt, return_tensors = 'pt').to(main_device)
input_ids = inputs['input_ids']
attention_mask = inputs['attention_mask']

In [None]:
# 1. Split up LM head + rest of model
with torch.no_grad():

    # Everything before the LM head
    decoder_output = model.model(
        input_ids,
        attention_mask
    )['last_hidden_state']
    
    # The LM head
    output_logits = model.lm_head(decoder_output)

output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

In [None]:
# 2. Split out embeddings layer + seperate decoder layers
with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)

    causal_mask = model.model._update_causal_mask(attention_mask, embeds_output, cache_position, None, False)

    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):
        
        # We can ignore all the arguments related to caching/intermediate outputs
        layer_output = layer(
            hidden_state,
            causal_mask,
            position_ids,
            position_embeddings = position_embeddings
        )

        hidden_state = layer_output[0]
        
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # The LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))


In [None]:
# 3. Further split out the transformer layers into SA module + MLP module
with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)

    causal_mask = model.model._update_causal_mask(attention_mask, embeds_output, cache_position, None, False)

    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):
        
        residual = hidden_state

        hidden_state = layer.input_layernorm(hidden_state)
        ### SA ###
        hidden_state, _, _ = layer.self_attn(
            hidden_state,
            attention_mask = causal_mask,
            position_ids = position_ids,
            position_embeddings = position_embeddings
        )
        hidden_state = residual + hidden_state

        ### MLP ###
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        hidden_state, router_logits = layer.mlp(hidden_state)
        hidden_state = residual + hidden_state
                
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # The LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

In [None]:
# 4. Further split out the SA module into raw max components (or SDPA sub-module), and split out the MLP layer into router + individual expert operations

from helpers.olmoe import apply_rotary_pos_emb

# The model supports 3 attention implementations
# - normal = calculate attention normally with matrix operations
# - sdpa = use pytorch sdpa attention implementation (same result as normal but faster with fused kernel operations)
# - flash attention 2 = will get different results, so don't implement this
attention_method = ['normal', 'sdpa'][1] 

with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)
    # This is the upper-trangular matrix of infinities to mask future tokens in the attention softmax; only needed when attention_method = 'normal'
    causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(attention_mask, N, N, embeds_output.dtype, embeds_output.device, cache_position, B)
    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):

        ### Pre-SA Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        
        ### Self-attention ###
        H = layer.self_attn.num_heads # Number of attention heads
        Dh = int(D/H) # Dimensions per head
        
        query_state = layer.self_attn.q_norm(layer.self_attn.q_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        key_state = layer.self_attn.k_norm(layer.self_attn.k_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        value_state = layer.self_attn.v_proj(hidden_state).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if attention_method == 'normal':
            # See OlMoeAttention class https://github.com/huggingface/transformers/blob/main/src/transformers/models/olmoe/modeling_olmoe.py
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + causal_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D

        elif attention_method == 'sdpa':
            # See OlmoeSdpaAttention class
            # Don't pass causal mask at all, let it create it itself
            attn_output = torch.nn.functional.scaled_dot_product_attention(query_state, key_state, value_state, dropout_p = 0.0, is_causal = True)
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
        
        else:
            raise Exception('No')

        ### Post-SA linear layer + Sum to Residual Stream ###
        attn_output = layer.self_attn.o_proj(attn_output)
        hidden_state = residual + attn_output

        ### Pre-MLP Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        
        ### MLP ###
        TOP_K = layer.mlp.top_k # 8
        N_EXPERTS = layer.mlp.num_experts # 64

        hidden_state = hidden_state.view(B * N, D) # Flatten out B x N x D to BN x D (flattened token-level reps) to route all tokens seperately
        router_logits = layer.mlp.gate(hidden_state) # Output BN x N_EXPERTS (routing probability for each token)
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)

        # Below both routing_weights and selected_experts are of size BN x TOP_K (for each token, the selected TOP_K experts and corresponding weights)
        # Weights do NOT sum to 1 since we only top_k'd after the softmax
        routing_weights, selected_experts = torch.topk(routing_weights, TOP_K, dim = -1) 
        routing_weights = routing_weights.to(hidden_state.dtype)
        
        mlp_output = torch.zeros((B * N, D), dtype = hidden_state.dtype, device = hidden_state.device) # Initialize MLP output - later iterate through experts and sum onto this object
        # One hot encode - for each expert, which topk x token is active - e.g. expert_assignment_mask[0, :] will be 0s if the first expert is never chosen
        expert_assignment_mask = torch.nn.functional.one_hot(selected_experts, num_classes = N_EXPERTS).permute(2, 1, 0) # Creates (N_EXPERTS, TOP_K, BN)

        # Iterate through all the experts, apply each expert to the tokens where the expert are relevant, multiple output by the weights for the topk/token for that expert, then sum onto the mlp_output obj
        for expert_ix, expert_wrapper in enumerate(layer.mlp.experts):
            expert_device = expert_wrapper.target_device
            expert = expert_wrapper.expert

            # For this expert, gives the (topk, token) coordinates which uses the expert
            topk_slot, token_indices = torch.where(expert_assignment_mask[expert_ix, :])
            # Get hidden states for tokens that use this expert - shape of num_assigned_tokens x D
            tokens_for_expert = hidden_state[token_indices, :]

            # Get expert output, multiply by routing weights
            expert_output = expert.down_proj(expert.act_fn(expert.gate_proj(tokens_for_expert.to(expert_device))) * expert.up_proj(tokens_for_expert.to(expert_device))) # Shape = num_assigned_tokens x D
            expert_output = expert_output.to(main_device) * routing_weights[token_indices, topk_slot].unsqueeze(1) # For each num_assigned_tokens, multiples it by the corresponding weight in topk_slot fort that token_index

            mlp_output.index_add_(0, token_indices, expert_output.to(hidden_state.dtype))

        mlp_output = mlp_output.reshape(B, N, D) # Convert back from BN x D -> B x N x D

        ### Post-MLP Sum to Residual Stream ###
        hidden_state = residual + mlp_output
                
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))

## Add back in the loss functions

In [None]:
attention_method = 'normal'

with torch.no_grad():

    # Embedding layer
    embeds_output = model.model.embed_tokens(input_ids)
    B, N, D = embeds_output.shape

    cache_position = torch.arange(0, N, device = embeds_output.device)
    position_ids = cache_position.unsqueeze(0)
    # This is the upper-trangular matrix of infinities to mask future tokens in the attention softmax; only needed when attention_method = 'normal'
    causal_mask = model.model._prepare_4d_causal_attention_mask_with_cache_position(attention_mask, N, N, embeds_output.dtype, embeds_output.device, cache_position, B)
    position_embeddings = model.model.rotary_emb(embeds_output, position_ids) # Position embeddings to be shared across the decoder layers

    hidden_state = embeds_output

    # Now iterate through the layers
    for i, layer in enumerate(model.model.layers):

        ### Pre-SA Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = layer.input_layernorm(hidden_state)
        
        ### Self-attention ###
        H = layer.self_attn.num_heads # Number of attention heads
        Dh = int(D/H) # Dimensions per head
        
        query_state = layer.self_attn.q_norm(layer.self_attn.q_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        key_state = layer.self_attn.k_norm(layer.self_attn.k_proj(hidden_state)).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048
        value_state = layer.self_attn.v_proj(hidden_state).view(B, N, H, Dh).transpose(1, 2) # B x N x 2048

        cos, sin = position_embeddings
        query_state, key_state = apply_rotary_pos_emb(query_state, key_state, cos, sin)
        
        if attention_method == 'normal':
            attn_weights = torch.matmul(query_state, key_state.transpose(2, 3))/math.sqrt(Dh)  # Should be shape B x H x N x N
            attn_weights = attn_weights + causal_mask # Attention mask is upper triangular of negative infinity
            attn_weights = F.softmax(attn_weights, dim = -1, dtype = torch.float32).to(query_state.dtype)
            attn_output = torch.matmul(attn_weights, value_state) # B x H x N x D/H
            attn_output = attn_output.transpose(1, 2).contiguous() # Reorder into B x N x H x D/H
            attn_output = attn_output.reshape(B, N, D) # Concatenate vertically back into B x N x D
        elif attention_method == 'sdpa':
            attn_output = torch.nn.functional.scaled_dot_product_attention(query_state, key_state, value_state, dropout_p = 0.0, is_causal = True)
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(B, N, D)
        else:
            raise Exception('No')

        ### Post-SA linear layer + Sum to Residual Stream ###
        attn_output = layer.self_attn.o_proj(attn_output)
        hidden_state = residual + attn_output

        ### Pre-MLP Residual Stream + Norm ###
        residual = hidden_state
        hidden_state = layer.post_attention_layernorm(hidden_state)
        
        ### MLP ###
        TOP_K = layer.mlp.top_k # 8
        N_EXPERTS = layer.mlp.num_experts # 64

        hidden_state = hidden_state.view(B * N, D) # Flatten out B x N x D to BN x D (flattened token-level reps) to route all tokens seperately
        router_logits = layer.mlp.gate(hidden_state) # Output BN x N_EXPERTS (routing probability for each token)
        routing_weights = F.softmax(router_logits, dim = 1, dtype = torch.float)

        # Below both routing_weights and selected_experts are of size BN x TOP_K (for each token, the selected TOP_K experts and corresponding weights)
        # Weights do NOT sum to 1 since we only top_k'd after the softmax
        routing_weights, selected_experts = torch.topk(routing_weights, TOP_K, dim = -1) 
        routing_weights = routing_weights.to(hidden_state.dtype)

        mlp_output = torch.zeros((B * N, D), dtype = hidden_state.dtype, device = hidden_state.device) # Initialize MLP output - later iterate through experts and sum onto this object
        # One hot encode - for each expert, which topk x token is active - e.g. expert_assignment_mask[0, :] will be 0s if the first expert is never chosen
        expert_assignment_mask = torch.nn.functional.one_hot(selected_experts, num_classes = N_EXPERTS).permute(2, 1, 0) # Creates (N_EXPERTS, TOP_K, BN)
        
        # Iterate through all the experts, apply each expert to the tokens where the expert are relevant, multiple output by the weights for the topk/token for that expert, then sum onto the mlp_output obj
        for expert_ix, expert_wrapper in enumerate(layer.mlp.experts):
            expert_device = expert_wrapper.target_device
            expert = expert_wrapper.expert

            # For this expert, gives the (topk, token) coordinates which uses the expert
            topk_slot, token_indices = torch.where(expert_assignment_mask[expert_ix, :])
            # Get hidden states for tokens that use this expert - shape of num_assigned_tokens x D
            tokens_for_expert = hidden_state[token_indices, :]

            # Get expert output, multiply by routing weights
            expert_output = expert.down_proj(expert.act_fn(expert.gate_proj(tokens_for_expert.to(expert_device))) * expert.up_proj(tokens_for_expert.to(expert_device))) # Shape = num_assigned_tokens x D
            expert_output = expert_output.to(main_device) * routing_weights[token_indices, topk_slot].unsqueeze(1) # For each num_assigned_tokens, multiples it by the corresponding weight in topk_slot fort that token_index

            mlp_output.index_add_(0, token_indices, expert_output.to(hidden_state.dtype))

        mlp_output = mlp_output.reshape(B, N, D) # Convert back from BN x D -> B x N x D

        ### Post-MLP Sum to Residual Stream ###
        hidden_state = residual + mlp_output
                
    # RMS Norm
    hidden_state = model.model.norm(hidden_state)
    
    # LM head
    output_logits = model.lm_head(hidden_state)


output_ids = torch.argmax(output_logits[0, :, :], dim = 1)
print(tokenizer.decode(output_ids[-1]))
