In [1]:
## Auto reload
%load_ext autoreload
%autoreload 2

from tqdm.auto import tqdm
import torch
        
torch.set_float32_matmul_precision('high')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
print(f"Using dtype: {dtype}")

Using device: cuda
Using dtype: torch.bfloat16


## Helper dataset

In [2]:
from utils.train_utils import create_model, create_dataset, train_loop

device="cuda"
dtype=torch.bfloat16

# model and dataset
model_id = 'HuggingFaceTB/SmolLM2-135M-Instruct'
dataset_id = 'roneneldan/TinyStories'

# train hp
epochs=1

bs_factor = 2

batch_size = 32 * bs_factor
max_length = 128

train_steps = 32768 // bs_factor
# train_steps = 256
val_steps = 256

## train for train_steps steps
num_train_samples = batch_size * train_steps
num_test_samples = batch_size * val_steps

# Load model
tokenizer, embed_tokens, lm_head, norm, vocab_size, hidden_size = create_model(model_id)

# load dataset
raw_train_set, raw_test_set = create_dataset(
    dataset_id,
    split="train",
    field = "text",
    num_train_samples = num_train_samples,
    num_test_samples = num_test_samples,
)



In [3]:
# raw_train_set = [elt[:max_length * 5] for elt in tqdm(raw_train_set)]
# raw_test_set = [elt[:max_length * 5] for elt in tqdm(raw_test_set)]

# def batch_tokenize(tokenizer, texts, batch_size=256, max_length=512, device='cuda'):
#     tokenized_batch = []
#     for i in tqdm(range(0, len(texts), batch_size)):
#         batch = texts[i:i + batch_size]
#         tokenized = tokenizer(batch, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')['input_ids']
#         tokenized_batch.append(tokenized)
#     return torch.cat(tokenized_batch, dim=0)

# train_set = batch_tokenize(tokenizer, raw_train_set, batch_size=256, max_length=max_length)
# test_set = batch_tokenize(tokenizer, raw_test_set, batch_size=64, max_length=max_length)

import pickle
# with open('/home/golympie/tokenized_dataset.pickle', 'wb') as f:
#     pickle.dump((train_set, test_set), f)

with open('/home/golympie/tokenized_dataset.pickle', 'rb') as f:
    train_set, test_set = pickle.load(f)

In [4]:
len(train_set)

1048576

In [5]:
## Partial train function
def train(module, run_name, do_compile=False):
    return train_loop(
        module,
        run_name,
        do_compile,
        tokenizer,
        device,
        dtype,
        train_set,
        test_set,
        epochs,
        batch_size,
        max_length,
        embed_tokens,
        lm_head,
        norm,
    )


## Import modules

In [6]:
from modules.archi_modules import StackedMixinBlock, count_parameters
from modules.positionnal_modules import NaivePositionnalEmbedding

from modules.mixin_modules import (
    RNNMixin,
    LSTMMixin,
    MultiScaleRetentionMixin,
    Mamba2Mixin,
    RWKV6Mixin,
    GroupedQuerySelfAttentionMixin,
    MultiHeadLatentAttentionMixin,
)

from modules.ffn_modules import FFN, SparseMoeFFN

## STACK 4 - MLP

In [7]:
num_layers = 6
ffn_module = FFN(hidden_size, hidden_size*4)

### LSTM

In [None]:
%%time

lstm = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=LSTMMixin(hidden_size),
    ffn_module=ffn_module,
)

count_parameters(lstm)
train(lstm, run_name='lstm', do_compile=False)

### GQA

In [None]:
%%time

gqsa = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=GroupedQuerySelfAttentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(gqsa)
train(gqsa,run_name='gqsa', do_compile=False)

### MHLA

In [11]:
%%time

mhla = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=MultiHeadLatentAttentionMixin(hidden_size, num_attention_heads=9),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mhla)
train(mhla,run_name='mhla', do_compile=False)

Total parameters: 30,053,664
Mixin parameters: 6,088,608
FFN parameters: 23,891,328
Using 16 workers for DataLoader.


Epochs:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/16384 [00:00<?, ?it/s]

CPU times: user 40.5 s, sys: 3.36 s, total: 43.9 s
Wall time: 45 s


KeyboardInterrupt: 

### Retentive Network

In [None]:
%%time
retnet = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=MultiScaleRetentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(retnet)
train(retnet,run_name='retnet', do_compile=False)

### Mamba

In [None]:
%%time

mamba = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=Mamba2Mixin(hidden_size = hidden_size, num_attention_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mamba)
train(mamba,run_name='mamba', do_compile=False)

### RWKV

In [None]:
%%time

rwkv = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=RWKV6Mixin(hidden_size = hidden_size, num_attention_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(rwkv)
train(rwkv,run_name='rwkv', do_compile=False)

## STACK 8 - Moe

In [None]:
num_layers = 6
ffn_module = SparseMoeFFN(
    hidden_size,
    hidden_size*4,
    num_experts=4,
    num_experts_per_tok=1,
    norm_topk_prob=True
)

### LSTM MOE

In [None]:
%%time

lstm_moe = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=LSTMMixin(hidden_size),
    ffn_module=ffn_module,
)

count_parameters(lstm_moe)
train(lstm_moe, run_name='lstm-moe', do_compile=False)

### GQA MOE

In [None]:
%%time

gqsa_moe = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=GroupedQuerySelfAttentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(gqsa_moe)
train(gqsa_moe,run_name='gqsa-moe', do_compile=False)

### MHLA

In [None]:
%%time

mhla_moe = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=MultiHeadLatentAttentionMixin(hidden_size, num_attention_heads=9),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mhla_moe)
train(mhla_moe,run_name='mhla-moe', do_compile=False)

### Retentive Network

In [None]:
%%time
retnet_moe = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=MultiScaleRetentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(retnet_moe)
train(retnet_moe,run_name='retnet-moe', do_compile=False)

### Mamba MOE

In [None]:
%%time

mamba_moe = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=Mamba2Mixin(hidden_size = hidden_size, num_attention_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mamba_moe)
train(mamba_moe,run_name='mamba-moe', do_compile=False)

### RWKV MOE

In [None]:
%%time

rwkv_moe = StackedMixinBlock(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    mixin_module=RWKV6Mixin(hidden_size = hidden_size, num_attention_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(rwkv_moe)
train(rwkv_moe,run_name='rwkv-moe', do_compile=False)