## Ressources

The pipeline was optimized for the following config:
- Storage : 2To SSD @512Mo/s
- RAM : 128go DDR4 @3600
- CPU : Ryzen 9 3950X 16@32 cores
- GPU : 2x RTX 3090, aggregated 48gb DDR6X Vram

Hence i can  not guarantee that it will work properly on more frugal hardware.
Plus the GPU are not NVLink unified, so some optimisation involve manual allocation to one or the other GPU.

## Compute the number of parameters for different pruning size

In [None]:
def calc_num_parameters(
    n_routed_experts,
    num_experts_per_tok,
):
    num_hidden_layers=61
    first_k_dense_replace = 3
    num_moe_layer = num_hidden_layers - first_k_dense_replace
    
    hidden_size=7168
    intermediate_size=18432
    moe_intermediate_size=2048
    
    
    num_heads  = 128
    q_lora_rank = 1536
    qk_nope_head_dim = 128
    qk_rope_head_dim = 64
    kv_lora_rank = 512
    v_head_dim=128
    
    n_shared_experts=1
    
    vocab_size = 129280
    
    gate_size = n_routed_experts * hidden_size
    
    mlp_weights = 3 * hidden_size * intermediate_size
    moe_mlp_weights = 3 * hidden_size * moe_intermediate_size
    
    moe_total_weight = n_routed_experts * moe_mlp_weights
    moe_active_weight = num_experts_per_tok * moe_mlp_weights
    
    q_head_dim = qk_nope_head_dim + qk_rope_head_dim
    q_a_proj = hidden_size * q_lora_rank + q_lora_rank * q_head_dim
    kv_a_proj_with_mqa = hidden_size * (kv_lora_rank  + qk_rope_head_dim) + kv_lora_rank * (num_heads * (q_head_dim - qk_rope_head_dim  + v_head_dim))
    o_proj_weight = num_heads * v_head_dim * hidden_size
    attention_weight = q_a_proj + 2 * kv_a_proj_with_mqa + o_proj_weight
    
    base_weight_per_moe_layer = attention_weight + n_shared_experts * moe_mlp_weights + gate_size
    base_weight_per_mlp_layer = attention_weight + mlp_weights
    
    base_model_weight = base_weight_per_moe_layer * num_moe_layer + base_weight_per_mlp_layer * first_k_dense_replace + 2 * vocab_size * hidden_size
    
    total_expert_weight = moe_total_weight * num_moe_layer
    active_expert_weight = moe_active_weight * num_moe_layer
    
    active_model_weight = active_expert_weight + num_moe_layer + base_model_weight
    total_model_weight = total_expert_weight + num_moe_layer + base_model_weight
    
    print(f"{n_routed_experts} @ {num_experts_per_tok} => {int(round(total_model_weight/1e9,0))}B @ {int(round(active_model_weight/1e9,0))}B parameters")

In [None]:
n_routed_experts=256
num_experts_per_tok=8
calc_num_parameters(n_routed_experts, num_experts_per_tok)

In [None]:
p = [
    (256,8),
    (22,6),
    (16,4),
    (8,2),
    (4,1),
]

for elt in p:
    calc_num_parameters(*elt)

## Imports

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from accelerate import init_empty_weights
from accelerate.utils import load_offloaded_weight
import json
from accelerate import load_checkpoint_in_model, dispatch_model
from tqdm.auto import tqdm
from datasets import load_dataset
import numpy as np
import gc
import _pickle as pickle
# from Distiller import MOEDistiller, count_parameters

from transformers import BitsAndBytesConfig
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, AutoConfig

import torch

from memory_utils import load_module_weights_and_freeze_optimized, load_weight_cached, destruct_module_optimized
from Distiller import load_model_config,create_empty_model,create_empty_layer

from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()

model_name = "DeepSeek-V3"
offload_folder = model_name+'_offload/'
output_directory = model_name+'_runner_output/'

tokenizer = AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V3", trust_remote_code=True)

def memory_cleanup():
    """Perform thorough memory cleanup"""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()


## Create dataset

The dataset is a small selection of the excellent dolphin r1 dataset because it contains both non reasoning and reasoning sample.

Feel free to change the dataset or to scale the approach as you feel (will take longer but with better result)

In [2]:
batch_size=2

n_batch=64 // batch_size
n_sample=batch_size * n_batch

gradient_accumulation_steps=8

calibration = load_dataset(
    'cognitivecomputations/dolphin-r1',
    "nonreasoning",
    cache_dir="../dolphin-r1"
)

max_length=512
calibration = calibration['train']
position_ids = torch.arange(
    0,
    max_length,
    dtype=torch.long,
    device="cuda",
).unsqueeze(0)

data=calibration['messages'][:n_sample]
train_dataset = [tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False) for elt in tqdm(data)]

  0%|          | 0/64 [00:00<?, ?it/s]

## Run the embedding layer

Here we run the embedding layer on the dataset to build the first intermediate representations.
This is not very vram intensive, but depending on the size of the dataset, it can quickly eat RAM as the tensors are very large.

I was hesitant to quantise the representation for better memory optim, but was scared to lower the quality too much, especially in lower layers of the model.

In [3]:
import os
from deepseek_v3.modeling_deepseek import _prepare_4d_causal_attention_mask
import bitsandbytes as bnb

base_path='/home/golympie/data/'

## Load
weight_map, config = load_model_config("deepseek_v3")
weight_file = weight_map['model.embed_tokens.weight']



embed_tokens=torch.nn.Embedding(config.vocab_size, config.hidden_size, config.pad_token_id, device="cuda")
embed_tokens.weight.requires_grad=False
embed_tokens.weight.copy_(load_weight_cached('model.embed_tokens.weight', weight_file, "deepseek_v3", "cuda:0"))


intermediate = [0] * n_batch
attention_masks = [0] * n_batch

for batch_idx in tqdm(range(n_batch)):
    batch=train_dataset[batch_idx * batch_size : (batch_idx + 1)*batch_size]
    inputs = tokenizer(batch, max_length=max_length, padding="max_length", truncation=True, return_tensors='pt').to('cuda')
    embed =embed_tokens(inputs['input_ids'])

    intermediate[batch_idx]=embed.to('cpu', dtype=torch.bfloat16)
    
destruct_module_optimized(embed_tokens)
memory_cleanup()

  0%|          | 0/32 [00:00<?, ?it/s]

## Distill layers one by one

Now the hard work, loading layers one by one to the GPUs and distill them.

Note that a full layer in fp8 should take about 14 go vram during inference with batch size 2.
The distiller makes a repartition of the pruned experts on remaining space. With the default config of this notebook, the first distillat is on cuda:0 and the other are in cuda 1.

To make it work properly I had to make a custom implementation of FP8Linear layer, optimized for numerical stability, as the default bnb one was producing Nan outputs very frequently. The current implementation is a bit naive, and could probably be improved with custom triton kernels.

The new layer was required as i am running the pipeline on Ampere GPU, and Deepseek released kernels can only work with Ada lovelace and plus generations.

The logic is hardcoded in the Distiller file, and will need to be updated.

With the current config, expect peak total vram usage of about 46 go. The pipeline should take about 24 hour on my setup to run, with a large part dedicated to disk write / read operations and quantization optimisations. If you have a larger vram, you can probably speed it up a bit :)

In [None]:
from copy import deepcopy
from Distiller import MOEDistillerV3
from copy import deepcopy
from Distiller import MOEDistillerV3, count_parameters

for i in range(65):
    checkpoint_path="layers/"
    distilled_checkpoint_path="distilled_layers/"
    
    os.makedirs(distilled_checkpoint_path, exist_ok=True)
    
    path = checkpoint_path+f"layer{i}.ckpt"
    distilled_path= distilled_checkpoint_path+f"layer{i}.ckpt"
    
    layer = create_empty_layer(config, layer_idx=i)
    layer = load_module_weights_and_freeze_optimized(
        layer,
        f"model.layers.{i}",
        weight_map,
        "deepseek_v3",
        max_workers=32,
        fp8_format="e4m3",
    )
    memory_cleanup()
        
    if  "DeepseekV3MLP" in str(layer.mlp.__class__):
        for batch_idx in tqdm(range(n_batch)):
            with torch.amp.autocast('cuda'):
                intermediate[batch_idx]= layer.forward(
                    hidden_states=intermediate[batch_idx].to('cuda:0'),
                    position_ids=position_ids,
                )[0].detach().to('cpu')
                if batch_idx % 100 == 0:
                    memory_cleanup()
        destruct_module_optimized(layer)
        
    else:
        distiller = MOEDistillerV3(layer, i, model_name=model_name) # Example with accumulation
        
        calibration_batch=[elt.to('cuda:0', dtype=torch.bfloat16) for elt in intermediate[:128]]
        # calibration_attention_batch=[_prepare_4d_causal_attention_mask(elt,(batch_size, max_length),embed,0).to('cuda:0') for elt in attention_masks[:4]]
        
        mlp_params = [
            (22,6),
            (16,4),
            (8,2),
            (4,1),
        ]
        
        
        for elt in mlp_params:
            save_directory=model_name+f"_{elt[0]}@{elt[1]}"
            os.makedirs(save_directory, exist_ok=True)
        
        distiller.calibrate(
            calibration_batch,
            position_ids,
            mlp_params,
            total_steps=n_batch,
            gradient_accumulation_steps=gradient_accumulation_steps,
            learning_rate=1e-4,
            temperature=2,
        )
        memory_cleanup()
        
        progress_bar = tqdm(range(n_batch), desc="Training")
        for batch_idx in progress_bar:
            new_hidden_state, losses = distiller.step(
                intermediate[batch_idx].to('cuda:0', dtype=torch.bfloat16),
                attention_mask=None,
                position_ids=position_ids,
            )
            intermediate[batch_idx]=new_hidden_state.cpu()
            progress_bar.set_postfix(**losses)
            memory_cleanup()

        distiller.save_distillats()  # Call the save function
        destruct_module_optimized(layer) # Destruct after saving
        memory_cleanup()

Processing modules: 100%|█████████████████████████████████████████████| 17/17 [00:08<00:00,  2.00it/s]
Updating module structure: 100%|██████████████████████████████████| 14/14 [00:00<00:00, 114912.44it/s]


  0%|          | 0/32 [00:00<?, ?it/s]

Processing modules: 100%|█████████████████████████████████████████████| 17/17 [00:08<00:00,  2.00it/s]
Updating module structure: 100%|███████████████████████████████████| 14/14 [00:00<00:00, 87122.04it/s]


  0%|          | 0/32 [00:00<?, ?it/s]

Processing modules: 100%|█████████████████████████████████████████████| 17/17 [00:08<00:00,  2.07it/s]
Updating module structure: 100%|███████████████████████████████████| 14/14 [00:00<00:00, 67494.55it/s]


  0%|          | 0/32 [00:00<?, ?it/s]

Processing modules:  36%|███████████████▏                          | 469/1300 [02:09<04:02,  3.43it/s]

In [None]:
distiller = MOEDistillerV3(layer, i, model_name=model_name) # Example with accumulation
        
calibration_batch=[elt.to('cuda:0', dtype=torch.bfloat16) for elt in intermediate[:16]]
# calibration_attention_batch=[_prepare_4d_causal_attention_mask(elt,(batch_size, max_length),embed,0).to('cuda:0') for elt in attention_masks[:4]]

mlp_params = [
    (22,6),
    (16,4),
    (8,2),
    (4,1),
]


for elt in mlp_params:
    save_directory=model_name+f"_{elt[0]}@{elt[1]}"
    os.makedirs(save_directory, exist_ok=True)

distiller.calibrate(
    calibration_batch,
    position_ids,
    mlp_params,
    total_steps=n_batch,
    gradient_accumulation_steps=gradient_accumulation_steps,
    learning_rate=1e-4,
    temperature=2,
)
memory_cleanup()

progress_bar = tqdm(range(16), desc="Training")
for batch_idx in progress_bar:
    new_hidden_state, losses = distiller.step(
        intermediate[batch_idx].to('cuda:0', dtype=torch.bfloat16),
        attention_mask=None,
        position_ids=position_ids,
    )
    intermediate[batch_idx]=new_hidden_state.cpu()
    progress_bar.set_postfix(**losses)
    memory_cleanup()

distiller.save_distillats()  # Call the save function
destruct_module_optimized(layer) # Destruct after saving
memory_cleanup()

In [None]:
layer = create_empty_layer(config, layer_idx=i)
layer = load_module_weights_and_freeze_optimized(
    layer,
    f"model.layers.{i}",
    weight_map,
    "deepseek_v3",
    max_workers=32,
    fp8_format="e4m3",
)
memory_cleanup()

In [None]:
memory_cleanup()

In [None]:
x = distiller.distillats[0]['moe'].experts[0].gate_proj

In [None]:
x = x.merge_and_unload()