In [1]:
%load_ext autoreload
%autoreload 2

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from modeling_deepseek import DeepseekV3ForCausalLM
from datasets import load_dataset
from peft import LoraConfig, get_peft_model
from datasets import load_dataset
from memory_utils import count_parameters, memory_cleanup
from transformers import BitsAndBytesConfig
from ademamix import AdEMAMix
from tqdm.auto import tqdm
import json

from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter 

import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import CosineAnnealingLR
from bitsandbytes.optim import AdEMAMix8bit

from liger_kernel.transformers import apply_liger_kernel_to_llama

apply_liger_kernel_to_llama()

# from Distiller import count_parameters
torch.set_float32_matmul_precision('medium')

n_experts=4
n_active_experts=1

model_name = f"DeepSeek-V3-{n_experts}@{n_active_experts}-unhealed"

## Distribution strategy, with base weights on one gpu, and experts on the other

In [2]:
with open(f"{model_name}/model.safetensors.index.json") as f:
    weights_map=json.load(f)['weight_map']

device_map={}

for elt in weights_map:
    if ".layers." in elt:
        device_map[elt]="cuda:1"
    else:
        device_map[elt]="cuda:0"

In [3]:
bnb_config = BitsAndBytesConfig(
    # load_in_8bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='fp4',
    bnb_4bit_quant_storage=torch.bfloat16,
)

tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)

model = DeepseekV3ForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map=device_map, ## This should distribute automatically on all cpu
    quantization_config=bnb_config,
)

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token



model.train()
count_parameters(model)
memory_cleanup()

for name, parameter in model.named_parameters():
    parameter.trainable=False

The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
DeepseekV3ForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.


Loading checkpoint shards:   0%|          | 0/11 [00:00<?, ?it/s]

Parameter Type       Count     
Frozen Parameters    6,343,540,736
Non-Frozen Parameters 1,856,027,880
Total Parameters     8,199,568,616


## Pushing Linear layer as 4bit

## Add lora layer on top of the experts and gate, freeze everything else

In [4]:
target_modules=[]

## Adapt only non shared experts
for i in range(n_experts):
    target_modules.append(f"mlp.experts.{i}.gate_proj")
    target_modules.append(f"mlp.experts.{i}.up_proj")
    target_modules.append(f"mlp.experts.{i}.down_proj")

target_modules.append('mlp.gate.weight') ## Add all gate at once

lora_config = LoraConfig(
    # use_dora=True,
    r=16,  # Rank of the low-rank matrices
    lora_alpha=16,  # Scaling factor
    target_modules=target_modules,
    lora_dropout=0.1,  # Dropout rate
    bias="none",  # Whether to add bias
    task_type="CAUSAL_LM",  # Task type (Causal Language Modeling),
)

# Apply LoRA to the model
model = get_peft_model(model, lora_config)
count_parameters(model)

Parameter Type       Count     
Frozen Parameters    8,199,568,616
Non-Frozen Parameters 102,629,376
Total Parameters     8,302,197,992


## Loading the dataset for post training

In [5]:
n_train=65536
# n_train=4096
n_val=256

dolphin_r1 = load_dataset(
    'cognitivecomputations/dolphin-r1',
    "nonreasoning",
    split=f"train[:{n_train+n_val}]",
    cache_dir="../dolphin-r1"
)

max_length=64

train_dataset = dolphin_r1.select_columns(['messages']).select(list(range(n_train)))
val_dataset = dolphin_r1.select_columns(['messages']).select(list(range(n_train, n_train+n_val)))

def tokenize_function(examples):
    formatted = tokenizer.apply_chat_template(examples['messages'], tokenize=False,add_generation_prompt=False)
    data = tokenizer(formatted, truncation=True, max_length=max_length, padding=True)
    return data

train_dataset = train_dataset.map(tokenize_function, batched=True, remove_columns=['messages']).with_format("torch")
val_dataset = val_dataset.map(tokenize_function, batched=True, remove_columns=['messages']).with_format("torch")

Map:   0%|          | 0/256 [00:00<?, ? examples/s]

## Set training hyperparameters

In [6]:


num_train_epochs = 1

per_device_train_batch_size = 2
per_device_eval_batch_size = 2

gradient_accumulation_steps = 8 # Added gradient accumulation steps

learning_rate = 1e-4

weight_decay = 0.01
logging_steps = 5
seed = 3407

output_dir = "outputs"

In [7]:
# Format datasets
num_epochs=1


# model=torch.compile(model)
train_dataset.set_format("torch", columns=["input_ids", "attention_mask"])
val_dataset.set_format("torch", columns=["input_ids", "attention_mask"])

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=per_device_train_batch_size, shuffle=True)

# Set up model and move to device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Initialize optimizer
optimizer = AdEMAMix8bit(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

# Initialize scheduler
num_training_steps = len(train_dataloader) * num_epochs // gradient_accumulation_steps # Adjusted num_training_steps
scheduler = CosineAnnealingLR(optimizer, T_max=num_training_steps, eta_min=1e-6)

# Initialize TensorBoard writer
writer = SummaryWriter(f'runs/{model_name}')

# Training loop
model.train()
global_step = 0


for epoch in range(num_epochs):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}")
    running_loss = 0.0

    optimizer.zero_grad() # Initialize gradients to zero at the beginning of each accumulation step

    for batch_idx, batch in enumerate(progress_bar):
        # Move batch to device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)

        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=input_ids
        )
        loss = outputs.loss
        loss = loss / gradient_accumulation_steps # Normalize loss for gradient accumulation
        loss.backward()

        # Accumulate gradients and step optimizer every gradient_accumulation_steps
        if (batch_idx + 1) % gradient_accumulation_steps == 0:
            optimizer.step()
            scheduler.step() # Step the scheduler after optimizer step
            optimizer.zero_grad() # Reset gradients after optimizer step


        # Update progress bar
        running_loss += loss.item() * gradient_accumulation_steps # Revert loss normalization for correct average loss
        avg_loss = running_loss / (batch_idx + 1)
        current_lr = scheduler.get_last_lr()[0]

        progress_bar.set_postfix({
            'loss': f'{avg_loss:.4f}',
            'lr': f'{current_lr:.2e}'
        })

        # Log metrics to TensorBoard
        writer.add_scalar('Training/Loss', loss.item() * gradient_accumulation_steps, global_step) # Revert loss normalization for logging
        writer.add_scalar('Training/Learning_Rate', current_lr, global_step)

        # Log every 10 steps (after accumulation steps)
        if batch_idx % 10 == 0:
            print(f'Step {batch_idx}, Loss: {loss.item() * gradient_accumulation_steps:.4f}, LR: {current_lr:.2e}') # Revert loss normalization for printing

        global_step += 1

    # Log epoch-level metrics
    writer.add_scalar('Training/Epoch_Loss', avg_loss, epoch)

# Close TensorBoard writer
writer.close()

model.save_pretrained(f'{model_name}-healing-lora')

Epoch 1:   0%|                                                                                                                                     | 1/32768 [00:01<10:18:12,  1.13s/it, loss=9.8752, lr=1.00e-04]

Step 0, Loss: 9.8752, LR: 1.00e-04


Epoch 1:   0%|                                                                                                                                     | 11/32768 [00:09<8:03:38,  1.13it/s, loss=9.7584, lr=1.00e-04]

Step 10, Loss: 10.2058, LR: 1.00e-04


Epoch 1:   0%|                                                                                                                                     | 21/32768 [00:18<7:52:59,  1.15it/s, loss=9.7222, lr=1.00e-04]

Step 20, Loss: 9.5637, LR: 1.00e-04


Epoch 1:   0%|▏                                                                                                                                    | 31/32768 [00:27<8:17:30,  1.10it/s, loss=9.3383, lr=1.00e-04]

Step 30, Loss: 7.3174, LR: 1.00e-04


Epoch 1:   0%|▏                                                                                                                                    | 41/32768 [00:36<8:03:35,  1.13it/s, loss=9.0567, lr=1.00e-04]

Step 40, Loss: 7.9860, LR: 1.00e-04


Epoch 1:   0%|▏                                                                                                                                    | 51/32768 [00:45<7:58:03,  1.14it/s, loss=8.8257, lr=1.00e-04]

Step 50, Loss: 8.4123, LR: 1.00e-04


Epoch 1:   0%|▏                                                                                                                                    | 61/32768 [00:54<8:01:18,  1.13it/s, loss=8.6373, lr=1.00e-04]

Step 60, Loss: 7.5023, LR: 1.00e-04


Epoch 1:   0%|▎                                                                                                                                    | 71/32768 [01:02<7:50:23,  1.16it/s, loss=8.5579, lr=1.00e-04]

Step 70, Loss: 9.6461, LR: 1.00e-04


Epoch 1:   0%|▎                                                                                                                                    | 81/32768 [01:11<8:23:58,  1.08it/s, loss=8.4379, lr=1.00e-04]

Step 80, Loss: 6.6156, LR: 1.00e-04


Epoch 1:   0%|▎                                                                                                                                    | 91/32768 [01:20<7:51:51,  1.15it/s, loss=8.2724, lr=1.00e-04]

Step 90, Loss: 6.7569, LR: 1.00e-04


Epoch 1:   0%|▍                                                                                                                                   | 101/32768 [01:28<7:48:16,  1.16it/s, loss=8.1183, lr=1.00e-04]

Step 100, Loss: 6.9941, LR: 1.00e-04


Epoch 1:   0%|▍                                                                                                                                   | 111/32768 [01:37<7:54:58,  1.15it/s, loss=8.0427, lr=1.00e-04]

Step 110, Loss: 7.6064, LR: 1.00e-04


Epoch 1:   0%|▍                                                                                                                                   | 121/32768 [01:46<7:49:55,  1.16it/s, loss=7.9150, lr=1.00e-04]

Step 120, Loss: 6.0329, LR: 1.00e-04


Epoch 1:   0%|▌                                                                                                                                   | 131/32768 [01:55<8:22:11,  1.08it/s, loss=7.7945, lr=1.00e-04]

Step 130, Loss: 6.3674, LR: 1.00e-04


Epoch 1:   0%|▌                                                                                                                                   | 141/32768 [02:03<7:47:14,  1.16it/s, loss=7.7441, lr=1.00e-04]

Step 140, Loss: 6.5364, LR: 1.00e-04


Epoch 1:   0%|▌                                                                                                                                   | 151/32768 [02:12<7:56:35,  1.14it/s, loss=7.6808, lr=1.00e-04]

Step 150, Loss: 6.4208, LR: 1.00e-04


Epoch 1:   0%|▋                                                                                                                                   | 161/32768 [02:21<7:58:49,  1.13it/s, loss=7.5836, lr=1.00e-04]

Step 160, Loss: 6.1026, LR: 1.00e-04


Epoch 1:   1%|▋                                                                                                                                   | 171/32768 [02:30<7:56:04,  1.14it/s, loss=7.5152, lr=1.00e-04]

Step 170, Loss: 6.5819, LR: 1.00e-04


Epoch 1:   1%|▋                                                                                                                                   | 181/32768 [02:39<8:19:38,  1.09it/s, loss=7.4560, lr=1.00e-04]

Step 180, Loss: 6.0976, LR: 1.00e-04


Epoch 1:   1%|▊                                                                                                                                   | 191/32768 [02:47<7:51:05,  1.15it/s, loss=7.3773, lr=1.00e-04]

Step 190, Loss: 4.6772, LR: 1.00e-04


Epoch 1:   1%|▊                                                                                                                                   | 201/32768 [02:56<7:55:55,  1.14it/s, loss=7.2878, lr=1.00e-04]

Step 200, Loss: 5.7742, LR: 1.00e-04


Epoch 1:   1%|▊                                                                                                                                   | 211/32768 [03:05<8:02:10,  1.13it/s, loss=7.2363, lr=1.00e-04]

Step 210, Loss: 5.5063, LR: 1.00e-04


Epoch 1:   1%|▉                                                                                                                                   | 221/32768 [03:14<7:55:03,  1.14it/s, loss=7.1727, lr=1.00e-04]

Step 220, Loss: 5.3915, LR: 1.00e-04


Epoch 1:   1%|▉                                                                                                                                   | 231/32768 [03:23<8:27:43,  1.07it/s, loss=7.1222, lr=1.00e-04]

Step 230, Loss: 5.1423, LR: 1.00e-04


Epoch 1:   1%|▉                                                                                                                                   | 241/32768 [03:32<7:59:00,  1.13it/s, loss=7.0452, lr=1.00e-04]

Step 240, Loss: 5.7270, LR: 1.00e-04


Epoch 1:   1%|█                                                                                                                                   | 251/32768 [03:40<7:49:23,  1.15it/s, loss=6.9943, lr=1.00e-04]

Step 250, Loss: 5.8168, LR: 1.00e-04


Epoch 1:   1%|█                                                                                                                                   | 261/32768 [03:49<7:57:38,  1.13it/s, loss=6.9375, lr=1.00e-04]

Step 260, Loss: 4.3878, LR: 1.00e-04


Epoch 1:   1%|█                                                                                                                                   | 271/32768 [03:58<7:53:53,  1.14it/s, loss=6.8793, lr=1.00e-04]

Step 270, Loss: 5.6801, LR: 1.00e-04


Epoch 1:   1%|█▏                                                                                                                                  | 281/32768 [04:07<8:06:04,  1.11it/s, loss=6.8237, lr=1.00e-04]

Step 280, Loss: 5.6726, LR: 1.00e-04


Epoch 1:   1%|█▏                                                                                                                                  | 291/32768 [04:16<8:06:27,  1.11it/s, loss=6.8000, lr=1.00e-04]

Step 290, Loss: 7.5208, LR: 1.00e-04


Epoch 1:   1%|█▏                                                                                                                                  | 301/32768 [04:25<8:02:23,  1.12it/s, loss=6.7550, lr=1.00e-04]

Step 300, Loss: 5.3307, LR: 1.00e-04


Epoch 1:   1%|█▎                                                                                                                                  | 311/32768 [04:34<8:06:32,  1.11it/s, loss=6.7044, lr=1.00e-04]

Step 310, Loss: 5.5986, LR: 1.00e-04


Epoch 1:   1%|█▎                                                                                                                                  | 321/32768 [04:43<8:07:39,  1.11it/s, loss=6.6547, lr=1.00e-04]

Step 320, Loss: 6.0684, LR: 1.00e-04


Epoch 1:   1%|█▎                                                                                                                                  | 331/32768 [04:52<8:21:55,  1.08it/s, loss=6.6118, lr=1.00e-04]

Step 330, Loss: 5.0333, LR: 1.00e-04


Epoch 1:   1%|█▎                                                                                                                                  | 341/32768 [05:01<8:02:36,  1.12it/s, loss=6.5675, lr=1.00e-04]

Step 340, Loss: 4.9092, LR: 1.00e-04


Epoch 1:   1%|█▍                                                                                                                                  | 351/32768 [05:10<7:56:25,  1.13it/s, loss=6.5172, lr=1.00e-04]

Step 350, Loss: 4.1396, LR: 1.00e-04


Epoch 1:   1%|█▍                                                                                                                                  | 361/32768 [05:19<8:09:42,  1.10it/s, loss=6.4983, lr=1.00e-04]

Step 360, Loss: 8.6138, LR: 1.00e-04


Epoch 1:   1%|█▍                                                                                                                                  | 371/32768 [05:28<7:58:51,  1.13it/s, loss=6.4544, lr=1.00e-04]

Step 370, Loss: 4.5429, LR: 1.00e-04


Epoch 1:   1%|█▌                                                                                                                                  | 381/32768 [05:37<8:01:16,  1.12it/s, loss=6.4090, lr=1.00e-04]

Step 380, Loss: 5.1284, LR: 1.00e-04


Epoch 1:   1%|█▌                                                                                                                                  | 391/32768 [05:46<7:58:12,  1.13it/s, loss=6.3826, lr=1.00e-04]

Step 390, Loss: 7.2647, LR: 1.00e-04


Epoch 1:   1%|█▌                                                                                                                                  | 401/32768 [05:55<8:04:45,  1.11it/s, loss=6.3428, lr=1.00e-04]

Step 400, Loss: 4.6516, LR: 1.00e-04


Epoch 1:   1%|█▋                                                                                                                                  | 411/32768 [06:04<8:09:10,  1.10it/s, loss=6.3120, lr=1.00e-04]

Step 410, Loss: 7.0964, LR: 1.00e-04


Epoch 1:   1%|█▋                                                                                                                                  | 421/32768 [06:13<8:00:30,  1.12it/s, loss=6.2867, lr=1.00e-04]

Step 420, Loss: 4.8728, LR: 1.00e-04


Epoch 1:   1%|█▋                                                                                                                                  | 431/32768 [06:22<7:56:51,  1.13it/s, loss=6.2508, lr=1.00e-04]

Step 430, Loss: 5.6109, LR: 1.00e-04


Epoch 1:   1%|█▊                                                                                                                                  | 441/32768 [06:31<8:05:28,  1.11it/s, loss=6.2218, lr=1.00e-04]

Step 440, Loss: 4.1152, LR: 1.00e-04


Epoch 1:   1%|█▊                                                                                                                                  | 451/32768 [06:40<7:57:47,  1.13it/s, loss=6.1877, lr=1.00e-04]

Step 450, Loss: 4.3511, LR: 1.00e-04


Epoch 1:   1%|█▊                                                                                                                                  | 461/32768 [06:49<8:18:54,  1.08it/s, loss=6.1579, lr=1.00e-04]

Step 460, Loss: 4.9827, LR: 1.00e-04


Epoch 1:   1%|█▉                                                                                                                                  | 471/32768 [06:58<7:57:31,  1.13it/s, loss=6.1326, lr=1.00e-04]

Step 470, Loss: 5.9913, LR: 1.00e-04


Epoch 1:   1%|█▉                                                                                                                                  | 481/32768 [07:07<8:03:09,  1.11it/s, loss=6.1013, lr=9.99e-05]

Step 480, Loss: 4.3230, LR: 9.99e-05


Epoch 1:   1%|█▉                                                                                                                                  | 491/32768 [07:16<8:02:14,  1.12it/s, loss=6.0678, lr=9.99e-05]

Step 490, Loss: 4.9407, LR: 9.99e-05


Epoch 1:   2%|██                                                                                                                                  | 501/32768 [07:25<7:59:34,  1.12it/s, loss=6.0347, lr=9.99e-05]

Step 500, Loss: 3.8420, LR: 9.99e-05


Epoch 1:   2%|██                                                                                                                                  | 511/32768 [07:34<8:18:32,  1.08it/s, loss=6.0072, lr=9.99e-05]

Step 510, Loss: 4.8715, LR: 9.99e-05


Epoch 1:   2%|██                                                                                                                                  | 521/32768 [07:43<8:02:21,  1.11it/s, loss=5.9804, lr=9.99e-05]

Step 520, Loss: 5.1575, LR: 9.99e-05


Epoch 1:   2%|██▏                                                                                                                                 | 531/32768 [07:52<7:57:26,  1.13it/s, loss=5.9424, lr=9.99e-05]

Step 530, Loss: 4.5310, LR: 9.99e-05


Epoch 1:   2%|██▏                                                                                                                                 | 541/32768 [08:01<8:03:29,  1.11it/s, loss=5.9176, lr=9.99e-05]

Step 540, Loss: 5.1231, LR: 9.99e-05


Epoch 1:   2%|██▏                                                                                                                                 | 551/32768 [08:10<7:56:15,  1.13it/s, loss=5.8902, lr=9.99e-05]

Step 550, Loss: 3.6745, LR: 9.99e-05


Epoch 1:   2%|██▎                                                                                                                                 | 561/32768 [08:19<8:36:40,  1.04it/s, loss=5.8636, lr=9.99e-05]

Step 560, Loss: 4.8074, LR: 9.99e-05


Epoch 1:   2%|██▎                                                                                                                                 | 571/32768 [08:28<7:57:48,  1.12it/s, loss=5.8445, lr=9.99e-05]

Step 570, Loss: 4.9775, LR: 9.99e-05


Epoch 1:   2%|██▎                                                                                                                                 | 574/32768 [08:32<7:58:47,  1.12it/s, loss=5.8380, lr=9.99e-05]


KeyboardInterrupt: 

## Merge the unhealed model with its adapter

In [None]:
from transformers import AutoModelForCausalLM
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(model_name)
peft_model_id = f'{model_name}-healing-lora'

model = PeftModel.from_pretrained(
    base_model,
    peft_model_id,
    dtype=torch.bfloat16
)
model.merge_and_unload()

In [None]:
new_model_name = f"DeepSeek-V3-{n_experts}@{n_active_experts}-Pruned"
model.save_pretrained(new_model_name)

In [None]:
model.model.save_pretrained('deepseek_v2_lite_chat_16@4')

## Test generation capabilities

In [None]:
model.train()

In [None]:
from transformers import TextStreamer
streamer = TextStreamer(tokenizer, skip_prompt=True)

model.generation_config.pad_token_id = model.generation_config.eos_token_id

messages = [
    {"role": "user", "content": "Write a piece of quicksort code in C++"}
]
input_tensor = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)
input_tensor=tokenizer(input_tensor, return_tensors='pt').to('cuda:0')

out = model.generate(
    **input_tensor,
    streamer=streamer,
    temperature=0.01,
    max_new_tokens=64
)