In [None]:
%load_ext autoreload
%autoreload 2
    
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback, SpikeDetection
from pytorch_lightning.loggers import TensorBoardLogger 
from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
    get_linear_schedule_with_warmup,
    HqqConfig
)
from peft import LoraConfig, get_peft_model
# from ademamix import AdEMAMix
import bitsandbytes as bnb

from datasets import load_dataset
import os
import numpy as np
from collections import deque
import gc
import argparse
from tqdm.auto import tqdm

from torch_utils import memory_cleanup, count_parameters
from liger_kernel.transformers import apply_liger_kernel_to_llama

from pytorch_lightning.strategies import FSDPStrategy
import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp
from torch.distributed.fsdp import MixedPrecision

from functools import partial
from fsdp_utils import fsdp_hqq_dora_model_for_causal_lm, get_wrapping_policy
from torch.distributed.fsdp.api import BackwardPrefetch, CPUOffload, ShardingStrategy

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
import torch.distributed as dist
from model_utils import rsetattr

class HealingDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }

def load_and_prepare_data(tokenizer, batch_size=8, max_length=512, num_workers=os.cpu_count(), train_sample_limit=None, val_sample_limit=None):
    dataset = load_dataset(
        "cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1"
    )["train"]

    # Apply sample limits if provided
    if train_sample_limit is not None:
        train_dataset = dataset.select(range(train_sample_limit))  # Use .select for efficiency
    else:
        train_dataset = dataset

    if val_sample_limit is not None:
        val_dataset = dataset.select(range(train_sample_limit, train_sample_limit+val_sample_limit)) # Use .select for efficiency
    else:
        val_dataset = dataset

    train_dataset = train_dataset["messages"]
    val_dataset = val_dataset["messages"]
    
    train_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(train_dataset, desc="Preparing dataset train")
    ]

    val_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(val_dataset, desc="Preparing dataset train")
    ]

    train_dataset = HealingDataset(
        train_dataset, tokenizer, max_length=max_length
    )
    val_dataset = HealingDataset(
        val_dataset, tokenizer, max_length=max_length
    )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    )

    return train_loader, val_loader

## Load Base Model

In [None]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
os.environ["TORCHDYNAMO_DISABLE_GRAPH_CAPTURE"] = "1"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"

torch.set_float32_matmul_precision('medium')
# torch.backends.cuda.enable_flash_sdp(True)
# torch.backends.cuda.enable_mem_efficient_sdp(True)

weights_location="deepseek_v3"
n_routed_experts=8
n_active_experts=4
epochs=1
batch_size=1
max_length=32

learning_rate=3e-5
train_sample_limit=32000
val_sample_limit=512
warmup_steps=128
compilation=False
checkpoint_every_n_steps=1024
accumulate_grad_batches=1


model_name=f"/home/golympie/ai-toolbox/{weights_location}_{n_routed_experts}a{n_active_experts}" ## i displaced the model on a faster disc for increased loading speed.

log_name=f"{weights_location}_{n_routed_experts}a{n_active_experts}"
log_dir="pl_logs"

tokenizer = AutoTokenizer.from_pretrained(
    "deepseek-ai/DeepSeek-V3", trust_remote_code=True
)

# Load and prepare data


memory_cleanup()

In [None]:
quant_config = BitsAndBytesConfig(
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_quant_storage=torch.bfloat16,
)


import json
with open(model_name+"/model.safetensors.index.json", "r") as f:
    weight_map=json.loads(f.read())['weight_map']

# quant_config  = HqqConfig(nbits=4, group_size=64)

device_map={}
for elt in weight_map:
    if "lm_head" in elt:
        device_map[elt]="cuda:1"
    elif "model.embed_tokens" in elt:
        device_map[elt]="cuda:0"
    elif "model.norm" in elt:
        device_map[elt]="cuda:0"
    else:
        i = int(elt.split('.')[2])
        if i < 30:
            device_map[elt]="cuda:0"
        else:
            device_map[elt]="cuda:1"
        
model=AutoModelForCausalLM.from_pretrained(
    model_name,
    device_map=device_map,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    offload_buffers=True,
    quantization_config=quant_config,
)

## Adding dora layers

In [None]:
class DORALayer(nn.Module):
    "Same as LORA but also returnes weight norm. This will be wrapped as a single FSDP unit"
    def __init__(self, in_features, out_features, lora_rank, device, dtype, *args, **kwargs):
        super().__init__()
        # Init LoRA layers.
        std_dev = 1 / torch.sqrt(torch.tensor(lora_rank).float()).to(device=device, dtype=torch.bfloat16)
        
        lora_A_param = nn.Parameter(torch.randn(lora_rank, in_features, device=device, dtype=dtype) * std_dev)
        self.lora_A = nn.Linear(in_features, lora_rank, bias=False, device=device, dtype=dtype)
        setattr(self.lora_A, "weight", lora_A_param)

        self.lora_B = nn.Linear(lora_rank, out_features, bias=False, device=device, dtype=dtype)
        self.lora_B.weight.data.zero_()

    def forward(self, x, frozen_weight):
        output = self.lora_B(self.lora_A(x))
        column_norm = (frozen_weight + self.lora_B.weight @ self.lora_A.weight).norm(p=2, dim=1).detach()
        return output, column_norm

class MagnitudeLayer(nn.Module):
    "FSDP doesn't work with nn.ParameterDict hence this module: https://github.com/pytorch/pytorch/issues/79605"
    def __init__(self, vector_data, device, dtype):
        super().__init__()
        self.magnitude = nn.Parameter(vector_data.to(device=device, dtype=dtype))

    def forward(self, x):
        return x * self.magnitude.view(1, 1, -1)
        
class BNBDORA(nn.Module):
    def __init__(self, base_layer, lora_rank, *args, **kwargs):
        super().__init__()
        self.base_layer = base_layer
        dtype = getattr(base_layer, "compute_dtype", next(base_layer.parameters()).dtype)
        device = next(base_layer.parameters()).device
        
        # Init trainable magnitude parameter.
        self.magnitude_layer = MagnitudeLayer(self.base_layer.dora_scale.clone().to(dtype=dtype), device, dtype)
        self.base_layer.dora_scale = None
        torch.cuda.empty_cache()
        
        # Init DORA layers.
        self.dora_layer = DORALayer(base_layer.in_features, base_layer.out_features, lora_rank, device, dtype, *args, **kwargs)

    def forward(self, x, *args, **kwargs):
        result = self.base_layer(x, *args, **kwargs)
        result = result.clone()

        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            expected_dtype = result.dtype
            x = x.to(self.dora_layer.lora_A.weight.dtype)

        # m * (W + AB / ||W + AB||) @ X == m * ((W @ X + AB @ X) / ||W + AB||)
        output, column_norm = self.dora_layer(x, bnb.functional.dequantize_4bit(self.base_layer.weight.data, 
                                                                                self.base_layer.weight.quant_state))
        if requires_conversion:
            output = output.to(expected_dtype)
        
        result += output        
        result = result / column_norm.view(1,1,-1) #unit vector result.
        result = self.magnitude_layer(result) #rescaled result.
        return result

class LORA(nn.Module):
    def __init__(self, base_layer, lora_rank, lora_alpha, lora_dropout):
        super().__init__()
        self.base_layer = base_layer
        dtype = getattr(base_layer, "compute_dtype", next(base_layer.parameters()).dtype)
        device = next(base_layer.parameters()).device
        lora_A = nn.Linear(base_layer.in_features, lora_rank, bias=False, device=device, dtype=dtype)
        lora_B = nn.Linear(lora_rank, base_layer.out_features, bias=False, device=device, dtype=dtype)
        lora_B.weight.data.zero_()

        self.lora_AB = nn.Sequential(lora_A, lora_B)

        self.lora_alpha = lora_alpha
        self.lora_dropout = nn.Dropout(lora_dropout)
        self.scaling = self.lora_alpha / lora_rank

    def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:

        result = self.base_layer(x, *args, **kwargs)
        # As per Tim Dettmers, for 4bit, we need to defensively clone here.
        # The reason is that in some cases, an error can occur that backprop
        # does not work on a manipulated view. This issue may be solved with
        # newer PyTorch versions but this would need extensive testing to be
        # sure.
        result = result.clone()

        requires_conversion = not torch.is_autocast_enabled()
        if requires_conversion:
            expected_dtype = result.dtype
            x = x.to(next(iter(self.lora_AB)).weight.dtype)

        output = self.lora_AB(self.lora_dropout(x))
        if requires_conversion:
            output = output.to(expected_dtype)
        output = output * self.scaling

        result += output

        return result

In [None]:
target_modules=[]
for i in range(n_routed_experts):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

# target_modules.append('mlp.gate.weight')

lora_rank=16

for name, module in tqdm(model.named_modules()):
    if isinstance(module, bnb.nn.Linear4bit):
        cond=False
        for target in target_modules:
            if target in name:
                cond=True
        if cond:
            # dora_scale = module.weight.norm(p=2, dim=1).to(dtype=torch.bfloat16)
            # rsetattr(
            #     model,
            #     name+".dora_scale",
            #     dora_scale
            # )
            rsetattr(
                model,
                name,
                LORA(module,lora_rank, lora_rank, lora_dropout=0.1)
            )


for name, params in tqdm(model.named_parameters()):
    if any([lora_name in name for lora_name in ['lora_AB', 'lora_A', 'lora_B', 'magnitude', 'mlp.gate.weight']]):
        params.requires_grad = True
    else:
        params.requires_grad = False

In [None]:
count_parameters(model)

## Splitting layer between cuda 0 and cuda 1

In [None]:
def load_and_prepare_data(tokenizer, batch_size=8, max_length=512, num_workers=os.cpu_count(), train_sample_limit=None, val_sample_limit=None):
    dataset = load_dataset(
        "cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1"
    )["train"]

    def filter_function(example):
        if example["overall_quality"] is not None and example["overall_quality"] == 5:
            return True
        if example["score"] is not None and example["score"] >= 0.2:
            return True
        return False

    dataset = dataset.filter(filter_function)
    
    # Apply sample limits if provided
    if train_sample_limit is not None:
        train_dataset = dataset.select(range(train_sample_limit))  # Use .select for efficiency
    else:
        train_dataset = dataset

    # if val_sample_limit is not None:
    #     val_dataset = dataset.select(range(train_sample_limit, train_sample_limit+val_sample_limit)) # Use .select for efficiency
    # else:
    #     val_dataset = dataset

    train_dataset = train_dataset["messages"]
    # val_dataset = val_dataset["messages"]
    
    train_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(train_dataset, desc="Preparing dataset train")
    ]

    # val_dataset = [
    #     tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
    #     for elt in tqdm(val_dataset, desc="Preparing dataset train")
    # ]

    train_dataset = HealingDataset(
        train_dataset, tokenizer, max_length=max_length
    )
    # val_dataset = HealingDataset(
    #     val_dataset, tokenizer, max_length=max_length
    # )

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    # val_loader = DataLoader(
    #     val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True
    # )

    return train_loader, None


In [None]:
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
import torch
from ademamix import AdEMAMix
from torch.optim.lr_scheduler import _LRScheduler
import math

class WarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch=-1)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            return [base_lr * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
        else:
            # Cosine annealing phase
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
            decay_factor = (1 - self.min_lr) * cosine_decay + self.min_lr
            return [base_lr * decay_factor for base_lr in self.base_lrs]
            
# Assuming model, tokenizer, and load_and_prepare_data are defined elsewhere

total_steps = 32
max_length = 128

num_epochs=1
num_sample = 8192

batch_size = 4
gradient_accumulation_steps = 1
log_interval = gradient_accumulation_steps  # Log every 10 steps

lr = 5e-3
# Initialize the SummaryWriter
writer = SummaryWriter(log_dir='runs/experiment_1')

train_loader, val_loader = load_and_prepare_data(
    tokenizer, batch_size=batch_size, max_length=max_length,
    train_sample_limit=None, val_sample_limit=None
)

optimizer = AdEMAMix(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr,
    betas=(0.9, 0.999, 0.9999),
    alpha=8.0 #batch size is small so increazing alpha to smooth gradient
)

scheduler = WarmupCosineAnnealingLR(
    optimizer,
    warmup_steps=128,
    total_steps=len(train_loader) // gradient_accumulation_steps,
    min_lr=lr/100
)


model.train()  # Ensure the model is in training mode

for epoch in range(num_epochs):  # Assuming num_epochs is defined
    for i, encoding in enumerate(tqdm(train_loader)):
        input_ids = encoding['input_ids'].to("cuda:0")
        attention_mask = encoding['attention_mask'].to("cuda:0")

        # Forward pass
        output = model(
            input_ids=input_ids,
            labels=input_ids,  # Assuming labels are the same as input_ids for this task
            attention_mask=attention_mask,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False
        )

        # Compute loss and backpropagate
        loss = output.loss
        loss.backward()

        # Update model parameters and learning rate
        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Log loss and learning rate to TensorBoard
        if (i + 1) % log_interval == 0:
            global_step = epoch * len(train_loader) + i
            writer.add_scalar('Loss/train', loss.item(), global_step)
            writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], global_step)

# Close the writer
writer.close()

In [None]:
from transformers import TextStreamer

streamer = TextStreamer(tokenizer, skip_prompt=True)

model.generation_config.pad_token_id = model.generation_config.eos_token_id

messages = [{"role": "user", "content": "Write a piece of quicksort code in C++"}]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_tensor = tokenizer(input_tensor, return_tensors="pt").to("cuda:0")

out = model.generate(**input_tensor, streamer=streamer, temperature=0.01, max_new_tokens=64, do_sample=True)

In [None]:
torch.cuda.synchronize(0)
torch.cuda.synchronize(1)