In [1]:
import torch
from dotenv import load_dotenv
import wandb
from helpers.memory import check_memory, profile_memory
from helpers.moe_utils import check_cosine_similarity
from helpers.dataset import load_shard_as_dataloader
from dataclasses import  asdict
import os
import json
from datetime import datetime
from transformers import AutoTokenizer

from config import ModelConf, TrainConf
from moe import OlmoeModel
from train import train


check_memory()

Device 0: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 1: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 2: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB

Device 3: NVIDIA H200
  Allocated: 0.00 GB
  Reserved: 0.00 GB
  Total: 139.83 GB



In [None]:
model_conf = ModelConf(
    D = 768,
    H = 8,
    I = 512,
    n_experts = 30,
    n_shared_experts = 2,
    top_k = 4,
    norm_topk_prob = False,
    n_layers = 10,
    max_position_embeddings = 2048,
    gate_orthogonal = False,
    main_device = 'cuda:0'
)

train_conf = TrainConf()
seed = 1234

TypeError: ModelConf.__init__() got an unexpected keyword argument 'main_device'

In [1]:
""" 
Let's load the model
- Set the default_device to specify where all the non-expert layers live (the experts are moved on model init)
- Set the default_dtype to specify the model dtype, all params will be in this dtype except for this explicitly specified differently in class definition
  - In the default OlMoE, RMSNorm is required to be f32 whereas all other params are bf16. 
"""
# torch.set_default_device(conf.main_device) # This is buggy, don't use
torch.set_default_dtype(torch.bfloat16)
torch.set_float32_matmul_precision('medium') # See https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html 
torch.manual_seed(seed)

model = OlmoeModel(
    model_conf,
    primary_device = model_conf.main_device, # Where to store dense layers and shared experts
    expert_device_map = ['cuda:0'] * model_conf.n_experts #=, here let's test them with all of them on cuda:0
)
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained('allenai/OLMoE-1B-7B-0924', add_eos_token = False, add_bos_token = False)
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
check_memory()

NameError: name 'torch' is not defined

In [None]:
"""
Setup a Wandb run for logging. Choose a run name and notes for the run!
"""
RUN_NAME = 'test-01 -single-gpu -experts-32 -topk-4 -forward-slow'
RUN_NOTES = 'Baseline test with routing orthogonal initialization and no gate update'

load_dotenv('./../../secrets.env')
wandb.login(key = os.getenv('WANDB_API_KEY'))
run = wandb.init(
    project = 'interpretable-moes', 
    name = RUN_NAME,
    notes = RUN_NOTES,
    config = {**asdict(model_conf), **asdict(train_conf)}
)

# (Optional) Also log various info as a wandb media object.
additional_log_notes = {
    'run_name': RUN_NAME,
    'notes': RUN_NOTES,
    'created_at': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
    'total_model_params': sum(p.numel() for p in model.parameters()),
    'available_cuda_gpus': [torch.cuda.get_device_properties(i).name for i in range(torch.cuda.device_count())],
    'model_conf': asdict(model_conf),
    'train_conf': asdict(train_conf)
}

wandb.log({'conf': wandb.Html(f"<pre style='font-size:12px;'>{json.dumps(additional_log_notes, indent = 2)}</pre>")})



In [5]:


val_dl = load_shard_as_dataloader(
    './../../data/val_shard.json',
    tokenizer,
    batch_size = 32,
    seq_len = 2048,
    eos_seperator_id = tokenizer.eos_token_id
)

In [None]:
train(model, tokenizer, train_conf, model_conf, val_dl, seed)
wandb.finish()

In [7]:
# check cosine similarity of gate weights
for layer in model.layers:
    print(check_cosine_similarity(layer.moe.gate.weight))

tensor([[ 1.0040e+00, -1.9381e-02,  7.0575e-02, -1.0232e-02, -1.3055e-02,
          1.3327e-02, -7.4502e-04,  6.5123e-02,  4.5616e-02,  6.4997e-02,
          1.3804e-02, -1.4292e-02,  2.1619e-02,  9.0053e-02,  9.5576e-02,
          7.1145e-02,  4.4628e-02,  1.5899e-02, -2.4944e-03,  8.0756e-03,
          8.8617e-02, -1.3800e-02,  5.9052e-02,  2.0120e-02,  5.0409e-02,
          1.4031e-02, -1.1387e-02,  6.8987e-03,  3.9767e-02, -4.6113e-02],
        [-1.9381e-02,  1.0006e+00,  3.7543e-02,  5.6117e-02, -1.7157e-02,
         -1.8685e-02, -8.1010e-02, -1.1259e-02,  3.0969e-02,  2.0211e-02,
          6.0370e-04, -4.9318e-03,  2.6778e-02,  4.9485e-03,  2.7560e-02,
          8.0993e-02,  2.9810e-02,  4.5340e-02,  7.5180e-02,  3.9464e-02,
         -8.2291e-03,  2.4975e-02,  2.4285e-02, -2.2968e-02,  3.6457e-02,
          1.6802e-03, -2.2392e-02,  2.6965e-02,  7.1705e-03, -2.3014e-02],
        [ 7.0575e-02,  3.7543e-02,  9.9427e-01,  3.3195e-02,  5.8373e-02,
          6.6809e-02,  1.3236e-03, -