In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

from torch.utils.tensorboard import SummaryWriter

from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import numpy as np
import argparse
import torch
import json
import os

from utils.ademamix import AdEMAMix
from utils.config_utils import PathConfig, DistillationParams

from utils.adapters import DoRAAdapter

from utils.experts_merge_utils import (
    dequantize_GEMM,
    prepare_distillat_topk,
    prepare_distillat_state_cl,
    prepare_distillat_act_cl,
    prepare_moe_for_distillation,
    halve_distilled_mlp,
    merge_and_unload,
    calibrated_dequant,
    build_affinity_matrix,
    expert_clustering,
    cooccurrence_matrix,
    group_items_by_affinity
)

from utils.torch_utils import (
    load_quant,
    rsetattr,
    destruct_module_optimized,
    memory_cleanup,
    load_weights,
    WarmupCosineAnnealingLR
)

from utils.multiplex import MultiplexedMOE
import pickle

torch.set_float32_matmul_precision('medium')

In [None]:
# device = "cuda:1"
model_name = "../deepseek_coder_v2_lite_instruct_awq"

n_batch=16
batch_size=8
max_length=512

device="cuda:0"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "distillation_logs",
    moe_states="moe_states"
)

distillation_config = DistillationParams(
    n_epochs= 1,
    target_routed_expert = 2,
    target_active_expert = 2,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 8e-4,
    end_factor= 0.2,
    calibrate_merge=False,
    skip_first_tokens=0, ## useful to avoid tuning on early tokens that have less informations
    pruning_method="topk", # topk , act_cl, state_cl
    dora_rank=16,
)

## Instantiate empty model

In [None]:
with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True,
    )

for name, parameter in model.named_parameters():
    parameter.requires_grad = False

model.train()
destruct_module_optimized(model)
memory_cleanup()

## Load layer weights

In [None]:
layer_idx=22

model.model.layers[layer_idx].to_empty(device=device)
target_modules=[f".layers.{layer_idx}."]
model=load_weights(model, model_name, weight_map, target_modules, device)

In [None]:
distilled_mlp=deepcopy(model.model.layers[layer_idx].mlp).to(device)
layer_norm=deepcopy(model.model.layers[layer_idx].post_attention_layernorm).to(device, dtype=torch.bfloat16)

distilled_mlp.gate = distilled_mlp.gate.to(torch.bfloat16)

In [None]:
for param in distilled_mlp.parameters():
    param.requires_grad=False
    
with open(os.path.join(path_config.expert_activations, f"layer_{layer_idx}.pickle"), "rb") as f:
    (top_k_output, top_k_weight) = pickle.load(f)

top_k_output=top_k_output.detach().to(torch.int64).cpu().numpy()
affinity_matrix = cooccurrence_matrix(top_k_output, len(np.unique(top_k_output)))
affinity_matrix=(affinity_matrix - affinity_matrix.min())/(affinity_matrix.max()-affinity_matrix.min())

train_batches=2048

group_size=affinity_matrix.shape[0] // distillation_config.target_routed_expert

In [None]:
lr=3e-4

merge_method="slerp"


multiplex=MultiplexedMOE(distilled_mlp, distillation_config.target_routed_expert)
multiplex.multiplex(affinity_matrix, group_size, train_batches, learning_rate=lr, device=device, merge_method=merge_method)
multiplex=multiplex.train()

writer = SummaryWriter(log_dir=f'multiplex_runs/{lr}_{merge_method}_addmask')

progress_bar = tqdm(range(train_batches), desc=f"Calibrating multiplexage {lr}")
for batch_idx in progress_bar:
    hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]

    loss = multiplex.train_step(hidden_states, layer_norm, temperature=1)
    progress_bar.set_postfix(loss=loss.item())
    writer.add_scalar(f'Loss/train', loss.item(), batch_idx)

# Close the writer
writer.close()

hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{0}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]
hidden_states = layer_norm(hidden_states)[:2]

x1=multiplex(hidden_states)
x2=multiplex.forward_origin(hidden_states)
x3=distilled_mlp(hidden_states)

print(
    torch.nn.functional.smooth_l1_loss(x1,x2, reduction='mean'),
    torch.nn.functional.smooth_l1_loss(x1,x3, reduction='mean'),
    torch.nn.functional.smooth_l1_loss(x2,x3, reduction='mean')
)

In [None]:
multiplex.set_ready()

In [None]:
torch.save(distilled_mlp.state_dict(), export_path)