## Ressources

The pipeline was optimized for the following config:
- Storage : 2To SSD @512Mo/s
- RAM : 128go DDR4 @3600
- CPU : Ryzen 9 3950X 16@32 cores
- GPU : 2x RTX 3090, aggregated 48gb DDR6X Vram

Hence i can  not guarantee that it will work properly on more frugal hardware.
Plus the GPU are not NVLink unified, so some optimisation involve manual allocation to one or the other GPU.

## Compute the number of parameters for different pruning size

In [1]:
def calc_num_parameters(
    n_routed_experts,
    num_experts_per_tok,
):
    num_hidden_layers=61
    first_k_dense_replace = 3
    num_moe_layer = num_hidden_layers - first_k_dense_replace
    
    hidden_size=7168
    intermediate_size=18432
    moe_intermediate_size=2048
    
    
    num_heads  = 128
    q_lora_rank = 1536
    qk_nope_head_dim = 128
    qk_rope_head_dim = 64
    kv_lora_rank = 512
    v_head_dim=128
    
    n_shared_experts=1
    
    vocab_size = 129280
    
    gate_size = n_routed_experts * hidden_size
    
    mlp_weights = 3 * hidden_size * intermediate_size
    moe_mlp_weights = 3 * hidden_size * moe_intermediate_size
    
    moe_total_weight = n_routed_experts * moe_mlp_weights
    moe_active_weight = num_experts_per_tok * moe_mlp_weights
    
    q_head_dim = qk_nope_head_dim + qk_rope_head_dim
    q_a_proj = hidden_size * q_lora_rank + q_lora_rank * q_head_dim
    kv_a_proj_with_mqa = hidden_size * (kv_lora_rank  + qk_rope_head_dim) + kv_lora_rank * (num_heads * (q_head_dim - qk_rope_head_dim  + v_head_dim))
    o_proj_weight = num_heads * v_head_dim * hidden_size
    attention_weight = q_a_proj + 2 * kv_a_proj_with_mqa + o_proj_weight
    
    base_weight_per_moe_layer = attention_weight + n_shared_experts * moe_mlp_weights + gate_size
    base_weight_per_mlp_layer = attention_weight + mlp_weights
    
    base_model_weight = base_weight_per_moe_layer * num_moe_layer + base_weight_per_mlp_layer * first_k_dense_replace + 2 * vocab_size * hidden_size
    
    total_expert_weight = moe_total_weight * num_moe_layer
    active_expert_weight = moe_active_weight * num_moe_layer
    
    active_model_weight = active_expert_weight + num_moe_layer + base_model_weight
    total_model_weight = total_expert_weight + num_moe_layer + base_model_weight
    
    print(f"{n_routed_experts} @ {num_experts_per_tok} => {int(round(total_model_weight/1e9,0))}B @ {int(round(active_model_weight/1e9,0))}B parameters")

In [2]:
n_routed_experts=256
num_experts_per_tok=8
calc_num_parameters(n_routed_experts, num_experts_per_tok)

256 @ 8 => 670B @ 37B parameters


In [None]:
p = [
    (256,8),
    (22,8),
    (16,8),
    (8,8),
]

for elt in p:
    calc_num_parameters(*elt)

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import torch
from accelerate import init_empty_weights
from accelerate.utils import load_offloaded_weight
import json
from accelerate import load_checkpoint_in_model, dispatch_model
from datasets import load_dataset
import numpy as np
import gc
import _pickle as pickle
import os
# from Distiller import MOEDistiller, count_parameters

from transformers import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, AutoConfig

import torch

from memory_utils import load_module_weights_and_freeze_optimized, load_weight_cached, destruct_module_optimized
from Distiller import load_model_config,create_empty_model,create_empty_layer, create_empty_layer_fp8

from liger_kernel.transformers import apply_liger_kernel_to_llama

from copy import deepcopy
from Distiller import MOEDistillerV3
import os
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
import torch
import pickle
from pathlib import Path


import os
from modeling_deepseek import _prepare_4d_causal_attention_mask
import bitsandbytes as bnb

base_path = '/home/golympie/data/'
PICKLE_DIR = "intermediate_states"

# Create directory for pickle files
os.makedirs(PICKLE_DIR, exist_ok=True)

# # ## Load
weight_map, config = load_model_config("deepseek_v3")
weight_file = weight_map['model.embed_tokens.weight']


apply_liger_kernel_to_llama()

model_name = "DeepSeek-V3"
offload_folder = model_name+'_offload/'
output_directory = model_name+'_runner_output/'

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3", trust_remote_code=True)

def memory_cleanup():
    """Perform thorough memory cleanup"""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()


## Save layers to disk for easier loading

In [None]:
# os.makedirs("layers", exist_ok=True)
# model = create_empty_model(config)

In [None]:
# ## Embed
# model.model.embed_tokens = load_module_weights_and_freeze_optimized(
#     model.model.embed_tokens,
#     f"model.embed_tokens",
#     weight_map,
#     "deepseek_v3",
#     max_workers=32,
#     fp8_format="e4m3",
# )

# torch.save(model.model.embed_tokens.state_dict(), 'layers/embed_tokens.pt')

In [None]:
# ## End norm
# model.model.norm = load_module_weights_and_freeze_optimized(
#     model.model.norm,
#     f"model.norm",
#     weight_map,
#     "deepseek_v3",
#     max_workers=32,
#     fp8_format="e4m3",
# )

# torch.save(model.model.norm.state_dict(), 'layers/norm.pt')

In [None]:
# ## Lm head
# model.lm_head = load_module_weights_and_freeze_optimized(
#     model.lm_head,
#     f"lm_head",
#     weight_map,
#     "deepseek_v3",
#     max_workers=32,
#     fp8_format="e4m3",
# )

# torch.save(model.lm_head.state_dict(), 'layers/lm_head.pt')

In [None]:
# destruct_module_optimized(model)
# memory_cleanup()

In [None]:
# ## Layers
# for i in tqdm(range(61,62)):
    
#     layer = create_empty_layer(config, layer_idx=i)
#     layer = load_module_weights_and_freeze_optimized(
#         layer,
#         f"model.layers.{i}",
#         weight_map,
#         "deepseek_v3",
#         max_workers=16,
#         fp8_format="e4m3",
#     )
#     memory_cleanup()

#     torch.save(layer.state_dict(), f'./layers/layer_{i}.pt')
#     destruct_module_optimized(layer)

## Create dataset

The dataset is a small selection of the excellent dolphin r1 dataset because it contains both non reasoning and reasoning sample.

Feel free to change the dataset or to scale the approach as you feel (will take longer but with better result)

In [None]:
import os
import torch
import pickle
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter
from dataclasses import dataclass
from transformers import AutoTokenizer
from typing import Optional, List

from Distiller import DistillationConfig

@dataclass
class PathConfig:
    model_name: str = "deepseek"
    base_dir: str = "distillation_runs"
    checkpoint_dir: str = "layers"
    intermediate_dir: str = "intermediate_states"
    log_dir: str = "distillation_logs"
    
    def __post_init__(self):
        # Create all necessary directories
        for dir_name in [self.base_dir, self.log_dir, self.intermediate_dir]:
            os.makedirs(dir_name, exist_ok=True)
            
    def get_layer_path(self, layer_idx: int) -> str:
        return os.path.join(self.checkpoint_dir, f"layer_{layer_idx}.ckpt")
    
    def get_intermediate_path(self, layer_idx: int, batch_idx: int) -> str:
        os.makedirs(os.path.join(self.intermediate_dir, f"layer_{layer_idx}"), exist_ok=True)
        return os.path.join(self.intermediate_dir, f"layer_{layer_idx}",f"batch{batch_idx}.pt")
    
    def get_distillation_path(self, n_experts: int, n_active: int) -> str:
        return os.path.join(self.base_dir, f"{self.model_name}_{n_experts}@{n_active}")

def save_intermediate_state(path_config: PathConfig, layer_idx: int, batch_idx: int, state: torch.Tensor):
    """Save intermediate layer output to a file in FP8 format"""
    # Downcast to torch.float8_e4m3fn
    fp8_tensor = state.to(torch.float8_e4m3fn)
    torch.save(fp8_tensor, path_config.get_intermediate_path(layer_idx, batch_idx))

def load_intermediate_state(path_config: PathConfig, layer_idx: int, batch_idx: int) -> torch.Tensor:
    """Load intermediate layer output from a file and upcast from FP8"""
    fp8_tensor = torch.load(path_config.get_intermediate_path(layer_idx, batch_idx))
    # Upcast to torch.bfloat16
    return fp8_tensor.to(torch.bfloat16)

@dataclass
class DistillationParams:
    n_batch: int = 512
    batch_size: int = 4
    max_length: int = 512
    n_epoch: int = 1
    gradient_accumulation_steps: int = 4
    calibration_batches: int = 64
    learning_rate: float = 1e-4
    temperature: float = 1.0
    lora_rank: int = 16
    lora_alpha: int = 16
    max_workers: int = 16
    fp8_format: str = "e4m3"

In [None]:
# MoE configurations
MOE_CONFIGS = [
    (16, 8),
    # (22, 8),
    # (8, 8),
]

params = DistillationParams()
path_config = PathConfig()

batch_size=4
n_batch=2048
n_sample=params.batch_size * params.n_batch

gradient_accumulation_steps=8

calibration = load_dataset(
    'cognitivecomputations/dolphin-r1',
    "nonreasoning",
    cache_dir="../dolphin-r1"
)

calibration = calibration['train']

def filter_function(example):
    if example['overall_quality'] is not None:
        if example['overall_quality'] == 5:
            return True
    if example['score'] is not None:
        if example['score'] >= 0.2:
            return True
    return False


calibration = calibration.filter(filter_function)

position_ids = torch.arange(
    0,
    params.max_length,
    dtype=torch.long,
    device="cuda",
).unsqueeze(0)

data=calibration['messages'][:n_sample]
train_dataset = [tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False) for elt in tqdm(data)]

## Distill layers one by one

Now the hard work, loading layers one by one to the GPUs and distill them.

Note that a full layer in fp8 should take about 14 go vram during inference with batch size 2.
The distiller makes a repartition of the pruned experts on remaining space. With the default config of this notebook, the first distillat is on cuda:0 and the other are in cuda 1.

To make it work properly I had to make a custom implementation of FP8Linear layer, optimized for numerical stability, as the default bnb one was producing Nan outputs very frequently. The current implementation is a bit naive, and could probably be improved with custom triton kernels.

The new layer was required as i am running the pipeline on Ampere GPU, and Deepseek released kernels can only work with Ada lovelace and plus generations.

The logic is hardcoded in the Distiller file, and will need to be updated.

With the current config, expect peak total vram usage of about 46 go. The pipeline should take about 24 hour on my setup to run, with a large part dedicated to disk write / read operations and quantization optimisations. If you have a larger vram, you can probably speed it up a bit :)

In [None]:
# Define your configuration and batch size
accumulate_steps = 16  # Accumulate gradients over this many batches
device = "cuda"  # Specify the device

# Initialize embedding layer (moved outside the loop)
embed_tokens = torch.nn.Embedding(
    config.vocab_size,
    config.hidden_size,
    config.pad_token_id,
    device=device
)
embed_tokens.weight.requires_grad = False
embed_tokens.load_state_dict(torch.load("layers/embed_tokens.pt"))
embed_tokens.to(device) # move the embedding layer to cuda device

# Process each batch
all_embeddings = []

for batch_idx in tqdm(range(0, params.n_batch, accumulate_steps), desc="Processing embeddings"):
    # Collect multiple batches
    batches = []
    for i in range(accumulate_steps):
        current_batch_idx = batch_idx + i
        if current_batch_idx >= params.n_batch:
            break  # Stop if we reach the end of the dataset

        batch_start = current_batch_idx * params.batch_size
        batch_end = (current_batch_idx + 1) * params.batch_size
        batches.append(train_dataset[batch_start:batch_end])

    # Concatenate batches
    concatenated_batch = [item for sublist in batches for item in sublist]

    # Tokenize the concatenated batch
    inputs = tokenizer(
        concatenated_batch,
        max_length=params.max_length,
        padding="max_length",
        truncation=True,
        return_tensors='pt'
    ).to(device)

    # Compute embeddings for accumulated batch
    with torch.no_grad():
        embeddings = embed_tokens(inputs['input_ids']).to('cpu', dtype=torch.bfloat16)
    # Split back into original batch sizes and save
    current_idx = 0
    for i in range(accumulate_steps):
        current_batch_idx = batch_idx + i
        if current_batch_idx >= params.n_batch:
            break

        split_size = params.batch_size
        batch_embeddings = embeddings[current_idx * params.max_length : (current_idx + 1) * params.max_length]
        
        save_intermediate_state(path_config, -1, current_batch_idx, batch_embeddings)

        current_idx += 1

    if batch_idx % 100 == 0:
        memory_cleanup()

# Cleanup
destruct_module_optimized(embed_tokens)
memory_cleanup()

In [None]:
# Initialize tensorboard writer
writer = SummaryWriter(path_config.log_dir)
# Create distillation configs for each MoE configuration
distillation_configs = {
    f"{n_experts}@{n_active}": DistillationConfig(
        adapter_type="dora",
        adapter_rank=params.lora_rank,
        adapter_alpha=params.lora_alpha,
        learning_rate=params.learning_rate,
        temperature=params.temperature,
        total_steps=params.n_batch,
        gradient_accumulation_steps=params.gradient_accumulation_steps
    )
    for n_experts, n_active in MOE_CONFIGS
}

# Process each layer
for layer_idx in range(61):
    # Create and load layer
    print(f'Loading layer {layer_idx}')
    layer = create_empty_layer_fp8(config, layer_idx=layer_idx)
    layer = layer.load_state_dict(torch.load(f'layers/layer_{layer_idx}.pt'), assign=True)
    print('layer loaded')
    memory_cleanup()


    # Store intermediate states
    intermediate_states = {}
    
    if "DeepseekV3MLP" in str(layer.mlp.__class__):
        # Process standard MLP layer
        for batch_idx in tqdm(range(params.n_batch), desc=f"Processing MLP Layer {layer_idx}"):
            
            prev_state = load_intermediate_state(path_config, layer_idx-1, batch_idx)
            
            with torch.amp.autocast('cuda'):
                new_state = layer.forward(
                    hidden_states=prev_state.to('cuda:0'),
                    position_ids=position_ids,
                )[0].detach()
                
            save_intermediate_state(path_config, layer_idx, batch_idx, new_state)
            intermediate_states[batch_idx] = new_state.cpu()
            
            if batch_idx % 100 == 0:
                memory_cleanup()
                
        destruct_module_optimized(layer)
        
    else:
        # Process MoE layer
        distiller = MOEDistillerV3(layer, layer_idx, model_name=path_config.model_name)
        
        # Prepare calibration data
        calibration_batches = [
            load_intermediate_state(path_config, layer_idx-1, idx).to('cuda:0', dtype=torch.bfloat16)
            for idx in range(min(params.calibration_batches, params.n_batch))
        ]
        
        # Create output directories for each MoE configuration
        for n_experts, n_active in MOE_CONFIGS:
            os.makedirs(path_config.get_distillation_path(n_experts, n_active), exist_ok=True)
        
        # Calibrate distiller
        distiller.calibrate(
            calibration_batches,
            position_ids,
            MOE_CONFIGS,
            distillation_configs[f"{MOE_CONFIGS[0][0]}@{MOE_CONFIGS[0][1]}"]  # Use first config as default
        )
        
        memory_cleanup()
        
        # Training loop
        progress_bar = tqdm(range(params.n_batch * params.n_epoch), desc=f"Training Layer {layer_idx}")
        for batch_idx in progress_bar:
            prev_state = load_intermediate_state(
                path_config, 
                layer_idx-1, 
                batch_idx % params.n_epoch,
            )
            
            new_state, losses = distiller.step(
                prev_state.to('cuda:0', dtype=torch.bfloat16),
                attention_mask=None,
                position_ids=position_ids,
            )
            
            save_intermediate_state(path_config, layer_idx, batch_idx, new_state)
            intermediate_states[batch_idx] = new_state.cpu()
            
            # Log losses
            for loss_name, loss_value in losses.items():
                writer.add_scalar(
                    f"layer_{layer_idx}/{loss_name}",
                    loss_value,
                    batch_idx
                )
            
            progress_bar.set_postfix(**losses)
            memory_cleanup()
        
        # Save and cleanup
        distiller.save_distillats()
        destruct_module_optimized(layer)
        
        for distillat in distiller.distillats:
            destruct_module_optimized(distillat["moe"])
            
        memory_cleanup()

# Cleanup intermediate files
for layer_idx in range(-1, 61):  # Include embedding layer (-1)
    for batch_idx in range(params.n_batch):
        os.remove(path_config.get_intermediate_path(layer_idx, batch_idx))
os.rmdir(path_config.intermediate_dir)

writer.close()