In [1]:
## Auto reload
%load_ext autoreload
%autoreload 2

from datasets import load_dataset
from tqdm.auto import tqdm
import numpy as np
import thunder
import pickle
import torch
import os
        
torch.set_float32_matmul_precision('medium')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
dtype = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float32
print(f"Using dtype: {dtype}")

Using device: cuda
Using dtype: torch.bfloat16


In [2]:
from utils.train_utils import create_model_tokenizer, batch_tokenize, destruct_module_optimized
from utils.lightning_utils import prepare_model_and_data

## Config

In [3]:
device="cuda"
dtype=torch.bfloat16

# model and dataset
model_id = 'unsloth/mistral-7b-instruct-v0.3'
dataset_id = 'roneneldan/TinyStories'
hidden_size = 512
intermediate_size = 2048
max_position_embeddings = 512
text_field='text'

## Dataset pre processing
token_sample_estimate=10000
quantile_treshold=0.95
override_quantile_max_length=512 ## If quantile_max_length > override_quantile_max_length, cap it


# train hp
epochs=2
batch_size=8
max_length=256 # Enough to cover about 80% of the dataset
learning_rate=2e-4
min_lr=1e-6
betas=(0.9, 0.999, 0.9999)
alpha=5
weight_decay=1e-4
warmup_steps=64
warmup_steps=0


## Max length in tiny stories
# Quantile : 0.1 Length : 142
# Quantile : 0.25 Length : 164
# Quantile : 0.5 Length : 191
# Quantile : 0.75 Length : 227
# Quantile : 0.9 Length : 306
# Quantile : 0.99 Length : 605
# Quantile : 0.95 Length : 423

## Size the dataset, gradient_accumulation etc. based on token budget
token_budget = 1_000_000_000
batch_token_budget=100_000

gradient_accumulation_steps=batch_token_budget // (batch_size * max_length)
global_batch_size= gradient_accumulation_steps * batch_size

num_train_steps = 1 + token_budget // ( max_length * batch_size * gradient_accumulation_steps)

num_train_samples = num_train_steps * batch_size * gradient_accumulation_steps // epochs
num_test_samples = batch_size * 64

print('gradient_accumulation_steps', gradient_accumulation_steps)
print('num_train_steps', num_train_steps)
print('num_train_samples', num_train_samples)
print('num_test_samples', num_test_samples)

gradient_accumulation_steps 48
num_train_steps 10173
num_train_samples 1953216
num_test_samples 512


## Pre tokenize and pickle dataset (to run once each time you change tokenizer or dataset)

In [4]:
# tokenizer, embed_tokens, lm_head, norm, vocab_size, hidden_size = create_model_tokenizer(model_id, load_model=False)
# dataset = load_dataset(dataset_id)

# sample = list(dataset['train'].select(range(token_sample_estimate))[text_field])

# tokenized = batch_tokenize(
#     tokenizer,
#     sample,
#     padding=None,
#     batch_size=256,
#     max_length=None
# )

# lens = [len(x) for x in tokenized]
# quantiles_proba = [0.1, 0.25, 0.5, 0.75, 0.9, 0.99, quantile_treshold]
# quantiles = list(map(lambda x:x+1, np.quantile(lens, q=quantiles_proba).tolist()))
# quantile_max_length=int(quantiles[-1]) + 1

# for q,l in zip(quantiles_proba, quantiles):
#     print(f"Quantile : {q} Length : {int(l)}")

# print(f'Calculated max_length for quantile {quantile_treshold} is {quantile_max_length}')

# if quantile_max_length > override_quantile_max_length:
#     quantile_max_length = override_quantile_max_length

# train_set = batch_tokenize(tokenizer, list(dataset['train'][text_field]), batch_size=512, max_length=max_length)
# val_set = batch_tokenize(tokenizer, list(dataset['validation'][text_field]), batch_size=64, max_length=max_length)

# with open('/home/golympie/tokenized_dataset.pickle', 'wb') as f:
#     pickle.dump((train_set, val_set), f)

## Load model and dataset

In [5]:
# Load model
tokenizer, embed_tokens, lm_head, norm, vocab_size, hidden_size = create_model_tokenizer(model_id, load_model=False)

with open('/home/golympie/tokenized_dataset.pickle', 'rb') as f:
    train_set, test_set = pickle.load(f)

train_set=train_set[:num_train_samples,:max_length]
test_set=test_set[:num_test_samples,:max_length]
    
# train_set = train_set.to('cuda')
# test_set = test_set.to('cuda')

In [6]:
## Partial train function
def train(
    model,
    run_name,
    do_compile=True,
):
    if do_compile:
        model = torch.compile(
            model,
            # fullgraph=True,
            dynamic=True,
            # mode='max-autotune'
        )
        
    lightning_model, trainer, train_loader, val_loader = prepare_model_and_data(
        model,
        train_data=train_set,
        val_data=test_set,
        epochs=epochs,
        batch_size=batch_size,
        max_length=max_length,
        learning_rate=learning_rate,
        betas=betas,
        alpha=alpha,
        weight_decay=weight_decay,
        warmup_steps=warmup_steps,
        min_lr=min_lr,
        gradient_accumulation_steps=gradient_accumulation_steps,
        log_dir='runs',
        log_name=run_name,
        checkpoint_dir='checkpoints',
        checkpoint_every_n_steps=200,
    )
    trainer.fit(
        lightning_model,
        train_loader,
        val_loader
    )
    
    # Save the model state
    os.makedirs("states", exist_ok=True)
    torch.save(model.state_dict(), f"states/{run_name}.pth")
    
    destruct_module_optimized(model)
    destruct_module_optimized(model)


## Import modules

In [7]:
from modules.archi_modules import StackedMixinForCausalLM, count_parameters
from modules.positionnal_modules import NaivePositionnalEmbedding

from modules.mixin_modules import (
    RNNMixin,
    LSTMMixin,
    MultiScaleRetentionMixin,
    Mamba2Mixin,
    RWKV6Mixin,
    GroupedQuerySelfAttentionMixin,
    MultiHeadLatentAttentionMixin,
)

from modules.ffn_modules import FFN, SparseMoeFFN

## STACK 8 - MLP

In [8]:
num_layers = 4
ffn_module = FFN(hidden_size, intermediate_size)

### GQA

In [None]:
%%time

gqsa = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=GroupedQuerySelfAttentionMixin(hidden_size, num_attention_heads=8, num_key_value_heads=4),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_position_embeddings=max_position_embeddings)
)

count_parameters(gqsa)
train(gqsa,run_name='gqsa', do_compile=True)

### Retentive Network

In [None]:
%%time
retnet = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=MultiScaleRetentionMixin(hidden_size, num_attention_heads=8, num_key_value_heads=1),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_position_embeddings=max_position_embeddings)
)

count_parameters(retnet)
train(retnet,run_name='retnet', do_compile=False)

### Mamba

In [None]:
%%time

mamba = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=Mamba2Mixin(hidden_size = hidden_size, num_attention_heads=8),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_position_embeddings=max_position_embeddings)
)

count_parameters(mamba)
train(mamba,run_name='mamba2', do_compile=True)

### RWKV

In [None]:
%%time

rwkv = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=RWKV6Mixin(hidden_size = hidden_size, num_attention_heads=8),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_position_embeddings=max_position_embeddings)
)

count_parameters(rwkv)
train(rwkv,run_name='rwkv6', do_compile=True)

## Private - This part will only work if you have access to the private part

In [None]:
%%time

from private.hyper_modules import RecurrentHyperLayer#, ParallelScanRecurrentLayer

rhn = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=None,
    ffn_module=RecurrentHyperLayer(
        hidden_size=hidden_size,
        hyper_features=32,
        intermediate_size=intermediate_size,
        dora_rank=8,
        ema_steps = 0,
    ),
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_position_embeddings=max_position_embeddings)
)

count_parameters(rhn)
train(rhn,run_name='rhn', do_compile=True)

Using bfloat16 Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/golympie/miniconda3/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:658: Checkpoint directory /mnt/c/Users/gabol/Desktop/ArchiFactory/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Total parameters: 63,540,480
Embedding parameters: 16,777,216
Mixin parameters: 0
FFN parameters: 29,690,624
LM Head parameters: 16,809,984


/home/golympie/miniconda3/lib/python3.11/site-packages/pytorch_lightning/core/optimizer.py:259: Found unsupported keys in the lr scheduler dict: {'start_epoch'}. HINT: remove them from the output of `configure_optimizers`.

  | Name  | Type            | Params | Mode 
--------------------------------------------------
0 | model | OptimizedModule | 63.5 M | train
--------------------------------------------------
63.5 M    Trainable params
0         Non-trainable params
63.5 M    Total params
254.162   Total estimated model params size (MB)
78        Modules in train mode
0         Modules in eval mode


Sanity Checking: |             | 0/? [00:00<?, ?it/s]

W0810 22:02:30.677000 79840 site-packages/torch/_dynamo/exc.py:514] [7/0_1] Backend compiler exception
W0810 22:02:30.677000 79840 site-packages/torch/_dynamo/exc.py:514] [7/0_1]   Explanation: Backend compiler `inductor` failed with aten._local_scalar_dense.default
W0810 22:02:30.677000 79840 site-packages/torch/_dynamo/exc.py:514] [7/0_1] 
W0810 22:02:30.677000 79840 site-packages/torch/_dynamo/exc.py:514] [7/0_1]     While executing %item : [num_users=5] = call_method[target=item](args = (%getitem,), kwargs = {})
W0810 22:02:30.677000 79840 site-packages/torch/_dynamo/exc.py:514] [7/0_1]     GraphModule: class GraphModule(torch.nn.Module):
W0810 22:02:30.677000 79840 site-packages/torch/_dynamo/exc.py:514] [7/0_1]         def forward(self, s3: "Sym(s3)", s4: "Sym(s4)", L_x_t_: "f32[s3, 512][s4, 1]", L_self_modules_fused_hyper_linear_parameters_weight_: "f32[64000, 32][32, 1]", L_self_modules_fused_hyper_linear_parameters_bias_: "f32[64000][1]", L_h_hyper_: "f32[s3, 32][32, 1]", L_se

Training: |                    | 0/? [00:00<?, ?it/s]

W0810 22:02:40.575000 79840 site-packages/torch/_dynamo/convert_frame.py:964] [8/8] torch._dynamo hit config.recompile_limit (8)
W0810 22:02:40.575000 79840 site-packages/torch/_dynamo/convert_frame.py:964] [8/8]    function: 'forward' (/mnt/c/Users/gabol/Desktop/ArchiFactory/private/hyper_modules.py:65)
W0810 22:02:40.575000 79840 site-packages/torch/_dynamo/convert_frame.py:964] [8/8]    last reason: 8/7: tensor 'x' dtype mismatch. expected Float, actual BFloat16
W0810 22:02:40.575000 79840 site-packages/torch/_dynamo/convert_frame.py:964] [8/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W0810 22:02:40.575000 79840 site-packages/torch/_dynamo/convert_frame.py:964] [8/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html.


## STACK 8 - Moe

In [None]:
num_layers = 6
ffn_module = SparseMoeFFN(
    hidden_size,
    hidden_size*4,
    num_experts=8,
    num_experts_per_tok=2,
    norm_topk_prob=True
)

### GQA MOE

In [None]:
%%time

gqsa_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=GroupedQuerySelfAttentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=9),
    ffn_module=ffn_module,
    positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(gqsa_moe)
train(gqsa_moe,run_name='gqsa-moe', do_compile=True)

### Retentive Network

In [None]:
%%time
retnet_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=MultiScaleRetentionMixin(hidden_size, num_attention_heads=9, num_key_value_heads=3),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(retnet_moe)
train(retnet_moe,run_name='retnet-moe', do_compile=True)

### Mamba MOE

In [None]:
%%time

mamba_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=Mamba2Mixin(hidden_size = hidden_size, num_attention_heads=6),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(mamba_moe)
train(mamba_moe,run_name='mamba-moe', do_compile=True)

### RWKV MOE

In [None]:
%%time

rwkv_moe = StackedMixinForCausalLM(
    num_layers=num_layers,
    hidden_size=hidden_size,
    initializer_range=0.02,
    embedding_module=embed_tokens,
    lm_head_module=lm_head,
    final_norm_module=norm,
    freeze_lm_modules=False,
    vocab_size=vocab_size,
    mixin_module=RWKV6Mixin(hidden_size = hidden_size, num_attention_heads=9),
    ffn_module=ffn_module,
    # positionnal_module=NaivePositionnalEmbedding(hidden_size, max_length=max_length)
)

count_parameters(rwkv_moe)
train(rwkv_moe,run_name='rwkv-moe', do_compile=True)