In [1]:
%load_ext autoreload
%autoreload 2
    
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, SpikeDetection
from pytorch_lightning.loggers import TensorBoardLogger 
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    get_linear_schedule_with_warmup,
    HqqConfig
)
from peft import LoraConfig, get_peft_model
# from ademamix import AdEMAMix
import bitsandbytes as bnb
from bitsandbytes.optim.ademamix import AdEMAMix8bit as AdEMAMix
from bitsandbytes.optim.adamw import AdamW8bit as AdamW

from datasets import load_dataset
import os
import numpy as np
from collections import deque
import gc
import argparse
from tqdm.auto import tqdm

from torch_utils import memory_cleanup, count_parameters
from liger_kernel.transformers import apply_liger_kernel_to_llama

from pytorch_lightning.strategies import FSDPStrategy
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp
from torch.distributed.fsdp import MixedPrecision

from functools import partial
from fsdp_utils import fsdp_hqq_dora_model_for_causal_lm, get_wrapping_policy
from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed as dist
from model_utils import rsetattr

class HealingDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }

def load_and_prepare_data(tokenizer, batch_size=8, max_length=512, num_workers=os.cpu_count(), train_sample_limit=None, val_sample_limit=None):
    dataset = load_dataset(
        "cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1"
    )["train"]

    # Apply sample limits if provided
    if train_sample_limit is not None:
        train_dataset = dataset.select(range(train_sample_limit))  # Use .select for efficiency
    else:
        train_dataset = dataset

    if val_sample_limit is not None:
        val_dataset = dataset.select(range(train_sample_limit, train_sample_limit+val_sample_limit)) # Use .select for efficiency
    else:
        val_dataset = dataset

    train_dataset = train_dataset["messages"]
    val_dataset = val_dataset["messages"]
    
    train_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(train_dataset, desc="Preparing dataset train")
    ]

    val_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(val_dataset, desc="Preparing dataset train")
    ]

    train_dataset = HealingDataset(
        train_dataset, tokenizer, max_length=max_length
    )
    val_dataset = HealingDataset(
        val_dataset, tokenizer, max_length=max_length
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return train_loader, val_loader

## Load Base Model

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TORCHDYNAMO_DISABLE_GRAPH_CAPTURE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

torch.set_float32_matmul_precision('medium')
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True)

weights_location="deepseek_v3"
n_routed_experts=8
n_active_experts=4
epochs=1
batch_size=1
max_length=32

learning_rate=3e-5
train_sample_limit=32000
val_sample_limit=512
warmup_steps=128
compilation=False
checkpoint_every_n_steps=1024
accumulate_grad_batches=1


model_name=f"/home/golympie/ai-toolbox/{weights_location}_{n_routed_experts}a{n_active_experts}" ## i displaced the model on a faster disc for increased loading speed.

log_name=f"{weights_location}_{n_routed_experts}a{n_active_experts}"
log_dir="pl_logs"

tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/DeepSeek-V3", trust_remote_code=True
)

# Load and prepare data
train_loader, val_loader = load_and_prepare_data(
    tokenizer, batch_size=batch_size, max_length=max_length,
    train_sample_limit=train_sample_limit, val_sample_limit=val_sample_limit
)

memory_cleanup()

# Calculate total steps for the scheduler
total_steps = len(train_loader) * epochs

target_modules=[]
for i in range(n_routed_experts):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

target_modules.append('mlp.gate.weight')

model=fsdp_hqq_dora_model_for_causal_lm(
    model_name,
    target_modules=target_modules,
    lora_rank=4,
    lora_alpha=4,
    lora_dropout=0.1,
    n_workers=8
)

memory_cleanup()
count_parameters(model)

config_dict = model.config.to_dict()

if compilation:
    print('compile model')
    model = torch.compile(model)
    
memory_cleanup()

Preparing dataset train:   0%|          | 0/32000 [00:00<?, ?it/s]

Preparing dataset train:   0%|          | 0/512 [00:00<?, ?it/s]

Loading HF Model to CPU


Loading checkpoint shards:   0%|          | 0/22 [00:00<?, ?it/s]

Quantizing linear layers


Quantizing modules:   0%|          | 0/3537 [00:00<?, ?it/s]

## Splitting layer between cuda 0 and cuda 1

In [None]:
from hqq.core.quantize import BaseQuantizeConfig, HQQBackend, HQQLinear

In [None]:
cuda_1_layer = [f"layers.{i}" for i in range(10,42)]

for name, parameter in tqdm(model.named_parameters()):
    rsetattr(model, name, torch.nn.Parameter(parameter.to('cpu')))

for name, module in tqdm(model.named_modules()):
    if isinstance(module, HQQLinear):
        module.device="cpu"
        module.meta['scale']= module.meta['scale'].to('cpu')
        module.meta['zero']= module.meta['zero'].to('cpu')

model=model.to('cpu')
memory_cleanup()

In [None]:
memory_cleanup()

In [None]:
cuda_1_layer = [f"layers.{i}" for i in range(10,45)]

for name, parameter in tqdm(model.named_parameters()):
    cond=False
    for elt in cuda_1_layer:
        if elt in name:
            cond=True
    
    if cond:
        rsetattr(model, name, torch.nn.Parameter(parameter.to('cuda:1')))
    else:
        rsetattr(model, name, torch.nn.Parameter(parameter.to('cuda:0')))

for name, module in tqdm(model.named_modules()):
    if isinstance(module, HQQLinear):
        cond=False
        for elt in cuda_1_layer:
            if elt in name:
                cond=True
        if cond:
            module.device="cuda:1"
            module.meta['scale']= module.meta['scale'].to('cuda:1')
            module.meta['zero']= module.meta['zero'].to('cuda:1')
        else:
            module.device="cuda:0"
            module.meta['scale']= module.meta['scale'].to('cuda:0')
            module.meta['zero']= module.meta['zero'].to('cuda:0')
memory_cleanup()

In [None]:
memory_cleanup()

In [None]:
text="this is a very interesting text, mmmmmm"

max_length=256
encoding = tokenizer(
    [text],
    max_length=max_length,
    padding="max_length",
    truncation=True,
    return_tensors="pt",
).to('cuda:0')

# encoding['attention_mask']=None

output=model(
    input_ids=encoding['input_ids'],
    labels=encoding['input_ids'],
    attention_mask=encoding['attention_mask'],
    use_cache=False,
    output_attentions=False,
    output_hidden_states=False
)

loss=output.loss

del output
memory_cleanup()