In [None]:
import numpy as np
import os
from configs import DistillationParams, PathConfig
import torch
from Distiller import IntermediateStateDataset, prepare_distilled_moe, MOEDistillerLightningModule
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping  
from pytorch_lightning.loggers import TensorBoardLogger
import torch.multiprocessing as mp
from ademamix import AdEMAMix
import argparse

from torch_utils import memory_cleanup, destruct_module_optimized
from model_utils import rsetattr, rgetattr, load_model_config, load_weight, map_device, assign_device, get_dataset, get_device_map
from accelerate import init_empty_weights
from modeling_deepseek import DeepseekV3DecoderLayer, DeepseekV3MoE, DeepseekV3ForCausalLM
from tqdm.auto import tqdm
import pickle

In [None]:
mp.set_start_method('spawn', force=True)

os.environ["CUDA_VISIBLE_DEVICES"] = "1"  # Use GPUs 0 and 2
# torch.cuda.set_device("cuda:1")
torch.backends.cuda.max_split_size_mb = 512
torch.set_float32_matmul_precision('medium')

path_config = PathConfig()

In [None]:
n_routed_experts = 8
n_active_experts = 2
learning_rate = 8e-4  # 0.0008
end_factor = 0.1
lora_rank = 16
lora_alpha = 16
device = "cuda:0"
min_layer = 3
max_layer = 61
weights_location = 'deepseek_v3/'

layer_idx=3

In [None]:
dev=int(device.split(':')[1])
print('****************************************')

params = DistillationParams(
    n_epochs=10,
    n_batch=256,
    n_train_batch=232,
    batch_size=8,
    max_length=512,
    gradient_accumulation_steps=1,
    calibration_batches=16,
    learning_rate=learning_rate,
    end_factor=end_factor,
    temperature=1.0,
    lora_type="dora",
    lora_rank=lora_rank,
    lora_alpha=lora_alpha,
    max_workers=8,
    fp8_format="e4m3",
    distiller_device=device,
)

In [None]:
# Load Model Config and Tokenizer
weight_map, config = load_model_config(weights_location)
# Create empty model
with init_empty_weights():
    model = DeepseekV3ForCausalLM(config)

destruct_module_optimized(model)
memory_cleanup()

In [None]:
print("Loading dataset...")
full_dataset = IntermediateStateDataset(path_config, layer_idx, 0, params.n_batch)

train_size = params.n_train_batch
val_size = params.n_batch - params.n_train_batch

train_dataset, val_dataset = torch.utils.data.random_split(full_dataset, [train_size, val_size])
print("Dataset loaded and split.")

In [None]:
train_loader = DataLoader(
    train_dataset,
    batch_size=1,
    num_workers=4,
    shuffle=True,
    persistent_workers=True,
    prefetch_factor=4
)

val_loader = DataLoader(
    val_dataset,
    batch_size=1,
    num_workers=4,
    shuffle=False,
    persistent_workers=True,
    prefetch_factor=4
)

In [None]:
with torch.device(device):
    device_map = get_device_map(layer_idx, weight_map, device)
    model.model.layers[layer_idx] = model.model.layers[layer_idx].to_empty(device=device)
    
    for i, weight_name in enumerate(tqdm(device_map)):
        rsetattr(model, weight_name, load_weight(weights_location, weight_name, weight_map, device))
        if i%100 ==0:
            memory_cleanup()
    
    model.model.layers[layer_idx] = model.model.layers[layer_idx].to(device)
    memory_cleanup()
    
    with open(f"{path_config.expert_activation_dir}/layer_{layer_idx}.pickle", "rb") as f:
        act = pickle.load(f)
    
    v,c = np.unique(act, return_counts=True)
    selected_experts = np.flip(np.argsort(c))
    
    
    pl_model = MOEDistillerLightningModule(
        weight_map,
        path_config,
        params,
        layer_idx=layer_idx,
        n_routed_experts=n_routed_experts,
        n_active_experts=n_active_experts,
        weights_location=weights_location,
    )
    
    pl_model.distillat=prepare_distilled_moe(
        model.model.layers[layer_idx].mlp,
        selected_experts,
        n_routed_experts,
        n_active_experts,
        params,
        device=device
    )

    destruct_module_optimized(model)
    memory_cleanup()

In [None]:
logger = TensorBoardLogger(
    path_config.log_dir,
    name=f"lightning_logs_layer_{layer_idx}_{n_routed_experts}a{n_active_experts}"
)

checkpoint_callback = ModelCheckpoint(
    dirpath=path_config.checkpoint_dir,
    filename=f"moe_distiller_layer_{layer_idx}_{n_routed_experts}a{n_active_experts}",
    save_top_k=1,
    verbose=True,
    monitor="val_loss",
    mode="min"
)

early_stop_callback = EarlyStopping(
    monitor='val_loss',  # Monitor validation loss
    min_delta=0.001,       # Minimum change in the monitored quantity to qualify as an improvement
    patience=2,          # Number of epochs with no improvement after which training will be stopped
    verbose=True,        # Print a message when training is stopped
    mode='min'            # Training will stop when the quantity monitored has stopped decreasing
)

print("Logger and checkpoint callback setup.")

print(dev)
trainer = pl.Trainer(
    max_epochs=params.n_epochs,
    accelerator="cuda",
    devices=[0],
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    precision="bf16-mixed",
    gradient_clip_val=1.0,
    accumulate_grad_batches=params.gradient_accumulation_steps,
    enable_progress_bar=True,
    log_every_n_steps=1,
    # strategy="ddp"  # Add strategy for multi-gpu training
)