# Training BALM-MoE

## Import Packages

In [1]:
from dataclasses import asdict, dataclass, field, fields
from enum import Enum

class StrEnum(str, Enum):
    def __str__(self):
        return self.value

from typing import Optional, Tuple, List, Dict, Any, Iterable, Union

from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import Optimizer
from torch.utils.data import DataLoader
from transformers import TrainingArguments, Trainer, DataCollatorForLanguageModeling, RobertaTokenizer

In [2]:
from balm.config import BalmConfig, BalmMoEConfig
from balm.data import load_dataset, DataCollator
from balm.models import (
    BalmForMaskedLM,
    BalmModel,
    BalmMoEForMaskedLM,
)
from balm.tokenizer import Tokenizer
from balm.train import Trainer

## Load Tokenizer

In [3]:
tokenizer = Tokenizer(vocab="./balm/vocab.json")

## Load and Clean Training Data

In [4]:
def remove_sep(txt):
    return txt.replace("</s>", "<cls><cls>")


data_files = {
    "train": "/training-data/jaffe_lc-coherence/paired/LC-coherence_90-5-5/train.txt",
    "eval": "/training-data/jaffe_lc-coherence/paired/LC-coherence_90-5-5/eval.txt",
}

dataset = load_dataset("text", data_files=data_files, preprocess_fn=remove_sep)

In [5]:
dataset

DatasetDict
-----------
  train
    num_rows: 1202270
    columns: ['text']
  eval
    num_rows: 66792
    columns: ['text']

In [1]:
run_name = "balmMoE_expertchoice_1shared_altern_052924"

## Tokenize Dataset

In [6]:
tokenized_dataset = dataset.map(
    lambda x: tokenizer(
        x["text"],
        padding=True,
        truncation=True,
        max_length=320,
    ),
     remove_columns="text"
)

Encoding:   0%|          | 0/1202270 [00:00<?, ?it/s]

Encoding:   0%|          | 0/66792 [00:00<?, ?it/s]

## Load Collator and Data

In [7]:
collator = DataCollator(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)

In [8]:
train_dataloader = DataLoader(
    tokenized_dataset["train"],
    batch_size=32,
    shuffle=True,
)

eval_dataloader = DataLoader(
    tokenized_dataset["eval"],
    batch_size=32,
    shuffle=True,
)

## Model

In [9]:
config = BalmMoEConfig(
    expert_choice_router=True,
    embed_dim=960,
    ffn_dim=3840,
    num_layers=6,
    num_experts=16,
    num_heads=20,
    num_shared_experts=1,
    alternate_sparsity=True,
    expert_capacity=128,
    router_z_loss_coef=0.01,
    router_aux_loss_coef=0.01,
    vocab_size=tokenizer.vocab_size,
)
    

In [10]:
model = BalmMoEForMaskedLM(config=config)

In [11]:
model_size = sum(p.numel() for p in model.parameters())
print(f"Model size: {model_size/1e6:.2f}M")

Model size: 305.22M


## Trainer

In [12]:
trainer = Trainer(
    model=model,
    data_collator=collator,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["eval"],
    output_dir="./training_runs/balmMoE_expertchoice_1shared_altern_052924",
    #epochs=1,
    max_steps=500000,
    logging_steps=100,
    eval_steps=25000,
    warmup_steps=30000,
    learning_rate=16e-4,
    # save_steps=15,
    per_device_train_batch_size=32,
    # use_cpu=True,
    use_wandb=True,
    wandb_project="balm_moe",
    # wandb_entity="bryanbriney",
    run_name="balmMoE_expertchoiceBig_1shared_altern_052924",
)


In [None]:
import wandb
#wandb.login()
#wandb.init(project = 'balm_moe', name='balmMoE_expertchoice_1shared_0altern_052924')
trainer.train()

[34m[1mwandb[0m: Currently logged in as: [33mbrineylab[0m. Use [1m`wandb login --relogin`[0m to force relogin


Training:   0%|          | 0/500000 [00:00<?, ?step/s]

  batch = {k: torch.tensor(v) for k, v in examples.items()}


step 100   | loss: 2.7781 | MLM loss: 2.7041 | router z-loss: 0.0739 | lr: 0.000005
step 200   | loss: 2.4280 | MLM loss: 2.3845 | router z-loss: 0.0435 | lr: 0.000011
step 300   | loss: 2.1152 | MLM loss: 2.0943 | router z-loss: 0.0208 | lr: 0.000016
step 400   | loss: 1.9178 | MLM loss: 1.9097 | router z-loss: 0.0080 | lr: 0.000021
step 500   | loss: 1.6520 | MLM loss: 1.6475 | router z-loss: 0.0044 | lr: 0.000027
step 600   | loss: 1.3631 | MLM loss: 1.3591 | router z-loss: 0.0040 | lr: 0.000032
step 700   | loss: 1.1150 | MLM loss: 1.1109 | router z-loss: 0.0041 | lr: 0.000037
step 800   | loss: 0.9257 | MLM loss: 0.9211 | router z-loss: 0.0046 | lr: 0.000043
step 900   | loss: 0.7853 | MLM loss: 0.7810 | router z-loss: 0.0043 | lr: 0.000048
step 1000  | loss: 0.7495 | MLM loss: 0.7449 | router z-loss: 0.0046 | lr: 0.000053
step 1100  | loss: 0.6385 | MLM loss: 0.6345 | router z-loss: 0.0040 | lr: 0.000059
step 1200  | loss: 0.5610 | MLM loss: 0.5572 | router z-loss: 0.0039 | lr: 0

In [16]:
model = trainer.model
torch.save(model.state_dict(), f'./models/{run_name}_{model_size/1e6:.2f}Mp.pth')
#trainer.save_model(f'./models/{run_name}_2')

In [17]:
wandb.finish()