In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from awq.modules.linear.gemm import WQLinear_GEMM
from torch.utils.tensorboard import SummaryWriter
from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import shutil
import numpy as np
import argparse
import pickle
import torch
import json
import os

from utils.ademamix import AdEMAMix
from utils.config_utils import GenerationParams, PathConfig, DistillationParams
from utils.torch_utils import (
    destruct_module_optimized,
    memory_cleanup,
    rsetattr,
    load_weights,
    rhasattr,
    count_parameters
)

In [None]:
device = "cpu"

model_name = "../deepseek_v3_awq"
base_model = "unsloth/DeepSeek-V3-bf16"

# model_name = "../deepseek_v2_lite_chat_awq"
# base_model = "deepseek-ai/DeepSeek-V2-Lite-Chat"

n_epochs = 0
start_layer = 0
end_layer = 0
target_routed_expert = 4
target_active_expert = target_routed_expert ## unused in multiplex
dora_rank = 4
calibrate_merge= True
pruning_method= "fused"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "distillation_logs",
    moe_states="../moe_states"
)

distillation_config = DistillationParams(
    n_epochs= n_epochs,
    target_routed_expert = target_routed_expert,
    target_active_expert = target_active_expert,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 3e-4,
    end_factor= 0.2,
    calibrate_merge=calibrate_merge,
    skip_first_tokens=0, ## useful to avoid tuning on early tokens that have less informations
    pruning_method=pruning_method, # topk , act_cl, state_cl
    dora_rank=dora_rank,
)

In [None]:
print('Loading model')

from patched_modules.configuration_deepseek import DeepseekV3Config
from patched_modules.modeling_deepseek_fused import FusedMOE

# from patched_modules.configuration_deepseek_fused_v2 import DeepseekV2Config as DeepseekV3Config
# from patched_modules.modeling_deepseek_fused_v2 import FusedMOE

from utils.torch_utils import convert_meta_model_to_awq


config=AutoConfig.from_pretrained(
    model_name,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)

with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights():
    # model = AutoModelForCausalLM.from_pretrained(
    #     model_name,
    #     trust_remote_code=True,
    #     torch_dtype=torch.bfloat16,
    #     attn_implementation="flash_attention_2",
    #     low_cpu_mem_usage=True
    # )
    model = AutoModelForCausalLM.from_config(
        config,
        trust_remote_code=True,
        # torch_dtype=dtype,
        # attn_implementation="flash_attention_2",
        # low_cpu_mem_usage=True
    )


model=convert_meta_model_to_awq(model, config, device)

for name, parameter in model.named_parameters():
    parameter.requires_grad = False

config=AutoConfig.from_pretrained(
    '../patched_modules/',
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
)

# config=AutoConfig.from_pretrained(
#     base_model,
#     trust_remote_code=True,
#     attn_implementation="flash_attention_2",
#     torch_dtype=torch.float16,
# )

config = config.to_dict()

config['auto_map'] = {
    'AutoConfig':'configuration_deepseek.DeepseekV3Config',
    'AutoModel':'modeling_deepseek.DeepseekV3Model',
    'AutoModelForCausalLM':'modeling_deepseek.DeepseekV3ForCausalLM'
}

# config['auto_map'] = {
#     'AutoConfig':'configuration_deepseek.DeepseekV2Config',
#     'AutoModel':'modeling_deepseek.DeepseekV2Model',
#     'AutoModelForCausalLM':'modeling_deepseek.DeepseekV2ForCausalLM'
# }

config['n_fused_experts']=distillation_config.target_routed_expert
config['fused_expert_dora_rank']=distillation_config.dora_rank
config['fused_expert_method']="mixture"

config=DeepseekV3Config(**config)


model.train()
destruct_module_optimized(model)
memory_cleanup()

In [None]:
for i in range(len(model.model.layers)):
    if rhasattr(model, f"model.layers.{i}.mlp.experts"):
        rsetattr(model, f"model.layers.{i}.mlp.experts", torch.nn.Module()) ## ensuring destruction of experts to avoid oom

model=model.to_empty(device="cpu")

## Load non expert weights

In [None]:
target_modules=[]
for elt in weight_map:
    if not('.experts.' in elt):
        if not('gate.weight' in elt):
            target_modules.append(elt)

model=load_weights(model, model_name, weight_map, target_modules, "cpu")

## Reinitialize experts with new number of layers

In [None]:
import gc

for layer_idx, layer in enumerate(tqdm(model.model.layers)):
    if rhasattr(layer.mlp, "experts"):
        shared=deepcopy(layer.mlp.shared_experts)
        
        export_path=path_config.moe_states+f"/distillat_fused_{distillation_config.target_routed_expert}/layer_{layer_idx}"
        state_dict = torch.load(export_path)
        
        new_state_dict = {}
        for key in state_dict.keys():
            new_key = key.replace('fused_experts', 'experts')
            new_key = new_key.replace('_orig_mod.', '')
            
            new_state_dict[new_key] = state_dict[key]
        
        layer.mlp=FusedMOE(config)
        
        layer.mlp.shared_experts=deepcopy(shared)
        layer.mlp.load_state_dict(new_state_dict)
        layer.mlp.shared_experts=shared

        gc.collect()

        del shared
        memory_cleanup()

# model=model.to(device, dtype=torch.bfloat16)

In [None]:
model=model.to(device)
memory_cleanup()

In [None]:
count_parameters(model)

## Dequant every WQLinear_GEMM layers

In [None]:
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.triton.gemm import awq_gemm_triton, awq_dequantize_triton

def dequantize_WQLinear_GEMM(wq_linear, destruct=True, dtype=torch.bfloat16, device='cpu'):
    quant_params={
        'w_bit': wq_linear.w_bit,
        'group_size': wq_linear.group_size,
        'in_features': wq_linear.in_features,
        'out_features': wq_linear.out_features,
        'bias': wq_linear.bias is not None,
        'dev':wq_linear.qweight.device,
        'zero_point':wq_linear.qzeros is not None,
    }
    
    linear=torch.nn.Linear(wq_linear.in_features, wq_linear.out_features, bias=wq_linear.bias is not None, device=wq_linear.qweight.device, dtype=dtype)

    wq_linear=wq_linear.to('cuda:0')
    linear.weight=torch.nn.Parameter(
        awq_dequantize_triton(
            wq_linear.qweight,
            wq_linear.scales,
            wq_linear.qzeros,
        ).T,
        requires_grad=True
    )
    
    if wq_linear.bias is not None:
        linear.bias = torch.nn.Parameter(wq_linear.bias, requires_grad=True)
    
    if destruct:
        wq_linear.to_empty(device="meta")
        del wq_linear
    return linear.to(dtype=dtype, device=device)

count=0
for name, module in tqdm(model.named_modules()):
    if isinstance(module, WQLinear_GEMM):
        rsetattr(model, name, dequantize_WQLinear_GEMM(module, destruct=True, dtype=torch.bfloat16, device='cpu'))
        count+=1
        if count==100:
            count=0
            memory_cleanup()

In [None]:
count_parameters(model)

In [None]:
model.lm_head.weight

In [None]:
print('updating config')

model.config=config

print('Saving')
unhealed_name=model_name+f"_fused_{distillation_config.target_routed_expert}_unhealed"
unhealed_name=unhealed_name.replace('_awq', '').replace("../","/home/golympie/")

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

tokenizer.save_pretrained(unhealed_name)
model.save_pretrained(unhealed_name)

shutil.copy(os.path.join('../patched_modules/', 'modeling_deepseek_fused.py'), os.path.join(unhealed_name, 'modeling_deepseek.py'))
shutil.copy(os.path.join('../patched_modules/', 'configuration_deepseek.py'), os.path.join(unhealed_name, 'configuration_deepseek.py'))

# shutil.copy(os.path.join('../patched_modules/', 'modeling_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'modeling_deepseek.py'))
# shutil.copy(os.path.join('../patched_modules/', 'configuration_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'configuration_deepseek.py'))

# shutil.copy(os.path.join('../patched_modules/', 'modeling_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'modeling_deepseek_fused_v2.py'))
# shutil.copy(os.path.join('../patched_modules/', 'configuration_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'configuration_deepseek_fused_v2.py'))