In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from helpers.memory import check_memory, profile_memory
from helpers.logging import get_gradient_stats
from helpers.moe_utils import check_cosine_similarity
from helpers.dataset import load_shard_as_dataloader_mp
from dataclasses import dataclass, asdict
import time
from collections import defaultdict
import os
import glob 
import json
from datetime import datetime
from transformers import AutoTokenizer

from config import ModelConf, TrainConf, OrthoMappingConf
from moe import OlmoeModel
from train import train


check_memory()

Device 0: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 1: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 2: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 3: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB



In [2]:
RUN_NAME = 'test'
RUN_NOTES = 'Baseline test without cosine loss'
save_dir = 'test'

model_conf = ModelConf(
    D = 768, 
    H = 12,
    I = 768*4,
    n_experts = 16,
    n_shared_experts = 0,
    top_k = 2,
    norm_topk_prob = False,
    n_layers = 12,
    max_position_embeddings = 1024,
    main_device = 'cuda:0'
)

train_conf = TrainConf(
    micro_batch_size = 8,
    accumulation_steps = 2,
    seq_len = 1024, 
    use_lflb = False
)

or_conf = OrthoMappingConf(
    is_gate_orthogonal_init = False,
    is_freeze_gate_weights = False,
    router_cos_loss_coef = 0,
    expert_cos_loss_coef = 0.01

)

seed = 1234

In [3]:
""" 
Let's load the model
- Set the default_device to specify where all the non-expert layers live (the experts are moved on model init)
- Set the default_dtype to specify the model dtype, all params will be in this dtype except for this explicitly specified differently in class definition
  - In the default OlMoE, RMSNorm is required to be f32 whereas all other params are bf16. 
"""
# torch.set_default_device(conf.main_device) # This is buggy, don't use
torch.set_default_dtype(torch.bfloat16)
torch.set_float32_matmul_precision('medium') # See https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html 
torch.manual_seed(seed)

model = OlmoeModel(
    model_conf,
    or_conf,
    primary_device = model_conf.main_device, # Where to store dense layers and shared experts
    expert_device_map = [model_conf.main_device] * model_conf.n_experts #=, here let's test them with all of them on cuda:0
)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
check_memory()

tokenizer_config.json:   0%|          | 0.00/5.37k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.12M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/65.0 [00:00<?, ?B/s]

Total parameters: 1,464,718,080
Device 0: NVIDIA H200
  Allocated: 2.77 GB
  Reserved: 2.98 GB
  Total: 139.83 GB

Device 1: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 2: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 3: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB



In [4]:
""" 
Let's load a forward pass with a batch size of 2, to make sure the model is able to run
- If you have multiple working forward methods, this is a good chance to test them for equality
"""
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
prompt = ['I am a dog and I like to eat. My favorite food is', 'My cat is']
inputs = tokenizer(prompt, truncation = True, max_length = 128, padding = 'max_length', return_tensors = 'pt').to(model_conf.main_device)

with torch.no_grad():
    output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_slow', use_lflb = True, use_checkpointing = False)
    # output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_fast', use_lflb = True, use_checkpointing = False)
    # output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_async', use_lflb = True, use_checkpointing = False)

output_ids = torch.argmax(output['logits'][:, :, :], dim = 2)
for i in range(output_ids.size(0)):
    idx = inputs["attention_mask"].sum(dim = -1)[i].item() - 1 # get length of attention mask to find the last non-mask output token ix
    print(tokenizer.decode(output_ids[i, idx], skip_special_tokens=True))

W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0] Graph break from `Tensor.item()`, consider setting:
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0]     torch._dynamo.config.capture_scalar_outputs = True
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0] or:
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0]     env TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0] to include these operations in the captured graph.
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0] 
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0] Graph break: from user code at:
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0]   File "/workspace/interpretable-moes/experiments_cli/current/moe.py", line 689, in forward
W0226 14:14:48.270000 143420 torch/_dynamo/variables/tensor.py:869] [0/0]   

torch.Size([256, 16, 768])
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3579, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_143420/2449217069.py", line 12, in <module>
    output = model(inputs['input_ids'], inputs['attention_mask'], moe_method = 'forward_slow', use_lflb = True, use_checkpointing = False)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", 

In [None]:
# reload expert_cos_loss_func
import importlib
from helpers import moe_utils
importlib.reload(moe_utils)
from helpers.moe_utils import expert_cos_loss_func
start_time = time.time()
mean_loss, layer_losses = expert_cos_loss_func(model, model_conf)
compute_time = time.time() - start_time

# validate results
print(f"running time: {compute_time:.4f} seconds")
print(f"average loss: {mean_loss.item():.4f}")
print("layer losses:", [f"{l.item():.4f}" for l in layer_losses])

# basic assertions
assert isinstance(mean_loss, torch.Tensor)
assert len(layer_losses) == model_conf.n_layers
assert all(isinstance(l, torch.Tensor) for l in layer_losses)
assert mean_loss == torch.mean(torch.stack(layer_losses))

In [15]:
from helpers import moe_utils
import importlib
importlib.reload(moe_utils)
from helpers.moe_utils import gap_loss_func

def test_gap_loss_func():
    
    torch.manual_seed(42)
    
    # 测试基本情况（无attention_mask）
    gate_logits = [
        torch.tensor([[10., 0., 0.], [0., 10., 0.]]),  # 层1：明确选择专家0和1
        torch.tensor([[0., 10., 0.], [0., 0., 10.]])   # 层2：明确选择专家1和2
    ]
    loss = gap_loss_func(gate_logits, top_k=1, attention_mask=None)
    print(loss)
    # assert torch.allclose(loss, torch.tensor(0.0), atol=1e-4)

    # 测试有attention_mask的情况
    mask = torch.tensor([[1, 0], [1, 1]])  # 第二个样本的第一个token被mask
    gate_logits = [
        torch.tensor([[1., 1., 1.], [1., 1., 1.], [1., 1., 1.], [1., 1., 1.]]),  # 模拟batch_size=2, seq_len=2
    ]
    loss_masked = gap_loss_func(gate_logits, top_k=1, attention_mask=mask)
    
    # 计算期望值：有效token数为3（mask掉1个）
    routing_weights = torch.softmax(torch.cat(gate_logits), -1)
    top2_vals, _ = torch.topk(routing_weights, k=2, dim=-1)
    expected_loss = -torch.log(top2_vals[:,0] - top2_vals[:,1] + 1e-8)[[0,2,3]].mean()
    assert torch.allclose(loss_masked, expected_loss), "掩码处理不正确"

    # 测试均匀分布时的损失值
    uniform_logits = [torch.zeros(4, 3)]  # 均匀分布logits
    uniform_loss = gap_loss_func(uniform_logits, top_k=1, attention_mask=None)
    print(uniform_loss)

    # 测试不同top_k值
    specific_logits = [torch.tensor([[4.,3.]])]  # top_k=2时差距应为1.0
    loss_k2 = gap_loss_func(specific_logits, top_k=1, attention_mask=None)
    expected = -math.log(1.0 + 1e-8)
    print(loss_k2)
    print(expected)



test_gap_loss_func()


tensor(0.0001)
tensor(18.4207)
tensor(0.7719)
-9.999999889225291e-09


NameError: name 'uniform_loss' is not defined

# some data
Default setting: 


    1.02 vs 0.0027
    D 768-> 768*4 : 0.530 vs 0.001
    n_experts 30-> 120 : 4.305 vs 0.010

    D_768 -> 768/4 : 2.166 vs 0.0054
    n_experts 30-> 8: 0.261 vs 0.00089

    D 768-> 768*4 and n_experts 30-> 120: 2.151 vs 0.0052
    D 768-> 768/4 and n_experts 30-> 8:  0.476 vs 0.0018

linear dependency on # of experts
$O(1/\sqrt{d})$ dependency on dimension

For reference, the lm loss is ~ 11, and the aux loss is 2~4

In [5]:
val_dl = load_shard_as_dataloader_mp(
    './../../data/val_shard.json',
    tokenizer,
    batch_size = 32,
    seq_len = 2048,
    eos_seperator_id = tokenizer.eos_token_id
)

In [None]:
load_dotenv('./../../secrets.env')
wandb.login(key = os.getenv('WANDB_API_KEY'))
run = wandb.init(
    project = 'interpretable-moes', 
    name = RUN_NAME,
    notes = RUN_NOTES,
    config = {**asdict(model_conf), **asdict(train_conf), **asdict(or_conf)}
)

In [None]:
import train
importlib.reload(train)

train.train(model, tokenizer, train_conf, model_conf, or_conf, val_dl, seed, save_dir = 'test')
wandb.finish()