In [None]:
from torch_utils import load_intermediate_state
from configs import PathConfig

path_config=PathConfig()

In [None]:
for layer_idx in range(61):
    x = load_intermediate_state(path_config, layer_idx , 0, batch_size=8)
    print(layer_idx, float(torch.max(x)), float(torch.mean(x)), float(torch.min(x)), torch.isnan(x).sum().item())

In [None]:
%load_ext autoreload
%autoreload 2
import gc
import os
import pickle

import numpy as np
import torch

from datasets import load_dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoConfig
from liger_kernel.transformers import apply_liger_kernel_to_llama
from configs import GenerationParams, PathConfig, DistillationParams
from torch_utils import save_intermediate_state, save_midlayer_state, load_intermediate_state, load_midlayer_state, destruct_module_optimized, memory_cleanup
from modeling_deepseek import DeepseekV3DecoderLayer, DeepseekV3MoE, DeepseekV3ForCausalLM
import torch
import json
from accelerate import init_empty_weights
import functools
from safetensors import safe_open

import concurrent
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor
from threading import Semaphore
from model_utils import rsetattr, rgetattr, load_model_config, load_weight, map_device, assign_device, get_dataset, get_device_map 
from Distiller import MOEDistillerLightningModule, prepare_distilled_moe

In [None]:
layer_idx = 3
n_routed_experts = 4
n_active_experts = 1
learning_rate = 8e-4
end_factor = 0.1
lora_rank = 16
lora_alpha = 16
device = "cuda:0"
weights_location='deepseek_v3/'

params = DistillationParams(
    n_epochs=1,
    n_batch=128,
    n_train_batch=116,
    batch_size=16,
    max_length=512,
    gradient_accumulation_steps=1,
    calibration_batches=16,
    learning_rate=learning_rate,
    end_factor=end_factor,
    temperature=1.0,
    lora_type="dora",
    lora_rank=lora_rank,
    lora_alpha=lora_alpha,
    max_workers=8,
    fp8_format="e4m3",
    distiller_device=device,
)

# Load Model Config and Tokenizer
weight_map, config = load_model_config(weights_location)
path_config=PathConfig()
# Create empty model
with init_empty_weights():
    model = DeepseekV3ForCausalLM(config)

In [None]:
device_map = get_device_map(layer_idx, weight_map, device)
model.model.layers[layer_idx] = model.model.layers[layer_idx].to_empty(device=device)

for i, weight_name in enumerate(tqdm(device_map)):
    rsetattr(model, weight_name, load_weight(weights_location, weight_name, weight_map, device))
    if i%100 ==0:
        memory_cleanup()

In [None]:
with open(f"{path_config.expert_activation_dir}/layer_{layer_idx}.pickle", "rb") as f:
    act = pickle.load(f)

v,c = np.unique(act, return_counts=True)
selected_experts = np.flip(np.argsort(c))

                  
path_config = PathConfig()

pl_model = MOEDistillerLightningModule(
    weight_map,
    path_config,
    params,
    layer_idx=layer_idx,
    n_routed_experts=n_routed_experts,
    n_active_experts=n_active_experts,
    weights_location=weights_location
)

pl_model.distillat=prepare_distilled_moe(
    model.model.layers[layer_idx].mlp,
    selected_experts,
    n_routed_experts,
    n_active_experts,
    params,
    device=device
)

destruct_module_optimized(model)
memory_cleanup()
        

In [None]:
batch_idx = 0
input_data = load_midlayer_state(path_config, layer_idx, batch_idx, batch_size=params.batch_size)
output_data = load_intermediate_state(path_config, layer_idx, batch_idx, batch_size=params.batch_size)

In [None]:
pl_model.merge_and_save()

In [None]:
pl_model