In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [2]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from awq.modules.linear.gemm import WQLinear_GEMM
from torch.utils.tensorboard import SummaryWriter
from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import shutil
import numpy as np
import argparse
import pickle
import torch
import json
import os

from utils.ademamix import AdEMAMix
from utils.config_utils import GenerationParams, PathConfig, DistillationParams
from utils.experts_merge_utils import dequantize_GEMM
from utils.torch_utils import (
    destruct_module_optimized,
    memory_cleanup,
    rsetattr,
    load_weights,
    rhasattr,
)

In [3]:
device = "cuda:0"
model_name = "../deepseek_coder_v2_lite_instruct_awq"
n_epochs = 0
start_layer = 0
end_layer = 0
target_routed_expert = 4
target_active_expert = 4
dora_rank = 16
calibrate_merge= True
pruning_method= "multiplex"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "distillation_logs",
    moe_states="../moe_states"
)

distillation_config = DistillationParams(
    n_epochs= n_epochs,
    target_routed_expert = target_routed_expert,
    target_active_expert = target_active_expert,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 3e-4,
    end_factor= 0.2,
    calibrate_merge=calibrate_merge,
    skip_first_tokens=0, ## useful to avoid tuning on early tokens that have less informations
    pruning_method=pruning_method, # topk , act_cl, state_cl
    dora_rank=dora_rank,
)

In [4]:
print('Loading model')

from patched_modules.configuration_deepseek_fused_v2 import DeepseekV2Config
from patched_modules.modeling_deepseek_fused_v2 import MultiplexedMOE


with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True
    )

for name, parameter in model.named_parameters():
    parameter.requires_grad = False

config=AutoConfig.from_pretrained(
    model_name,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.float16,
)

config.n_multiplexed_routed_experts=distillation_config.target_routed_expert


model.train()
destruct_module_optimized(model)
memory_cleanup()

You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set your model on a GPU device in order to run your model.
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.


Loading model


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [5]:
for i in range(len(model.model.layers)):
    if rhasattr(model, f"model.layers.{i}.mlp.experts"):
        rsetattr(model, f"model.layers.{i}.mlp.experts", torch.nn.Module()) ## ensuring destruction of experts to avoid oom

model=model.to_empty(device="cpu")

## Load non expert weights

In [6]:
target_modules=[]
for elt in weight_map:
    if not('.experts.' in elt):
        if not('gate.weight' in elt):
            target_modules.append(elt)

model=load_weights(model, model_name, weight_map, target_modules, device)

  0%|          | 0/15653 [00:00<?, ?it/s]

## Reinitialize experts with new number of layers

In [9]:
for layer_idx, layer in enumerate(tqdm(model.model.layers)):
    if rhasattr(layer.mlp, "experts"):
        shared=deepcopy(layer.mlp.shared_experts)
        
        export_path=path_config.moe_states+f"/distillat_multiplex_{distillation_config.target_routed_expert}/layer_{layer_idx}"
        state_dict = torch.load(export_path)
        
        new_state_dict = {}
        for key in state_dict.keys():
            new_key = key.replace('multiplexed_experts', 'experts')
            new_state_dict[new_key] = state_dict[key]
        
        layer.mlp=MultiplexedMOE(config)
        layer.mlp.shared_experts=deepcopy(shared)
        layer.mlp.load_state_dict(new_state_dict)
        layer.mlp.shared_experts=shared

  0%|          | 0/27 [00:00<?, ?it/s]

FileNotFoundError: [Errno 2] No such file or directory: '../moe_states/distillat_multiplex_4/layer_4'

In [None]:
count_parameters(model)

## Dequant every WQLinear_GEMM layers

In [None]:
model, params = dequantize_GEMM(model, dtype=torch.bfloat16)
model.to('cuda:0', dtype=torch.bfloat16)

In [None]:
print('updating config')

model.config=config

print('Saving')
unhealed_name=model_name+f"_fused_{distillation_config.target_routed_expert}_unhealed"

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

tokenizer.save_pretrained(unhealed_name)
model.save_pretrained(unhealed_name)

shutil.copy(os.path.join('../patched_modules/', 'modeling_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'modeling_deepseek.py'))
shutil.copy(os.path.join('../patched_modules/', 'configuration_deepseek_fused_v2.py'), os.path.join(unhealed_name, 'configuration_deepseek.py'))