In [None]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
from configs import DistillationParams, PathConfig
import torch
from Distiller import (
    IntermediateStateDataset,
    prepare_distilled_moe,
    MOEDistillerLightningModule,
    OptimizerModeCallback,
)
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger
import torch.multiprocessing as mp
import argparse

from torch_utils import memory_cleanup, destruct_module_optimized, count_parameters
from model_utils import (
    rsetattr,
    rgetattr,
    load_model_config,
    load_weight,
    map_device,
    assign_device,
    get_dataset,
    get_device_map,
)
from accelerate import init_empty_weights
from modeling_deepseek import (
    DeepseekV3DecoderLayer,
    DeepseekV3MoE,
    DeepseekV3ForCausalLM,
)
from tqdm.auto import tqdm
import pickle
from fp8_linear import FP8Linear

## Instantiate empty model

In [None]:
path_config = PathConfig()

In [None]:
weights_location = "deepseek_v3"

weight_map, config = load_model_config(weights_location)
# Create empty model
with init_empty_weights():
    model = DeepseekV3ForCausalLM(config)

for param in model.parameters():
    param.requires_grad = False

destruct_module_optimized(model)
memory_cleanup()

In [None]:
device = "cpu"

device_map = []

for elt in weight_map:
    if not ("experts" in elt):
        if not (".gate." in elt):
            if not (".61." in elt):
                device_map.append(elt)

for i, weight_name in enumerate(tqdm(device_map)):
    rsetattr(
        model,
        weight_name,
        load_weight(weights_location, weight_name, weight_map, device),
    )
    if i % 100 == 0:
        memory_cleanup()

In [None]:
n_routed_experts = 8
n_active_experts = 4
model_name = f"{weights_location}_{n_routed_experts}a{n_active_experts}"

model.config.n_routed_experts = n_routed_experts
model.config.num_experts_per_tok = n_active_experts

for layer_idx in tqdm(range(3, 61)):
    with init_empty_weights():
        model.model.layers[layer_idx].mlp = DeepseekV3MoE(model.config)

    model.model.layers[layer_idx].mlp.load_state_dict(
        torch.load(
            os.path.join(
                path_config.base_dir,
                f"{weights_location}_{n_routed_experts}@{n_active_experts}",
                f"layer_{layer_idx}.pt",
            )
        ),
        assign=True,
    )
    model.model.layers[layer_idx].mlp = model.model.layers[layer_idx].mlp.to(device)
    print(f"Layer {layer_idx} pruning")
    count_parameters(model)

In [None]:
for name, module in tqdm(model.named_modules()):
    if isinstance(module, FP8Linear):
        rsetattr(model, name, module.to_linear(device))

In [None]:
model_path = f"{weights_location}_{n_routed_experts}a{n_active_experts}"
model.save_pretrained(model_name)

In [None]:
import shutil

shutil.copy("configuration_deepseek.py", f"{model_path}/configuration_deepseek.py")

with open("modeling_deepseek.py", "r") as f:
    data = f.read()

data = data.replace("from fp8_linear import FP8Linear", "")
data = data.replace("FP8Linear", "nn.Linear")

with open(f"{model_name}/modeling_deepseek.py", "w") as f:
    f.write(data)