In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

from torch.utils.tensorboard import SummaryWriter

from accelerate import init_empty_weights
from tqdm.auto import tqdm
from copy import deepcopy
import numpy as np
import argparse
import torch
import json
import os

from utils.ademamix import AdEMAMix
from utils.config_utils import PathConfig, DistillationParams

from utils.adapters import DoRAAdapter

from utils.experts_merge_utils import (
    dequantize_GEMM,
    prepare_distillat_topk,
    prepare_distillat_state_cl,
    prepare_distillat_act_cl,
    prepare_moe_for_distillation,
    halve_distilled_mlp,
    merge_and_unload,
    calibrated_dequant,
)

from utils.torch_utils import (
    load_quant,
    rsetattr,
    destruct_module_optimized,
    memory_cleanup,
    load_weights,
    WarmupCosineAnnealingLR
)

torch.set_float32_matmul_precision('medium')

In [None]:
device = "cuda:1"
model_name = "../deepseek_coder_v2_lite_instruct_awq"

n_batch=16
batch_size=8
max_length=512

device="cuda:0"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "distillation_logs",
    moe_states="moe_states"
)

distillation_config = DistillationParams(
    n_epochs= 1,
    target_routed_expert = 16,
    target_active_expert = 4,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 8e-4,
    end_factor= 0.2,
    calibrate_merge=False,
    skip_first_tokens=32, ## useful to avoid tuning on early tokens that have less informations
    pruning_method="topk", # topk , act_cl, state_cl
    dora_rank=16,
)

## Instantiate empty model

In [None]:
with open(f"{model_name}/model.safetensors.index.json", "r") as f:
    weight_map = json.load(f)["weight_map"]

with init_empty_weights():
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        attn_implementation="flash_attention_2",
        low_cpu_mem_usage=True,
    )

for name, parameter in model.named_parameters():
    parameter.requires_grad = False

model.train()
destruct_module_optimized(model)
memory_cleanup()

## Load layer weights

In [None]:
layer_idx=22

model.model.layers[layer_idx].to_empty(device=device)
target_modules=[f".layers.{layer_idx}."]
model=load_weights(model, model_name, weight_map, target_modules, device)

In [None]:
distilled_mlp=deepcopy(model.model.layers[layer_idx].mlp).to(device)
layer_norm=deepcopy(model.model.layers[layer_idx].post_attention_layernorm).to(device, dtype=torch.bfloat16)

for i in range(len(distilled_mlp.experts)):
    distilled_mlp.experts[i] = calibrated_dequant(distilled_mlp.experts[i] , layer_norm, path_config, layer_idx)

distilled_mlp.gate = distilled_mlp.gate.to(torch.bfloat16)

# destruct_module_optimized(model)
# memory_cleanup()

# distilled_mlp=halve_distilled_mlp(distilled_mlp, layer_norm, distillation_config, path_config, layer_idx, device)
# distilled_mlp=halve_distilled_mlp(distilled_mlp, layer_norm, distillation_config, path_config, layer_idx, device)

## Prepare Model for finetuning

In [None]:
from utils.grokfast import gradfilter_ema

writer = SummaryWriter(log_dir=path_config.distillation_logs+f"/distillat_{distillation_config.pruning_method}_{distillation_config.target_routed_expert}a{distillation_config.target_active_expert}/layer_{layer_idx}")

os.makedirs(path_config.moe_states, exist_ok=True)

distillation_config.n_epochs=1
distillation_config.learning_rate=1e-4
distillation_config.end_factor=0.1
distillation_config.gradient_accumulation_steps=1

distilled_mlp, optimizer, scheduler, criterion = prepare_moe_for_distillation(distilled_mlp, distillation_config, path_config, layer_idx, device, dtype=torch.bfloat16)
train_batches = len(os.listdir(os.path.join(path_config.expert_states, f"layer_{layer_idx}"))) - distillation_config.eval_batches

if distilled_mlp.config.num_experts_per_tok != distillation_config.target_active_expert:
    print('updating active')
    distilled_mlp.config.num_experts_per_tok=distillation_config.target_active_expert
    distilled_mlp.num_experts_per_tok=distillation_config.target_active_expert
    
    distilled_mlp.gate.config.num_experts_per_tok=distillation_config.target_active_expert
    distilled_mlp.gate.top_k=distillation_config.target_active_expert

halve_every = 128
grads = None
eval_batches = distillation_config.eval_batches

patience = 2  # Number of epochs to wait for improvement
margin = 1e-4  # Minimum improvement required
best_loss = float('inf')
patience_counter = 0

# Training and evaluation loop
for epoch in range(distillation_config.n_epochs):
    distilled_mlp.train()
    # optimizer.train()
    progress_bar = tqdm(range(train_batches), desc=f"Calibrating merged expert, epoch {epoch}")
    
    for batch_idx in progress_bar:
        if (epoch * train_batches + batch_idx + 1) % halve_every == 0:
            if len(distilled_mlp.experts) > distillation_config.target_routed_expert:
                print('halving')
                distilled_mlp=merge_and_unload(distilled_mlp)
                
                distilled_mlp=halve_distilled_mlp(distilled_mlp, layer_norm, distillation_config, path_config, layer_idx, device)
                distilled_mlp, optimizer, scheduler, criterion = prepare_moe_for_distillation(distilled_mlp, distillation_config, path_config, layer_idx, device, dtype=torch.bfloat16)
                # optimizer.train()
                
        with torch.amp.autocast(device):
            hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]
            outputs = load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]

            residual = hidden_states
            hidden_states = layer_norm(hidden_states)

            pred = distilled_mlp(hidden_states)
            
            pred = pred + residual
            
            loss = criterion(pred, outputs)
            
        loss.backward()
        if (epoch * train_batches + batch_idx + 1) % distillation_config.gradient_accumulation_steps == 0:
            # grads = gradfilter_ema(model, grads=grads, alpha=0.95, lamb=5)
            optimizer.step()
            optimizer.zero_grad()

        # Log the training loss
        if scheduler is not None:
            scheduler.step()
        writer.add_scalar('Loss/train', loss.item(), epoch * train_batches + batch_idx)

        progress_bar.set_postfix(loss=loss.item())

    # Evaluation phase at the end of each epoch
    distilled_mlp.train()
    # optimizer.eval()
    eval_loss = 0
    
    with torch.no_grad():
        for batch_idx in range(train_batches, train_batches + distillation_config.eval_batches):
            hidden_states = load_quant(os.path.join(path_config.expert_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]
            outputs = load_quant(os.path.join(path_config.intermediate_states, f"layer_{layer_idx}", f"batch_{batch_idx}")).to(device, dtype=torch.bfloat16)[:, distillation_config.skip_first_tokens:]

            residual = hidden_states
            hidden_states = layer_norm(hidden_states)

            pred = distilled_mlp(hidden_states)
            
            pred = pred + residual
            
            loss = criterion(pred, outputs)
            
            eval_loss += loss.item()

    eval_loss /= eval_batches
    writer.add_scalar('Loss/eval', eval_loss, epoch)
    print(f"Epoch {epoch + 1}/{distillation_config.n_epochs}, Evaluation Loss: {eval_loss}")

    if best_loss - eval_loss > margin:
        best_loss = eval_loss
        patience_counter = 0
    else:
        patience_counter += 1

    if patience_counter >= patience:
        print(f"Early stopping triggered after epoch {epoch + 1}")
        break

writer.close()

## Unmount adapter

In [None]:
prune = 0.08564127795398235
slerp 0.8 = 0.09016684349626303
slerp 0.5 = 0.09877106221392751
slerp 0.98 = 0.084571722894907
slerp 0.05 = 0.08477236004546285
sce = 0.0845571905374527

In [None]:
distilled_mlp=merge_and_unload(distilled_mlp)

In [None]:
distilled_mlp

In [None]:
torch.save(distilled_mlp.state_dict(), export_path)