In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [2]:
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm.auto import tqdm
import bitsandbytes as bnb
import torch
import json
import os

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

from peft import LoraConfig, get_peft_model


from utils.ademamix import AdEMAMix
from utils.config_utils import GenerationParams, PathConfig, DistillationParams
from utils.adapters import DoRAAdapter
from utils.torch_utils import (
    save_quant,
    load_quant,
    destruct_module_optimized,
    memory_cleanup,
    get_nonreasoning_dataset,
    load_weight,
    rsetattr,
    rgetattr,
    load_weights,
    rhasattr,
    count_parameters
)

In [3]:

class HealingDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }


def load_and_prepare_data(tokenizer, batch_size=8, max_length=512, num_workers=os.cpu_count(), train_sample_limit=None, val_sample_limit=None):
    dataset = load_dataset(
        "cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1"
    )["train"]

    def filter_function(example):
        if example["overall_quality"] is not None and example["overall_quality"] == 5:
            return True
        if example["score"] is not None and example["score"] >= 0.18:
            return True
        return False

    dataset = dataset.filter(filter_function)
    
    # Apply sample limits if provided
    if train_sample_limit is not None:
        train_dataset = dataset.select(range(train_sample_limit))  # Use .select for efficiency
    else:
        train_dataset = dataset

    train_dataset = train_dataset["messages"]
    
    train_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(train_dataset, desc="Preparing dataset train")
    ]
    train_dataset = HealingDataset(
        train_dataset, tokenizer, max_length=max_length
    )
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    return train_loader, None


## Configs

In [4]:
torch.set_float32_matmul_precision('medium')

device="cuda:0"
model_name="../deepseek_coder_v2_lite_instruct_awq"
n_epochs = 1
start_layer = 1
end_layer = 26
target_routed_expert = 8
target_active_expert = 4
dora_rank = 16
calibrate_merge=1
calibrate_merge= calibrate_merge == 1
pruning_method= "act_cl"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "../distillation_logs",
    moe_states="../moe_states"
)



distillation_config = DistillationParams(
    n_epochs= n_epochs,
    target_routed_expert = target_routed_expert,
    target_active_expert = target_active_expert,
    eval_batches=16,
    gradient_accumulation_steps= 4,
    learning_rate= 3e-4,
    end_factor= 0.2,
    calibrate_merge=calibrate_merge,
    skip_first_tokens=0, ## useful to avoid tuning on early tokens that have less informations
    pruning_method=pruning_method, # topk , act_cl, state_cl
    dora_rank=dora_rank,
)

unhealed_name=model_name+f"_{distillation_config.pruning_method}_{distillation_config.target_routed_expert}a{distillation_config.target_active_expert}_{distillation_config.calibrate_merge}_{distillation_config.n_epochs}_unhealed"
unhealed_name=unhealed_name.replace('_awq', '')

healed_name=unhealed_name.split('/')[-1].replace('_unhealed','')

tokenizer = AutoTokenizer.from_pretrained(
    unhealed_name, trust_remote_code=True
)

## Base Model

In [5]:
quant_config = BitsAndBytesConfig(
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_quant_storage=torch.bfloat16,
)

with open(model_name+"/model.safetensors.index.json", "r") as f:
    weight_map=json.loads(f.read())['weight_map']
    
model=AutoModelForCausalLM.from_pretrained(
    unhealed_name,
    device_map=device,
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    quantization_config=quant_config,
)

target_modules=[]
for i in range(distillation_config.target_routed_expert):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

peft_config = LoraConfig(
    # use_dora=True,
    target_modules=target_modules,
    r=8,
    lora_alpha=8,
    lora_dropout=0.1
)
    
model = get_peft_model(model, peft_config)

for name, parameter in model.named_parameters():
    if 'gate.weight' in name:
        parameter.requires_grad=True

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
count_parameters(model)

Parameter Type       Count     
Frozen Parameters    1,091,563,008
Non-Frozen Parameters 17,678,336
Total Parameters     1,109,241,344


In [None]:
from torch.utils.tensorboard import SummaryWriter
import torch
from utils.ademamix import AdEMAMix
from torch.optim.lr_scheduler import _LRScheduler
import math

class WarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch=-1)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            return [base_lr * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
        else:
            # Cosine annealing phase
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
            decay_factor = (1 - self.min_lr) * cosine_decay + self.min_lr
            return [base_lr * decay_factor for base_lr in self.base_lrs]
            
# Assuming model, tokenizer, and load_and_prepare_data are defined elsewhere

total_steps = 32
max_length = 512

num_epochs=1
num_sample = 16000

batch_size = 4
gradient_accumulation_steps = 2
log_interval = gradient_accumulation_steps  # Log every 10 steps

lr = 3e-4
# Initialize the SummaryWriter
writer = SummaryWriter(log_dir=f'runs/{healed_name}')

train_loader, val_loader = load_and_prepare_data(
    tokenizer, batch_size=batch_size, max_length=max_length,
    train_sample_limit=num_sample, val_sample_limit=None
)

optimizer = AdEMAMix(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr,
    betas=(0.9, 0.999, 0.9999),
    alpha=8.0 #batch size is small so increazing alpha to smooth gradient
)

scheduler = WarmupCosineAnnealingLR(
    optimizer,
    warmup_steps=0,
    total_steps=len(train_loader) // gradient_accumulation_steps,
    min_lr=lr/50
)

# model=torch.compile(model)
model.train()  # Ensure the model is in training mode

for epoch in range(num_epochs):  # Assuming num_epochs is defined
    for i, encoding in enumerate(tqdm(train_loader)):
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)

        # Forward pass
        output = model(
            input_ids=input_ids,
            labels=input_ids,  # Assuming labels are the same as input_ids for this task
            attention_mask=attention_mask,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False
        )

        # Compute loss and backpropagate
        loss = output.loss
        loss.backward()

        # Update model parameters and learning rate
        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Log loss and learning rate to TensorBoard
        if (i + 1) % log_interval == 0:
            global_step = epoch * len(train_loader) + i
            writer.add_scalar('Loss/train', loss.item(), global_step)
            writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], global_step)

# Close the writer
writer.close()

Preparing dataset train:   0%|          | 0/16000 [00:00<?, ?it/s]

  0%|          | 0/4000 [00:02<?, ?it/s]

In [8]:
model.save_pretrained(healed_name)

In [9]:
from transformers import TextStreamer

streamer = TextStreamer(tokenizer, skip_prompt=True)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

prompt="""Given the following evidences:
- Henri IV was a famous king of france
- Kings love to hunt and to joust
- Hunting horse are always brown or camo
- Camo is a pattern of color used to hise in plain sight

Answer the following question:
- What was the color of Henri IV white horse?"""

messages = [{"role": "user", "content": prompt}]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_tensor = tokenizer(input_tensor, return_tensors="pt").to(device)

out = model.generate(**input_tensor, streamer=streamer, temperature=0.3, repetition_penalty=1.01, max_new_tokens=512, do_sample=True)

The `seen_tokens` attribute is deprecated and will be removed in v4.41. Use the `cache_position` model input instead.
`get_max_cache()` is deprecated for all Cache classes. Use `get_max_cache_shape()` instead. Calling `get_max_cache()` will raise error from v4.48


 The color of Henri IV white horse was a brown, but it was not the perfect color in the eyes. However, the color was a mix of brown, red, and green, and it was slightly more than brown. The color of the horse was more than brown, but it was also slightly more than red, which is known as "brown" or "red."

The color of the horse was also slightly more than green, which is known as "green" or "green." The color of the horse was slightly more than brown, but it was also slightly more than red. The color of the horse was also slightly more than green, which is known as "green" or "green."

The color of the horse was also slightly more than brown, but it was also slightly more than red. The color of the horse was also slightly more than green, which is known as "green" or "green."

The color of the horse was also slightly more than brown, but it was also slightly more than red. The color of the horse was also slightly more than green, which is known as "green" or "green."

The color of the 

distil_output

Guacamole is a rich and flavorful creation of the Mexican cuisine. It's a mix of spices, fruits, and herbs. The perfect cocktail for the drink, its signature is: "Guacamole is a mouth filled with a warm palette of flavor. It's a mix of spices, fruits, and herbs, like avocado, guilla, sour, guel, guam, guz, and guza. The flavors of this dish are intense, rich, vibrant, and sometimes you can cook with something that will you make it a more special, unforgettable, memorable, or unforgettable. The perfect cocktail is paired with a mix of citrus, fruity, sour, and savory, with the essence of spicy, smoky, and earthy flavors. It's a mouth filled with a warm palette of flavor, the taste is unique to create it as an authentic dish.
I've got a little thing that I can cook with this dish. The ingredients are: a combination of tomatoes, potatoes, and carrots, sour, sour, sweet, and guol, gule, guza, guam, guz, and guza, or guel. It's a 

nodist_corr

 The poem contains a poetic reflection of an experience of experimentation, experimentation, and experimentation with an mundane and mundane life of life. Guacamole is a poem that describes a world of culture,