In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from awq.modules.linear.gemm import WQLinear_GEMM
from torch.utils.tensorboard import SummaryWriter
from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import numpy as np
import shutil
import pickle
import torch
import json
import os

from utils.ademamix import AdEMAMix
from utils.config_utils import GenerationParams, PathConfig, DistillationParams
from utils.expert_merge_utils import calibrated_merge_experts, dequantize_GEMM
from utils.experts_gate_utils import create_gate
from utils.torch_utils import (
    save_quant,
    load_quant,
    destruct_module_optimized,
    memory_cleanup,
    get_nonreasoning_dataset,
    load_weight,
    rsetattr,
    rgetattr,
    load_weights,
    rhasattr,
    count_parameters
)

In [None]:
model_name = "../deepseek_v2_lite_awq"
device="cuda:0"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "../distillation_logs",
    moe_states="../moe_states"
)

distillation_config = DistillationParams(
    n_epochs= 10,
    target_routed_expert = 4,
    target_active_expert = 2,
    eval_batches=16,
    gradient_accumulation_steps= 1,
    learning_rate= 6e-4,
    end_factor= 0.05,
)

In [None]:
with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

config=AutoConfig.from_pretrained(
    model_name,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
)

config.n_routed_experts=distillation_config.target_routed_expert
config.num_experts_per_tok=distillation_config.target_active_expert

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True
    )

for name, parameter in model.named_parameters():
    parameter.requires_grad = False


model.train()
destruct_module_optimized(model)
memory_cleanup()

In [None]:
for i in range(len(model.model.layers)):
    if rhasattr(model, f"model.layers.{i}.mlp.experts"):
        rsetattr(model, f"model.layers.{i}.mlp.experts", torch.nn.Module()) ## ensuring destruction of experts to avoid oom

model=model.to_empty(device="cpu")

## Load non expert weights

In [None]:
target_modules=[]
for elt in weight_map:
    if not('.experts.' in elt):
        if not('gate.weight' in elt):
            target_modules.append(elt)

model=load_weights(model, model_name, weight_map, target_modules, device)

## Reinitialize experts with new number of layers

In [None]:
for layer_idx, layer in enumerate(tqdm(model.model.layers)):
    if rhasattr(layer.mlp, "experts"):
        shared=deepcopy(layer.mlp.shared_experts)
        layer.mlp.__init__(config)
        layer.mlp.shared_experts=shared
        
        export_path=path_config.moe_states+f"/distillat_{distillation_config.target_routed_expert}a{distillation_config.target_active_expert}/layer_{layer_idx}"
        layer.mlp.load_state_dict(torch.load(export_path))

In [None]:
count_parameters(model)

## Dequant every WQLinear_GEMM layers

In [None]:
model, params = dequantize_GEMM(model, dtype=torch.bfloat16)
model.to('cuda:0', dtype=torch.bfloat16)

In [None]:
print('updating config')
config=AutoConfig.from_pretrained(
    base_model,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
)

config.n_routed_experts=distillation_config.target_routed_expert
config.num_experts_per_tok=distillation_config.target_active_expert

model.config=config

print('Saving')
unhealed_name=model_name+f"_{distillation_config.target_routed_expert}a{distillation_config.target_active_expert}_unhealed"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

tokenizer.save_pretrained(unhealed_name)
model.save_pretrained(unhealed_name)

shutil.copy(os.path.join(model_name, 'modeling_deepseek.py'), os.path.join(unhealed_name, 'modeling_deepseek.py'))
shutil.copy(os.path.join(model_name, 'configuration_deepseek.py'), os.path.join(unhealed_name, 'configuration_deepseek.py'))