In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from awq.modules.linear.gemm import WQLinear_GEMM
from torch.utils.tensorboard import SummaryWriter
from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import shutil
import numpy as np
import argparse
import pickle
import torch
import json
import os

from utils.ademamix import AdEMAMix
from utils.config_utils import GenerationParams, PathConfig, DistillationParams
from utils.experts_merge_utils import dequantize_GEMM
from utils.torch_utils import (
    destruct_module_optimized,
    memory_cleanup,
    rsetattr,
    load_weights,
    rhasattr,
    count_parameters
)

In [None]:
device = "cuda:1"

model_name = "../deepseek_coder_v2_lite_instruct_awq"
base_model = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"

n_epochs = 0
start_layer = 0
end_layer = 0
target_routed_expert = 16
target_active_expert = target_routed_expert ## unused in multiplex
dora_rank = 8
calibrate_merge= True
pruning_method= "fused"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "distillation_logs",
    moe_states="../moe_states"
)

distillation_config = DistillationParams(
    n_epochs= n_epochs,
    target_routed_expert = target_routed_expert,
    target_active_expert = target_active_expert,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 3e-4,
    end_factor= 0.2,
    calibrate_merge=calibrate_merge,
    skip_first_tokens=0, ## useful to avoid tuning on early tokens that have less informations
    pruning_method=pruning_method, # topk , act_cl, state_cl
    dora_rank=dora_rank,
)

In [None]:
print('Loading model')

from patched_modules.configuration_deepseek_fused_v2 import DeepseekV2Config
from patched_modules.modeling_deepseek_fused_v2 import FusedMOE


with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True
    )

for name, parameter in model.named_parameters():
    parameter.requires_grad = False

config=AutoConfig.from_pretrained(
    base_model,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
)

config = config.to_dict()

config['auto_map'] = {
    'AutoConfig':'configuration_deepseek.DeepseekV2Config',
    'AutoModel':'modeling_deepseek.DeepseekV2Model',
    'AutoModelForCausalLM':'modeling_deepseek.DeepseekV2ForCausalLM'
}

config['n_fused_experts']=distillation_config.target_routed_expert
config['fused_expert_dora_rank']=distillation_config.dora_rank
config['fused_expert_method']="mixture"

config=DeepseekV2Config(**config)


model.train()
destruct_module_optimized(model)
memory_cleanup()

In [None]:
config

In [None]:
for i in range(len(model.model.layers)):
    if rhasattr(model, f"model.layers.{i}.mlp.experts"):
        rsetattr(model, f"model.layers.{i}.mlp.experts", torch.nn.Module()) ## ensuring destruction of experts to avoid oom

model=model.to_empty(device="cpu")

## Load non expert weights

In [None]:
target_modules=[]
for elt in weight_map:
    if not('.experts.' in elt):
        if not('gate.weight' in elt):
            target_modules.append(elt)

model=load_weights(model, model_name, weight_map, target_modules, device)

## Reinitialize experts with new number of layers

In [None]:
for layer_idx, layer in enumerate(tqdm(model.model.layers)):
    if rhasattr(layer.mlp, "experts"):
        shared=deepcopy(layer.mlp.shared_experts)
        
        export_path=path_config.moe_states+f"/distillat_fused_{distillation_config.target_routed_expert}/layer_{layer_idx}"
        state_dict = torch.load(export_path)
        
        new_state_dict = {}
        for key in state_dict.keys():
            new_key = key.replace('fused_experts', 'experts')
            new_key = new_key.replace('_orig_mod.', '')
            
            new_state_dict[new_key] = state_dict[key]
        
        layer.mlp=FusedMOE(config)
        
        layer.mlp.shared_experts=deepcopy(shared)
        layer.mlp.load_state_dict(new_state_dict)
        layer.mlp.shared_experts=shared

In [None]:
count_parameters(model)

## Dequant every WQLinear_GEMM layers

In [None]:
model, params = dequantize_GEMM(model, dtype=torch.bfloat16)
model.to(device, dtype=torch.bfloat16)

In [None]:
count_parameters(model)

In [None]:
print('updating config')

model.config=config

print('Saving')
unhealed_name=model_name+f"_fused_{distillation_config.target_routed_expert}_unhealed"
unhealed_name=unhealed_name.replace('_awq', '')

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

tokenizer.save_pretrained(unhealed_name)
model.save_pretrained(unhealed_name)

shutil.copy(os.path.join('../patched_modules/', 'modeling_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'modeling_deepseek.py'))
shutil.copy(os.path.join('../patched_modules/', 'configuration_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'configuration_deepseek.py'))

shutil.copy(os.path.join('../patched_modules/', 'modeling_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'modeling_deepseek_fused_v2.py'))
shutil.copy(os.path.join('../patched_modules/', 'configuration_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'configuration_deepseek_fused_v2.py'))