In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig

from torch.utils.tensorboard import SummaryWriter

from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import numpy as np
import argparse
import torch
import json
import os

# from bitsandbytes.optim.ademamix import AdEMAMix8bit as AdEMAMix
# from utils.ademamix import AdEMAMix
from utils.config_utils import PathConfig, DistillationParams

from utils.adapters import DoRAAdapter

from utils.experts_merge_utils import (
    dequantize_GEMM,
    prepare_distillat_topk,
    prepare_distillat_state_cl,
    prepare_distillat_act_cl,
    prepare_moe_for_distillation,
    halve_distilled_mlp,
    merge_and_unload,
    calibrated_dequant,
    build_affinity_matrix,
    expert_clustering,
    cooccurrence_matrix,
    group_items_by_affinity,
)

from utils.torch_utils import (
    load_quant,
    rsetattr,
    destruct_module_optimized,
    memory_cleanup,
    load_weights,
    WarmupCosineAnnealingLR,
    count_parameters,
    convert_meta_model_to_awq
)

from utils.fused import FusedMOE
import pickle

torch.set_float32_matmul_precision('medium')

In [None]:
# device = "cuda:1"
# model_name = "../deepseek_v3_awq"
model_name = "../deepseek_v2_lite_chat_awq"

n_batch=16
batch_size=4
max_length=512

device="cuda:0"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "distillation_logs",
    moe_states="moe_states"
)

distillation_config = DistillationParams(
    n_epochs= 1,
    target_routed_expert = 8,
    target_active_expert = 2,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 8e-4,
    end_factor= 0.2,
    calibrate_merge=False,
    skip_first_tokens=32, ## useful to avoid tuning on early tokens that have less informations
    pruning_method="fused", # topk , act_cl, state_cl
    dora_rank=8,
)

## Instantiate empty model

In [None]:
config=AutoConfig.from_pretrained(
    model_name,
    trust_remote_code=True,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True
)

with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights(include_buffers=True):    
    model = AutoModelForCausalLM.from_config(
        config,
        trust_remote_code=True,
        # torch_dtype=torch.bfloat16,
        # attn_implementation="flash_attention_2",
        # low_cpu_mem_usage=True
    )

model=convert_meta_model_to_awq(model, config, device)

for name, parameter in model.named_parameters():
    parameter.requires_grad = False

model.train()
destruct_module_optimized(model)
memory_cleanup()

## Load layer weights

In [None]:
layer_idx=25

target_modules=[f".layers.{layer_idx}."]
model=load_weights(model, model_name, weight_map, target_modules, device)

model.model.layers[layer_idx].mlp=model.model.layers[layer_idx].mlp.to('cuda:0')

In [None]:
hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{0}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]
hidden_states

In [None]:
# for param in distilled_mlp.parameters():
#     param.requires_grad=False

with open(os.path.join(path_config.expert_activations, f"layer_{layer_idx}.pickle"), "rb") as f:
    (top_k_output, top_k_weight) = pickle.load(f)

top_k_output=top_k_output.detach().to(torch.int64).cpu().numpy()
top_k_output[top_k_output > 512] = 0

affinity_matrix = cooccurrence_matrix(top_k_output, len(np.unique(top_k_output)))
affinity_matrix=(affinity_matrix - affinity_matrix.min())/(affinity_matrix.max()-affinity_matrix.min())

group_size=affinity_matrix.shape[0] // distillation_config.target_routed_expert

In [None]:
lr = 2e-4
train_batches = len(os.listdir(os.path.join(path_config.expert_states, f"layer_{layer_idx}")))
n_epoch = 2000

gradient_accumulation_step = 1

merge_methods = [
    'sce',
    'slerp',
    'mean',
    'greedy'
]

adapter_types = [
    'mixture'
]

ranks = [8,64,512]

eval_batches = 16
train_batches = 32

rank=8
merge_method='slerp'
adapter_type='mixture'



# for lr in lrs:
# for rank in ranks:
#     for merge_method in merge_methods:

distilled_mlp = deepcopy(model.model.layers[layer_idx].mlp).to(device)
layer_norm = deepcopy(model.model.layers[layer_idx].post_attention_layernorm).to(device, dtype=torch.bfloat16)
distilled_mlp.gate = distilled_mlp.gate.to(torch.bfloat16)

for param in distilled_mlp.parameters():
    param.requires_grad = False

fused_moe = FusedMOE(distilled_mlp)
fused_moe.fuse(affinity_matrix, group_size, train_batches, learning_rate=lr, device=device, merge_method=merge_method, rank=rank, adapter_type=adapter_type, low_vram=False)
fused_moe.train_mode(lr, train_batches * n_epoch)
fused_moe.set_ready()
fused_moe.train()

# destruct_module_optimized(model)
memory_cleanup()

# fused_moe = torch.compile(fused_moe, dynamic=True)

for name, params in fused_moe.named_parameters():
    if 'gate.' in name:
        params.requires_grad = False
    if 'qa_weights' in name:
        params.requires_grad = True
    if 'qb_weights' in name:
        params.requires_grad = True
    if 'scaling_factor' in name:
        params.requires_grad = True
    if 'fused_layer.weight' in name:
        params.requires_grad = True


writer = SummaryWriter(log_dir=f'multiplex_runs/fused_lin_{lr}_{merge_method}_{rank}_{adapter_type}_{distillation_config.target_routed_expert}')
# train_batches=32
for epoch in tqdm(range(n_epoch)):  
            
    fused_moe.train()
    progress_bar = tqdm(range(train_batches - eval_batches), desc=f"Calibrating fused_{lr}_{merge_method}_{rank}_{adapter_type}")
    for batch_idx in progress_bar:
        if epoch == 0:
            if batch_idx == 128:
                # if 'fused_layer.weight' in name:
                #     params.requires_grad = True
                if 'scaling_factor' in name:
                    params.requires_grad = False

        
        hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]
        output = load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]

        if not(output.max().isnan()): ## sometime there is numerical instability
            loss = fused_moe.train_step(hidden_states, layer_norm, temperature=1, output=output, gradient_accumulation_step=gradient_accumulation_step)
            progress_bar.set_postfix(loss=loss.item())

        writer.add_scalar(f'Loss/train', loss.item(), batch_idx + epoch * (train_batches - eval_batches))

    memory_cleanup()
    # Evaluation phase
    eval_progress_bar = tqdm(range(train_batches - eval_batches, train_batches), desc=f"Evaluating fused_{lr}_{merge_method}_{rank}_{adapter_type}")
    eval_losses = []
    fused_moe.eval()
    for batch_idx in eval_progress_bar:
        hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:256]
        output = load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:256]
        true=output
        residual = deepcopy(hidden_states)
        hidden_states = layer_norm(hidden_states)

        if not(output.max().isnan()): ## sometime there is numerical instability
            pred = fused_moe.forward(hidden_states) + residual

        local_loss = torch.nn.functional.smooth_l1_loss(pred, output, reduction='mean')
        eval_losses.append(local_loss.item())
        eval_progress_bar.set_postfix(loss=local_loss.item())

    median_eval_loss = torch.tensor(eval_losses).median().item()
    writer.add_scalar(f'Loss/eval', median_eval_loss, epoch)
    memory_cleanup()

# Close the writer
writer.close()
destruct_module_optimized(fused_moe)


In [None]:
hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{4000}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:256]
output = load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{4000}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:256]
true=output
residual = deepcopy(hidden_states)
hidden_states = layer_norm(hidden_states)

if not(output.max().isnan()): ## sometime there is numerical instability
    pred = fused_moe.forward(hidden_states) + residual

In [None]:
pred

In [None]:
true

In [None]:
fused_moe

In [None]:
destruct_module_optimized(fused_moe)

In [None]:
hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{34}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]
output = load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{34}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]

In [None]:
hidden_states.max(), 

In [None]:
destruct_module_optimized(model)
destruct_module_optimized(distilled_mlp)
destruct_module_optimized(fused_moe)

In [None]:
memory_cleanup()

In [None]:
fused_moe.set_ready()

In [None]:
fused_moe.state_dict()