In [1]:
## Auto reload
%load_ext autoreload
%autoreload 2

from tqdm.auto import tqdm
import pickle
import torch
        
torch.set_float32_matmul_precision('medium')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
print(f"Using dtype: {dtype}")

import os

os.environ['OMP_NUM_THREAD']='16'

Using device: cuda
Using dtype: torch.bfloat16


## Helper dataset

In [2]:
from utils.train_utils import create_model, create_dataset, train_loop

device="cuda"
dtype=torch.bfloat16

# model and dataset
model_id = 'HuggingFaceTB/SmolLM2-135M-Instruct'
dataset_id = 'roneneldan/TinyStories'

# train hp
epochs=1

num_samples = 1000000

bs_factor = 1

batch_size = 32
max_length = 128

train_steps = num_samples // batch_size
# train_steps = 256
val_steps = 256

## train for train_steps steps
num_train_samples = batch_size * train_steps
num_test_samples = batch_size * val_steps

# Load model
tokenizer, embed_tokens, lm_head, norm, vocab_size, hidden_size = create_model(model_id)

# load dataset
raw_train_set, raw_test_set = create_dataset(
    dataset_id,
    split="train",
    field = "text",
    num_train_samples = num_train_samples,
    num_test_samples = num_test_samples,
)



In [3]:
# raw_train_set = [elt[:max_length * 5] for elt in tqdm(raw_train_set)]
# raw_test_set = [elt[:max_length * 5] for elt in tqdm(raw_test_set)]

# def batch_tokenize(tokenizer, texts, batch_size=256, max_length=512, device='cuda'):
#     tokenized_batch = []
#     for i in tqdm(range(0, len(texts), batch_size)):
#         batch = texts[i:i + batch_size]
#         tokenized = tokenizer(batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids']
#         tokenized_batch.append(tokenized)
#     return torch.cat(tokenized_batch, dim=0)

# train_set = batch_tokenize(tokenizer, raw_train_set, batch_size=256, max_length=max_length)
# test_set = batch_tokenize(tokenizer, raw_test_set, batch_size=64, max_length=max_length)


# with open('/home/golympie/tokenized_dataset.pickle', 'wb') as f:
#     pickle.dump((train_set, test_set), f)

with open('/home/golympie/tokenized_dataset.pickle', 'rb') as f:
    train_set, test_set = pickle.load(f)


train_set = train_set.to('cuda')
test_set = test_set.to('cuda')

In [4]:
## Partial train function
def train(module, run_name, do_compile=False):
    return train_loop(
        module,
        run_name,
        do_compile,
        tokenizer,
        device,
        dtype,
        train_set,
        test_set,
        epochs,
        batch_size,
    )


## Import modules

In [5]:
from modules.archi_modules import StackedMixinForCausalLM, count_parameters
from modules.positionnal_modules import NaivePositionnalEmbedding

from modules.mixin_modules import (
    RNNMixin,
    LSTMMixin,
    MultiScaleRetentionMixin,
    Mamba2Mixin,
    RWKV6Mixin,
    GroupedQuerySelfAttentionMixin,
    MultiHeadLatentAttentionMixin,
)

from modules.ffn_modules import FFN, SparseMoeFFN

## STACK 4 - MLP

In [6]:
num_layers = 8
ffn_module = FFN(hidden_size, hidden_size*4)

### GQA

In [None]:
%%time

gqsa = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=True,
    vocab_size=vocab_size,
    mixin_module=GroupedQuerySelfAttentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=9),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(gqsa)
train(gqsa,run_name='gqsa', do_compile=True)

### Retentive Network

In [None]:
%%time
retnet = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=True,
    vocab_size=vocab_size,
    mixin_module=MultiScaleRetentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(retnet)
train(retnet,run_name='retnet', do_compile=True)

### Mamba

In [None]:
%%time

mamba = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=True,
    vocab_size=vocab_size,
    mixin_module=Mamba2Mixin(hidden_size = hidden_size, num_attention_heads=6),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mamba)
train(mamba,run_name='mamba', do_compile=True)

### RWKV

In [None]:
%%time

rwkv = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=True,
    vocab_size=vocab_size,
    mixin_module=RWKV6Mixin(hidden_size = hidden_size, num_attention_heads=9),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(rwkv)
train(rwkv,run_name='rwkv', do_compile=True)

## STACK 8 - Moe

In [None]:
num_layers = 6
ffn_module = SparseMoeFFN(
    hidden_size,
    hidden_size*4,
    num_experts=8,
    num_experts_per_tok=2,
    norm_topk_prob=True
)

### GQA MOE

In [None]:
%%time

gqsa_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=GroupedQuerySelfAttentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=9),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(gqsa_moe)
train(gqsa_moe,run_name='gqsa-moe', do_compile=True)

### Retentive Network

In [None]:
%%time
retnet_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=MultiScaleRetentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(retnet_moe)
train(retnet_moe,run_name='retnet-moe', do_compile=True)

### Mamba MOE

In [None]:
%%time

mamba_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=Mamba2Mixin(hidden_size = hidden_size, num_attention_heads=6),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mamba_moe)
train(mamba_moe,run_name='mamba-moe', do_compile=True)

### RWKV MOE

In [None]:
%%time

rwkv_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=RWKV6Mixin(hidden_size = hidden_size, num_attention_heads=9),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(rwkv_moe)
train(rwkv_moe,run_name='rwkv-moe', do_compile=True)