In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from dotenv import load_dotenv
import wandb
import math
from helpers.memory import check_memory, profile_memory
from helpers.logging import get_gradient_stats
from helpers.moe_utils import check_cosine_similarity
from helpers.dataset import load_shard_as_dataloader_mp
from dataclasses import dataclass, asdict
import time
from collections import defaultdict
import os
import glob 
import json
from datetime import datetime
from transformers import AutoTokenizer

from config import ModelConf, TrainConf, OrthoMappingConf
from moe import OlmoeModel
from train import train


check_memory()



Device 0: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 1: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 2: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 3: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB



In [35]:
RUN_NAME = 'test'
RUN_NOTES = 'Baseline test without cosine loss'
save_dir = 'test'

model_conf = ModelConf(
    D = 768, 
    H = 12,
    I = 3072,
    n_experts = 16,
    n_shared_experts = 0,
    top_k = 2,
    norm_topk_prob = False,
    n_layers = 12,
    max_position_embeddings = 1024,
    main_device = 'cuda:2'
)

train_conf = TrainConf(
    micro_batch_size = 8,
    accumulation_steps = 2,
    seq_len = 1024, 
    use_lflb = False
)

or_conf = OrthoMappingConf(
    is_gate_orthogonal_init = False,
    is_freeze_gate_weights = False,
    router_cos_loss_coef = 0,
    expert_cos_loss_coef = 0.01

)

seed = 1234

In [36]:
""" 
Let's load the model
- Set the default_device to specify where all the non-expert layers live (the experts are moved on model init)
- Set the default_dtype to specify the model dtype, all params will be in this dtype except for this explicitly specified differently in class definition
  - In the default OlMoE, RMSNorm is required to be f32 whereas all other params are bf16. 
"""
# torch.set_default_device(conf.main_device) # This is buggy, don't use
torch.set_default_dtype(torch.bfloat16)
torch.set_float32_matmul_precision('medium') # See https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html 
torch.manual_seed(seed)

model = OlmoeModel(
    model_conf,
    or_conf,
    primary_device = model_conf.main_device, # Where to store dense layers and shared experts
    expert_device_map = [model_conf.main_device] * model_conf.n_experts #=, here let's test them with all of them on cuda:0
)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
check_memory()

Total parameters: 1,464,718,080
Device 0: NVIDIA H200
  Allocated: 34.41 GB
  Reserved: 35.88 GB
  Total: 139.83 GB

Device 1: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 2: NVIDIA H200
  Allocated: 2.77 GB
  Reserved: 2.98 GB
  Total: 139.83 GB

Device 3: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB



In [25]:
# reload expert_cos_loss_func
import importlib
from helpers import moe_utils
importlib.reload(moe_utils)
from helpers.moe_utils import expert_cos_loss_func
start_time = time.time()
mean_loss, layer_losses = expert_cos_loss_func(model, model_conf)
compute_time = time.time() - start_time

# validate results
print(f"running time: {compute_time:.4f} seconds")
print(f"average loss: {mean_loss.item():.4f}")
print("layer losses:", [f"{l.item():.4f}" for l in layer_losses])

# basic assertions
assert isinstance(mean_loss, torch.Tensor)
assert len(layer_losses) == model_conf.n_layers
assert all(isinstance(l, torch.Tensor) for l in layer_losses)
assert mean_loss == torch.mean(torch.stack(layer_losses))

running time: 0.0434 seconds
average loss: 2.9807
layer losses: ['3.0513', '3.0290', '2.9517', '2.9092', '3.0907', '2.9186', '2.9560', '2.9285', '3.0082', '3.0686', '2.9153', '2.9413']


# some data
Default setting: 


    1.02 vs 0.0027
    D 768-> 768*4 : 0.530 vs 0.001
    n_experts 30-> 120 : 4.305 vs 0.010

    D_768 -> 768/4 : 2.166 vs 0.0054
    n_experts 30-> 8: 0.261 vs 0.00089

    D 768-> 768*4 and n_experts 30-> 120: 2.151 vs 0.0052
    D 768-> 768/4 and n_experts 30-> 8:  0.476 vs 0.0018

linear dependency on # of experts
$O(1/\sqrt{d})$ dependency on dimension

For reference, the lm loss is ~ 11, and the aux loss is 2~4

In [26]:
val_dl = load_shard_as_dataloader_mp(
    './../../data/val_shard.json',
    tokenizer,
    batch_size = 32,
    seq_len = 2048,
    eos_seperator_id = tokenizer.eos_token_id
)

In [32]:
load_dotenv('./../../secrets.env')
wandb.login(key = os.getenv('WANDB_API_KEY'))
run = wandb.init(
    project = 'interpretable-moes', 
    name = RUN_NAME,
    notes = RUN_NOTES,
    config = {**asdict(model_conf), **asdict(train_conf), **asdict(or_conf)}
)

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33myuanbo096[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin




In [37]:
import train
importlib.reload(train)

train.train(model, tokenizer, train_conf, model_conf, or_conf, val_dl, seed, save_dir = 'test')
wandb.finish()

Found 1946 shards.

=== Loading shard ./../../data/train_shard_0.json (index 0) ===


InternalTorchDynamoError: RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.



You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
