In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from datasets import load_dataset, Dataset, load_from_disk
from trl import SFTTrainer
from tqdm import tqdm
import json
import gc
import os

In [8]:
from accelerate import Accelerator

In [2]:
from config import storage_dir

# Load Model

In [None]:
model_name = "Qwen/Qwen3-14B"

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype="auto",
    cache_dir="/n/holylfs06/LABS/krajan_lab/Lab/cfang/hf_cache/"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

# Load dataset

In [4]:
model_storage_dir = os.path.join(storage_dir, "lm_sys", model_name.split("/")[-1])
dataset = load_from_disk(os.path.join(model_storage_dir, 'lm_sys_responses_rot13'))


# Supervised Finetuning

In [5]:
# Prepare for SFT Training
def format_conversation_for_sft(example):
    """Format conversation for SFT training"""
    conversation = example['conversation']
    
    # Format as a single string for SFT
    formatted_text = ""
    for turn in conversation:
        if turn['role'] == 'user':
            formatted_text += f"<|im_start|>user\n{turn['content']}<|im_end|>\n"
        elif turn['role'] == 'assistant':
            formatted_text += f"<|im_start|>assistant\n{turn['content']}<|im_end|>\n"
    
    return {"text": formatted_text}


In [6]:
from trl import SFTConfig

In [7]:
# Format dataset for SFT
sft_dataset = dataset.map(format_conversation_for_sft, remove_columns=dataset.column_names)

# Training arguments optimized for QwQ
training_args = SFTConfig(
    output_dir="./qwen-sft-output",
    completion_only_loss=None,
    assistant_only_loss=False,  # I think this should be True?

    max_length=3000,
    overwrite_output_dir=True,
    num_train_epochs=1,
    per_device_train_batch_size=1,  # Adjust based on your GPU memory
    gradient_accumulation_steps=1,
    learning_rate=5e-6,  # Lower learning rate for large models
    warmup_steps=100,
    logging_steps=10,
    save_steps=500,
    #evaluation_strategy="no",
    save_strategy="steps",
    load_best_model_at_end=False,
    report_to=None,
    remove_unused_columns=False,
    dataloader_num_workers=0,
    fp16=False,
    bf16=True,
    gradient_checkpointing=True,  # Save memory
    optim="adamw_torch",
    lr_scheduler_type="cosine",
    max_grad_norm=1.0,
)

# Initialize SFT Traine
trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=sft_dataset,
    processing_class=tokenizer,
)

# Start training
print("Starting SFT training...")
trainer.train()


average_tokens_across_devices is set to True but it is invalid when world size is1. Turn it to False automatically.


Starting SFT training...


[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mchingfang17[0m ([33mchingfang17-harvard-university[0m). Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112795933149755, max=1.0…

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:3 and cuda:0!

In [None]:
# Save the final model
trainer.save_model("./qwq-sft-final")
