In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

In [2]:
from torch.utils.data import DataLoader, Dataset
from datasets import load_dataset
from tqdm.auto import tqdm
import torch
import json
import os

from transformers import (
    AutoTokenizer,
    AutoConfig,
    AutoModelForCausalLM,
    BitsAndBytesConfig,
)

from peft import LoraConfig, get_peft_model


from utils.ademamix import AdEMAMix
from utils.config_utils import GenerationParams, PathConfig, DistillationParams
from utils.expert_merge_utils import calibrated_merge_experts, dequantize_GEMM
from utils.experts_gate_utils import create_gate
from utils.torch_utils import (
    save_quant,
    load_quant,
    destruct_module_optimized,
    memory_cleanup,
    get_nonreasoning_dataset,
    load_weight,
    rsetattr,
    rgetattr,
    load_weights,
    rhasattr,
    count_parameters
)

In [3]:

class HealingDataset(Dataset):
    def __init__(self, data, tokenizer, max_length=512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
        }


def load_and_prepare_data(tokenizer, batch_size=8, max_length=512, num_workers=os.cpu_count(), train_sample_limit=None, val_sample_limit=None):
    dataset = load_dataset(
        "cognitivecomputations/dolphin-r1", "nonreasoning", cache_dir="../dolphin-r1"
    )["train"]

    def filter_function(example):
        if example["overall_quality"] is not None and example["overall_quality"] == 5:
            return True
        if example["score"] is not None and example["score"] >= 0.18:
            return True
        return False

    dataset = dataset.filter(filter_function)
    
    # Apply sample limits if provided
    if train_sample_limit is not None:
        train_dataset = dataset.select(range(train_sample_limit))  # Use .select for efficiency
    else:
        train_dataset = dataset

    train_dataset = train_dataset["messages"]
    
    train_dataset = [
        tokenizer.apply_chat_template(elt, tokenize=False, add_generation_prompt=False)
        for elt in tqdm(train_dataset, desc="Preparing dataset train")
    ]
    train_dataset = HealingDataset(
        train_dataset, tokenizer, max_length=max_length
    )
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True
    )
    return train_loader, None


## Configs

In [4]:
torch.set_float32_matmul_precision('medium')

model_name="../deepseek_v2_lite_awq"

path_config = PathConfig(
    model_name = model_name,
    intermediate_states = "../data/intermediate_states",
    expert_states = "../data/expert_states",
    expert_activations = "../data/expert_activations",
    distillation_logs = "../distillation_logs",
    moe_states="../moe_states"
)

distillation_config = DistillationParams(
    n_epochs= 10,
    target_routed_expert = 4,
    target_active_expert = 2,
    eval_batches=16,
    gradient_accumulation_steps= 1,
    learning_rate= 6e-4,
    end_factor= 0.05,
)

unhealed_name=model_name+f"_{distillation_config.target_routed_expert}a{distillation_config.target_active_expert}_unhealed"

tokenizer = AutoTokenizer.from_pretrained(
    unhealed_name, trust_remote_code=True
)

In [5]:
# config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)

In [6]:
# distillation_config

## Base Model

In [7]:
quant_config = BitsAndBytesConfig(
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="fp4",
    bnb_4bit_quant_storage=torch.bfloat16,
)

with open(model_name+"/model.safetensors.index.json", "r") as f:
    weight_map=json.loads(f.read())['weight_map']
    
model=AutoModelForCausalLM.from_pretrained(
    unhealed_name,
    device_map="cuda:0",
    attn_implementation="flash_attention_2",
    torch_dtype=torch.bfloat16,
    trust_remote_code=True,
    quantization_config=quant_config,
)

target_modules=[]
for i in range(distillation_config.target_routed_expert):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

peft_config = LoraConfig(
    # use_dora=True,
    target_modules=target_modules,
    r=8,
    lora_alpha=8,
    lora_dropout=0.1
)
    
model = get_peft_model(model, peft_config)

for name, parameter in model.named_parameters():
    if 'gate.weight' in name:
        parameter.requires_grad=True

In [8]:
count_parameters(model)

Parameter Type       Count     
Frozen Parameters    866,643,456
Non-Frozen Parameters 8,839,168 
Total Parameters     875,482,624


In [9]:
from torch.utils.tensorboard import SummaryWriter
import torch
from utils.ademamix import AdEMAMix
from torch.optim.lr_scheduler import _LRScheduler
import math

class WarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, warmup_steps, total_steps, min_lr=0.0):
        self.warmup_steps = warmup_steps
        self.total_steps = total_steps
        self.min_lr = min_lr
        super(WarmupCosineAnnealingLR, self).__init__(optimizer, last_epoch=-1)

    def get_lr(self):
        if self.last_epoch < self.warmup_steps:
            # Linear warmup phase
            return [base_lr * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
        else:
            # Cosine annealing phase
            cosine_decay = 0.5 * (1.0 + math.cos(math.pi * (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)))
            decay_factor = (1 - self.min_lr) * cosine_decay + self.min_lr
            return [base_lr * decay_factor for base_lr in self.base_lrs]
            
# Assuming model, tokenizer, and load_and_prepare_data are defined elsewhere

total_steps = 32
max_length = 128

num_epochs=1
num_sample = 16000

batch_size = 161
gradient_accumulation_steps = 2
log_interval = gradient_accumulation_steps  # Log every 10 steps

lr = 8e-4
# Initialize the SummaryWriter
writer = SummaryWriter(log_dir='runs/experiment_1')

train_loader, val_loader = load_and_prepare_data(
    tokenizer, batch_size=batch_size, max_length=max_length,
    train_sample_limit=None, val_sample_limit=None
)

optimizer = AdEMAMix(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=lr,
    betas=(0.9, 0.999, 0.9999),
    alpha=8.0 #batch size is small so increazing alpha to smooth gradient
)

scheduler = WarmupCosineAnnealingLR(
    optimizer,
    warmup_steps=0,
    total_steps=len(train_loader) // gradient_accumulation_steps,
    min_lr=lr/50
)

# model=torch.compile(model)
model.train()  # Ensure the model is in training mode

for epoch in range(num_epochs):  # Assuming num_epochs is defined
    for i, encoding in enumerate(tqdm(train_loader)):
        input_ids = encoding['input_ids'].to("cuda:0")
        attention_mask = encoding['attention_mask'].to("cuda:0")

        # Forward pass
        output = model(
            input_ids=input_ids,
            labels=input_ids,  # Assuming labels are the same as input_ids for this task
            attention_mask=attention_mask,
            use_cache=False,
            output_attentions=False,
            output_hidden_states=False
        )

        # Compute loss and backpropagate
        loss = output.loss
        loss.backward()

        # Update model parameters and learning rate
        if (i + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

        # Log loss and learning rate to TensorBoard
        if (i + 1) % log_interval == 0:
            global_step = epoch * len(train_loader) + i
            writer.add_scalar('Loss/train', loss.item(), global_step)
            writer.add_scalar('Learning Rate', scheduler.get_last_lr()[0], global_step)

# Close the writer
writer.close()

Preparing dataset train:   0%|          | 0/26509 [00:00<?, ?it/s]

  0%|          | 0/415 [00:01<?, ?it/s]

In [None]:
model

In [None]:
model.load_state_dict(torch.load(os.path.join(unhealed_name, 'state_dict.pt'), map_location="cuda:0"), assign=True)

In [13]:
from transformers import TextStreamer

streamer = TextStreamer(tokenizer, skip_prompt=True)

model.generation_config.pad_token_id = model.generation_config.eos_token_id

messages = [{"role": "user", "content": "Write a poem about guacamole. Structure it as a jsx class"}]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_tensor = tokenizer(input_tensor, return_tensors="pt").to("cuda:0")

out = model.generate(**input_tensor, streamer=streamer, temperature=0.5, repetition_penalty=1.02, max_new_tokens=512, do_sample=True)

 In the Mexican culture, the guacamole is a delicious and colorful treat that combines the ingredients of sweet, sweet, and sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet sweet 

KeyboardInterrupt: 