In [None]:
import torch
import torch.optim as optim
from bitsandbytes.optim.ademamix import AdEMAMix

# Define a simple model to compare optimizer memory footprints
model = torch.nn.Linear(10, 1)

# Initialize optimizers
adamw = optim.AdamW(model.parameters())
sgd = optim.SGD(model.parameters())
adam = optim.Adam(model.parameters())
ademamix = AdEMAMix(model.parameters())


# Function to calculate memory footprint
def calculate_memory_footprint(optimizer):
    total_memory = 0
    for param in model.parameters():
        total_memory += param.element_size() * param.nelement()
        if param in optimizer.state:
            for value in optimizer.state[param].values():
                if isinstance(value, torch.Tensor):
                    total_memory += value.element_size() * value.nelement()
    return total_memory


# Calculate memory footprints
adamw_memory = calculate_memory_footprint(adamw)
sgd_memory = calculate_memory_footprint(sgd)
adam_memory = calculate_memory_footprint(adam)
ademamix_memory = calculate_memory_footprint(ademamix)

# Calculate relative memory footprints
relative_memory_footprints = {
    "AdamW": adamw_memory,
    "SGD": sgd_memory,
    "Adam": adam_memory,
    "AdEMAMix": ademamix_memory,
}

print(relative_memory_footprints)

In [None]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    ModelCheckpoint,
    EarlyStopping,
    Callback,
    SpikeDetection,
)
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    get_linear_schedule_with_warmup,
)
from peft import LoraConfig, get_peft_model
from ademamix import AdEMAMix
from datasets import load_dataset
import os
import numpy as np
from collections import deque
import gc
import argparse
from tqdm.auto import tqdm

from torch_utils import memory_cleanup, count_parameters
from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()

torch.set_float32_matmul_precision("medium")

In [None]:
dataset = load_dataset("cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1")

In [None]:
class HealingDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }


def load_and_prepare_data(
    tokenizer,
    batch_size=8,
    max_length=512,
    num_workers=os.cpu_count(),
    train_sample_limit=None,
    val_sample_limit=None,
):
    dataset = load_dataset("cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1")["train"]

    # Apply sample limits if provided
    if train_sample_limit is not None:
        train_dataset = dataset.select(range(train_sample_limit))  # Use .select for efficiency
    else:
        train_dataset = dataset

    if val_sample_limit is not None:
        val_dataset = dataset.select(
            range(train_sample_limit, train_sample_limit + val_sample_limit)
        )  # Use .select for efficiency
    else:
        val_dataset = dataset

    train_dataset = train_dataset["messages"]
    val_dataset = val_dataset["messages"]

    train_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(train_dataset, desc="Preparing dataset train")
    ]

    val_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(val_dataset, desc="Preparing dataset train")
    ]

    train_dataset = HealingDataset(train_dataset, tokenizer, max_length=max_length)
    val_dataset = HealingDataset(val_dataset, tokenizer, max_length=max_length)

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    return train_loader, val_loader

In [None]:
weights_location = "deepseek_v3"
n_routed_experts = 8
n_active_experts = 4

model_name = f"/home/golympie/ai-toolbox/{weights_location}_{n_routed_experts}a{n_active_experts}"  ## i displaced the model on a faster disc for increased loading speed.

bnb_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_quant_storage=torch.bfloat16,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
    quantization_config=bnb_config,
)

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3", trust_remote_code=True)

target_modules = []
for i in range(n_routed_experts):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

target_modules.append("mlp.gate.weight")  ## Add all gate at once

lora_config = LoraConfig(
    use_dora=True,
    r=8,  # Rank of the low-rank matrices
    lora_alpha=8,  # Scaling factor
    target_modules=target_modules,
    lora_dropout=0.1,  # Dropout rate
    bias="none",  # Whether to add bias
    task_type="CAUSAL_LM",  # Task type (Causal Language Modeling),
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
count_parameters(model)

In [None]:
class LightningModel(pl.LightningModule):
    def __init__(
        self,
        model,
        tokenizer,
        learning_rate=2e-5,
        warmup_steps=100,
        total_steps=1000,
        min_lr=1e-6,
        compilation=True,
    ):
        super().__init__()
        self.config_dict = model.config.to_dict()

        self.model = model

        if compilation:
            print("compile model")
            self.model = torch.compile(self.model)

        self.tokenizer = tokenizer
        self.learning_rate = learning_rate
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr  # Minimum learning rate for cosine annealing

        # Initialize EMA tracking
        self.loss_ema = None
        self.ema_alpha = 2.0 / (10 + 1)  # Alpha for 10-step EMA

        self.save_hyperparameters(ignore=["tokenizer", "config"])
        self.hparams.update({"model_config": self.config_dict})
        memory_cleanup()

    def forward(self, input_ids, attention_mask=None, labels=None):
        return self.model(input_ids, attention_mask=attention_mask, labels=labels)

    def training_step(self, batch, batch_idx):
        outputs = self(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["input_ids"],
        )
        loss = outputs.loss

        # Update EMA
        if self.loss_ema is None:
            self.loss_ema = loss.item()
        else:
            self.loss_ema = (1 - self.ema_alpha) * self.loss_ema + self.ema_alpha * loss.item()

        # Replace loss with EMA if it's more than 2.5x the EMA
        original_loss = loss.item()
        if original_loss > 1.5 * self.loss_ema:
            # Create a new tensor with the EMA value, maintaining gradients
            loss = loss * (self.loss_ema / original_loss)
            self.log("loss_clipped", True, on_step=True, prog_bar=True, sync_dist=True)
        else:
            self.log("loss_clipped", False, on_step=True, prog_bar=True, sync_dist=True)

        # Log losses
        self.log(
            "train_loss",
            loss,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.log(
            "train_loss_ema",
            self.loss_ema,
            on_step=True,
            on_epoch=True,
            prog_bar=True,
            sync_dist=True,
        )
        self.log(
            "train_loss_original",
            original_loss,
            on_step=True,
            prog_bar=True,
            sync_dist=True,
        )

        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self(
            batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["input_ids"],
        )
        loss = outputs.loss
        self.log("val_loss", loss, on_epoch=True, prog_bar=True, sync_dist=True)
        return loss

    def configure_optimizers(self):
        optimizer = AdEMAMix(self.parameters(), lr=self.learning_rate)
        # Use CosineAnnealingLR scheduler
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=self.total_steps,  # Adjusted total steps
            eta_min=self.min_lr,  # Minimum learning rate
        )

        if self.warmup_steps > 0:

            def lr_lambda(step):
                if step < self.warmup_steps:
                    return float(step) / float(max(1, self.warmup_steps))
                return scheduler.get_lr()[0] / self.learning_rate  # Scale by initial LR

            warmup_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

            return [optimizer], [
                {"scheduler": warmup_scheduler, "interval": "step"},  # Warmup scheduler
                {
                    "scheduler": scheduler,
                    "interval": "step",
                    "start_epoch": self.warmup_steps,
                },
            ]

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "step",
            },
        }

In [None]:
epochs = 1
batch_size = 4
max_length = 512

learning_rate = 3e-5
train_sample_limit = 32000
val_sample_limit = 512
warmup_steps = 128
compilation = True

In [None]:
train_loader, val_loader = load_and_prepare_data(
    tokenizer,
    batch_size=batch_size,
    max_length=max_length,
    train_sample_limit=train_sample_limit,
    val_sample_limit=val_sample_limit,
)

# Calculate total steps for the scheduler
total_steps = len(train_loader) * epochs

pl_model = LightningModel(
    model,
    tokenizer,
    learning_rate=learning_rate,
    warmup_steps=warmup_steps,
    total_steps=total_steps,
    compilation=compilation,
)

# Callbacks
os.makedirs("checkpoints_full/", exist_ok=True)
os.makedirs("checkpoints_full/" + log_name, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints_full/" + log_name,  # Specify the directory to save checkpoints
    filename="{epoch}-{step}-{train_loss:.4f}",  # Define the checkpoint filename format
    save_on_train_epoch_end=False,  # Allow saving during training steps
    every_n_train_steps=checkpoint_every_n_steps,  # Now saves every n steps
    save_top_k=3,  # Save all checkpoints (-1 means keep all)
    monitor="train_loss",
    save_last=True,  # Save the last checkpoint
)

logger = TensorBoardLogger(log_dir, name=log_name)


trainer = pl.Trainer(
    max_epochs=epochs,
    accelerator="auto",
    devices=2,
    strategy="fsdp",
    callbacks=[
        checkpoint_callback,
    ],
    accumulate_grad_batches=accumulate_grad_batches,
    logger=logger,
    precision="bf16-mixed",
    log_every_n_steps=1,
)

trainer.fit(pl_model, train_loader, val_loader)

In [None]:
%load_ext autoreload
%autoreload 2

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from modeling_deepseek import DeepseekV3ForCausalLM
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from memory_utils import count_parameters, memory_cleanup
from transformers import BitsAndBytesConfig
from ademamix import AdEMAMix
from tqdm.auto import tqdm
import json

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from bitsandbytes.optim import AdEMAMix8bit

from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()

# from Distiller import count_parameters
torch.set_float32_matmul_precision("medium")

n_experts = 4
n_active_experts = 1

model_name = f"DeepSeek-V3-{n_experts}@{n_active_experts}-unhealed-v0.1"

## Distribution strategy, with base weights on one gpu, and experts on the other

In [None]:
with open(f"{model_name}/model.safetensors.index.json") as f:
    weights_map = json.load(f)["weight_map"]

device_map = {}

for elt in weights_map:
    if ".layers." in elt:
        device_map[elt] = "cuda:1"
    else:
        device_map[elt] = "cuda:0"

In [None]:
bnb_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_quant_storage=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

model = DeepseekV3ForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=device_map,  ## This should distribute automatically on all cpu
    quantization_config=bnb_config,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token


model.train()
count_parameters(model)
memory_cleanup()

for name, parameter in model.named_parameters():
    parameter.trainable = False

## Pushing Linear layer as 4bit

## Add lora layer on top of the experts and gate, freeze everything else

In [None]:
target_modules = []

## Adapt only non shared experts
for i in range(n_experts):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

target_modules.append("mlp.gate.weight")  ## Add all gate at once

lora_config = LoraConfig(
    # use_dora=True,
    r=16,  # Rank of the low-rank matrices
    lora_alpha=16,  # Scaling factor
    target_modules=target_modules,
    lora_dropout=0.1,  # Dropout rate
    bias="none",  # Whether to add bias
    task_type="CAUSAL_LM",  # Task type (Causal Language Modeling),
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
count_parameters(model)

## Loading the dataset for post training

In [None]:
n_train = 65536
# n_train=4096
n_val = 256

dolphin_r1 = load_dataset(
    "cognitivecomputations/dolphin-r1",
    "nonreasoning",
    split=f"train[:{n_train+n_val}]",
    cache_dir="../dolphin-r1",
)

max_length = 64

train_dataset = dolphin_r1.select_columns(["messages"]).select(list(range(n_train)))
val_dataset = dolphin_r1.select_columns(["messages"]).select(list(range(n_train, n_train + n_val)))


def tokenize_function(examples):
    formatted = tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)
    data = tokenizer(formatted, truncation=True, max_length=max_length, padding=True)
    return data


train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=["messages"]).with_format("torch")
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=["messages"]).with_format("torch")

## Set training hyperparameters

In [None]:
num_train_epochs = 1

per_device_train_batch_size = 2
per_device_eval_batch_size = 2

gradient_accumulation_steps = 8  # Added gradient accumulation steps

learning_rate = 1e-4

weight_decay = 0.01
logging_steps = 5
seed = 3407

output_dir = "outputs"

In [None]:
# Format datasets
num_epochs = 1


# model=torch.compile(model)
train_dataset.set_format("torch", columns=["input_ids", "attention_mask"])
val_dataset.set_format("torch", columns=["input_ids", "attention_mask"])

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=per_device_train_batch_size, shuffle=True)

# Set up model and move to device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize optimizer
optimizer = AdEMAMix8bit(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Initialize scheduler
num_training_steps = len(train_dataloader) * num_epochs // gradient_accumulation_steps  # Adjusted num_training_steps
scheduler = CosineAnnealingLR(optimizer, T_max=num_training_steps, eta_min=1e-6)

# Initialize TensorBoard writer
writer = SummaryWriter(f"runs/{model_name}")

# Training loop
model.train()
global_step = 0


for epoch in range(num_epochs):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
    running_loss = 0.0

    optimizer.zero_grad()  # Initialize gradients to zero at the beginning of each accumulation step

    for batch_idx, batch in enumerate(progress_bar):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Forward pass
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
        loss = outputs.loss
        loss = loss / gradient_accumulation_steps  # Normalize loss for gradient accumulation
        loss.backward()

        # Accumulate gradients and step optimizer every gradient_accumulation_steps
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()  # Step the scheduler after optimizer step
            optimizer.zero_grad()  # Reset gradients after optimizer step

        # Update progress bar
        running_loss += loss.item() * gradient_accumulation_steps  # Revert loss normalization for correct average loss
        avg_loss = running_loss / (batch_idx + 1)
        current_lr = scheduler.get_last_lr()[0]

        progress_bar.set_postfix({"loss": f"{avg_loss:.4f}", "lr": f"{current_lr:.2e}"})

        # Log metrics to TensorBoard
        writer.add_scalar(
            "Training/Loss", loss.item() * gradient_accumulation_steps, global_step
        )  # Revert loss normalization for logging
        writer.add_scalar("Training/Learning_Rate", current_lr, global_step)

        # Log every 10 steps (after accumulation steps)
        if batch_idx % 10 == 0:
            print(
                f"Step {batch_idx}, Loss: {loss.item() * gradient_accumulation_steps:.4f}, LR: {current_lr:.2e}"
            )  # Revert loss normalization for printing

        global_step += 1

    # Log epoch-level metrics
    writer.add_scalar("Training/Epoch_Loss", avg_loss, epoch)

# Close TensorBoard writer
writer.close()

model.save_pretrained(f"{model_name}-healing-lora")

## Merge the unhealed model with its adapter

In [None]:
from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(model_name)
peft_model_id = f"{model_name}-healing-lora"

model = PeftModel.from_pretrained(base_model, peft_model_id, dtype=torch.bfloat16)
model.merge_and_unload()

In [None]:
new_model_name = f"DeepSeek-V3-{n_experts}@{n_active_experts}-Pruned"
model.save_pretrained(new_model_name)

In [None]:
model.model.save_pretrained("deepseek_v2_lite_chat_16@4")

## Test generation capabilities

In [None]:
model.train()

In [None]:
from transformers import TextStreamer

streamer = TextStreamer(tokenizer, skip_prompt=True)

model.generation_config.pad_token_id = model.generation_config.eos_token_id

messages = [{"role": "user", "content": "Write a piece of quicksort code in C++"}]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_tensor = tokenizer(input_tensor, return_tensors="pt").to("cuda:0")

out = model.generate(**input_tensor, streamer=streamer, temperature=0.01, max_new_tokens=64)