In [1]:
import torch
import random
import numpy as np

from datasets import load_dataset
from tokenizers import ByteLevelBPETokenizer
from tokenizers.processors import BertProcessing
from transformers import PretrainedConfig, Trainer, TrainingArguments

from modules import TransformerClassifier, MLP, MoE, MoDE


def seed_all(seed: int = 0) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def get_gpu_device() -> torch.device:
    if torch.cuda.is_available():
        device_str = "cuda"     # GPU
    elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
        device_str = "mps"      # Apple silicon
    else:
        print("Warning: No GPU found, using CPU instead.")
        device_str = "cpu"      # CPU
    return torch.device(device_str)

seed_all()

## Configs

In [2]:
base_config = dict(
    vocab_size=5000,
    max_position_embeddings=256,
    num_attention_heads=8,
    num_hidden_layers=4,
    hidden_dropout_prob=0.1,
    hidden_size=128,
    intermediate_size=512,
    num_labels=2,
    bias=True
)

standard_config = PretrainedConfig(
    **base_config,
    ff_cls=MLP,
    mh_moe=False
)

moe_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoE,
    mh_moe=False
)

mode_config = PretrainedConfig(
    **base_config,
    num_experts=5, # add 1 no-op expert
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoDE,
    mh_moe=False
)

mh_moe_config = PretrainedConfig(
    **base_config,
    num_experts=4,
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoE,
    mh_moe=True
)

mh_mode_config = PretrainedConfig(
    **base_config,
    num_experts=5, # add 1 no-op expert
    capacity_factor=2.0,
    num_experts_per_token=1,
    ff_cls=MoDE,
    mh_moe=True
)

## Tokenizer Training

In [3]:
dataset = load_dataset('imdb')

tokenizer = ByteLevelBPETokenizer()
tokenizer.train_from_iterator(
    dataset['train']['text'],
    vocab_size=base_config['vocab_size'],
    special_tokens=["<s>", "</s>", "<pad>"],
    min_frequency=2
)
tokenizer.post_processor = BertProcessing(
    ("</s>", tokenizer.token_to_id("</s>")),
    ("<s>", tokenizer.token_to_id("<s>")),
)

tokenizer.enable_truncation(max_length=base_config['max_position_embeddings'])
tokenizer.enable_padding(pad_id=tokenizer.token_to_id("<pad>"), pad_token="<pad>", length=base_config['max_position_embeddings'])
tokenizer.model_max_length = base_config['max_position_embeddings']
tokenizer.pad_token = "<pad>"


def tokenize(row):
    return {
        'input_ids': tokenizer.encode(row['text']).ids,
    }

tokenized_dataset = dataset.map(tokenize)






## Train Vanilla Transformer

In [5]:
training_args = TrainingArguments(
    output_dir='./results_vanilla',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.001,
    logging_dir='./logs',
)

model = TransformerClassifier(standard_config)
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
)

trainer.train()
trainer.evaluate()

  0%|          | 0/9375 [00:00<?, ?it/s]

{'loss': 0.7554, 'grad_norm': 11.673501014709473, 'learning_rate': 5e-05, 'epoch': 0.16}
{'loss': 0.7101, 'grad_norm': 9.985132217407227, 'learning_rate': 4.71830985915493e-05, 'epoch': 0.32}
{'loss': 0.6871, 'grad_norm': 3.6002533435821533, 'learning_rate': 4.436619718309859e-05, 'epoch': 0.48}
{'loss': 0.6654, 'grad_norm': 4.413424015045166, 'learning_rate': 4.154929577464789e-05, 'epoch': 0.64}
{'loss': 0.6496, 'grad_norm': 7.207448482513428, 'learning_rate': 3.8732394366197184e-05, 'epoch': 0.8}
{'loss': 0.643, 'grad_norm': 5.3875555992126465, 'learning_rate': 3.5915492957746486e-05, 'epoch': 0.96}
{'loss': 0.6326, 'grad_norm': 10.878851890563965, 'learning_rate': 3.3098591549295775e-05, 'epoch': 1.12}
{'loss': 0.6077, 'grad_norm': 6.362067222595215, 'learning_rate': 3.028169014084507e-05, 'epoch': 1.28}
{'loss': 0.5738, 'grad_norm': 5.495521068572998, 'learning_rate': 2.746478873239437e-05, 'epoch': 1.44}
{'loss': 0.591, 'grad_norm': 2.6917409896850586, 'learning_rate': 2.46478873

  0%|          | 0/3125 [00:00<?, ?it/s]

{'eval_loss': 0.5624286532402039,
 'eval_runtime': 53.9324,
 'eval_samples_per_second': 463.543,
 'eval_steps_per_second': 57.943,
 'epoch': 3.0}

## Train MoE

In [None]:
training_args = TrainingArguments(
    output_dir='./results_moe',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.001,
    logging_dir='./logs',
)

model_moe = TransformerClassifier(moe_config)
trainer = Trainer(
    model=model_moe,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
)

trainer.train()
trainer.evaluate()

## Train MoDE

In [None]:
training_args = TrainingArguments(
    output_dir='./results_mode',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.001,
    logging_dir='./logs',
)

model_mode = TransformerClassifier(mode_config)
trainer = Trainer(
    model=model_mode,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
)

trainer.train()
trainer.evaluate()

## Train MH-MoE

In [None]:
training_args = TrainingArguments(
    output_dir='./results_mh-moe',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.001,
    logging_dir='./logs',
)

model_mh_moe = TransformerClassifier(mh_moe_config)
trainer = Trainer(
    model=model_mh_moe,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
)

trainer.train()
trainer.evaluate()

## Train MH-MoDE

In [None]:
training_args = TrainingArguments(
    output_dir='./results_mh-mode',
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.001,
    logging_dir='./logs',
)

model_mh_mode = TransformerClassifier(mh_mode_config)
trainer = Trainer(
    model=model_mh_mode,
    args=training_args,
    train_dataset=tokenized_dataset['train'],
    eval_dataset=tokenized_dataset['test'],
)

trainer.train()
trainer.evaluate()