In [None]:
%load_ext autoreload
%autoreload 2

from memory_utils import load_module_weights_and_freeze_optimized, memory_cleanup
from accelerate import init_empty_weights
from deepseek_v3.modeling_deepseek import DeepseekV3ForCausalLM, DeepseekV3MoE
import torch
import json
from tqdm.auto import tqdm
from transformers import AutoConfig, AutoModelForCausalLM

from fp8_linear import FP8Linear
import shutil

## Instantiate empty model

In [None]:
model_name="deepseek_v3"

config = AutoConfig.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
)

with init_empty_weights():
    model = DeepseekV3ForCausalLM(config)

for param in model.parameters():
    param.requires_grad = False

## Load non expert weights in model

This part should take about 16go of VRAM (base model weight without experts is about 17B parameters)

In [None]:
with open(f"{model_name}/model.safetensors.index.json", 'r') as f:
    weight_map = json.load(f)['weight_map']

ignore=["gate.weight", ".experts."]
device="cuda:0"

model = load_module_weights_and_freeze_optimized(
    model, 
    module_path=None,
    weight_map=weight_map,
    model_name=model_name,
    max_workers=16,
    ignore=ignore,
    device=device
)

memory_cleanup()

from memory_utils import count_parameters
count_parameters(model)

## Cast quantized layers back to bfloat16, and offload them to cpu

Casting the model to bf16 will push VRAM usage to around 28go, a bit to much for my tiny 3090

In [None]:
for name, module in tqdm(model.named_modules(), desc="Replacing FP8Linear"):
    if isinstance(module, FP8Linear):
        # Split the name into its components
        parts = name.split('.')
        parent = model
        # Traverse the model hierarchy to get the parent module
        for part in parts[:-1]:
            if part.isdigit(): #handle sequential modules
                parent = parent[int(part)]
            elif '[' in part: # handle ModuleDict
                parent = parent[part.split('[')[0]]
                idx = part.split('[')[1].split(']')[0].replace("'",'').replace('"','') #get string
                parent = parent[idx] #now get the actual submodule


            else:
                parent = getattr(parent, part)

        # Replace the module using setattr on the parent
        setattr(parent, parts[-1], module.to_linear().to('cpu'))
memory_cleanup()

## Load distilled experts weights and save to disk

Depending on expert size, vram can vary. You can load some on gpu and some on cpu

You shall witness the magic moment of parameter shrinking as the layers are progressively updated

In [None]:
import os

In [None]:
from copy import deepcopy
config=model.config
source_folder=model_name
mlp_params = [
    # (22,6),
    # (16,4),
    # (8,2),
    (4,1),
]

for specs in mlp_params:
    config.n_routed_experts = specs[0]
    config.num_experts_per_tok = specs[1]
    model.config=config
    base_path=f"DeepSeek-V3_{specs[0]}@{specs[1]}/"
    
    mlp = DeepseekV3MoE(config).to('cuda:1')
    for layer_index in tqdm(range(3,61)):
        mlp.gate.load_state_dict(torch.load(base_path+f"gate_layer_{layer_index}.ckpt"))
        mlp.experts.load_state_dict(torch.load(base_path+f"experts_layer_{layer_index}.ckpt"))
        
        model.model.layers[layer_index].mlp.gate=deepcopy(mlp.gate).to(dtype=torch.bfloat16, device="cpu")
        model.model.layers[layer_index].mlp.experts=deepcopy(mlp.experts).to(dtype=torch.bfloat16, device="cpu")
        
        count_parameters(model)
        memory_cleanup()
    
    
    model_checkpoint_name=f"DeepSeek-V3-{specs[0]}@{specs[1]}-unhealed"
    model.save_pretrained(model_checkpoint_name)
    
    files_to_copy = [
        "modeling_deepseek.py",
        "tokenizer.json",
        "tokenizer_config.json",
    ]
    
    # Create the destination directory if it doesn't exist
    if not os.path.exists(model_checkpoint_name):
        os.makedirs(model_checkpoint_name)
    
    for file in files_to_copy:
        source_path = os.path.join(source_folder, file)
        destination_path = os.path.join(model_checkpoint_name, file)
    
        # Check if the source file exists before copying
        if os.path.exists(source_path):
            shutil.copy2(source_path, destination_path)  # Use copy2 to preserve metadata
            print(f"Copied '{file}' from '{source_folder}' to '{model_checkpoint_name}'")
        else:
            print(f"Warning: File '{file}' not found in '{source_folder}'")