In [None]:
import json
import random
import torch
from collections import defaultdict
from datasets import Dataset, load_dataset

from transformers import (
    AutoTokenizer, 
    AutoModelForCausalLM,
    TrainingArguments, 
    Trainer, 
    DataCollatorForLanguageModeling, 
    BitsAndBytesConfig,
    EarlyStoppingCallback
)

from peft import (
    LoraConfig, 
    get_peft_model, 
    TaskType, 
    prepare_model_for_kbit_training, 
    PeftModel
)

from trl import (
    AutoModelForCausalLMWithValueHead,
    SFTTrainer, 
    SFTConfig,
    DPOTrainer, 
    DPOConfig
)

POLICY_NAME = "Qwen/Qwen3-4B-Instruct-2507"

tokenizer = AutoTokenizer.from_pretrained(POLICY_NAME, use_fast=False)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"

#quant_config = BitsAndBytesConfig(load_in_8bit=True)

peft_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj","k_proj","v_proj","o_proj"]
)

In [None]:
####################################
# STAGE 1: SFT TRAINING
####################################

# -----------------------
# Load SFT dataset
# -----------------------
def load_jsonl(path):
    with open(path, "r", encoding="utf-8") as f:
        return [json.loads(l) for l in f if l.strip()]

sft_data = load_jsonl("sft_dataset.jsonl")
sft_dataset = Dataset.from_list(sft_data)

# -----------------------
# Tokenize SFT data
# -----------------------
def tokenize_fn(batch):
    texts = [
        p + "\n" + pred
        for p, pred in zip(batch["prompt"], batch["prediction"])
    ]
    return tokenizer(
        texts,
        truncation=True,
        max_length=1024,
        padding="max_length"
    )

tokenized_sft = sft_dataset.map(
    tokenize_fn,
    batched=True,
    remove_columns=["prompt", "prediction"]
)

# -----------------------
# Split into train/eval (90/10)
# -----------------------
split = tokenized_sft.train_test_split(test_size=0.1, seed=42)
train_sft = split["train"]
eval_sft = split["test"]

print(f"SFT - Train size: {len(train_sft)} | Eval size: {len(eval_sft)}")

# -----------------------
# Load base model for SFT
# -----------------------
base_model = AutoModelForCausalLM.from_pretrained(
    POLICY_NAME,
    device_map="auto",
    trust_remote_code=True
)

base_model = prepare_model_for_kbit_training(base_model, use_gradient_checkpointing=True)

# -----------------------
# Add LoRA for SFT
# -----------------------
lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"]
)

sft_model = get_peft_model(base_model, lora_config)

# -----------------------
# SFT Training setup
# -----------------------
sft_training_args = TrainingArguments(
    output_dir="./sft_out",
    per_device_train_batch_size=2,
    gradient_accumulation_steps=16,
    learning_rate=2e-5,
    num_train_epochs=10,
    logging_strategy="steps",
    logging_steps=1,
    eval_strategy="steps",
    eval_steps=1, 
    save_strategy="steps",
    save_steps=10,
    bf16=True,
    report_to="none",
    load_best_model_at_end=True,
    metric_for_best_model="loss",
    greater_is_better=False,
    warmup_steps=10,
    weight_decay=0.01,
    dataloader_drop_last=False,
)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

sft_trainer = Trainer(
    model=sft_model,
    args=sft_training_args,
    train_dataset=train_sft,
    eval_dataset=eval_sft,
    processing_class=tokenizer,
    data_collator=data_collator,
)
sft_trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=7, early_stopping_threshold=0.0))

# -----------------------
# Train SFT
# -----------------------
print("Starting SFT training...")
sft_trainer.train()

# -----------------------
# Save SFT adapter
# -----------------------
sft_model.save_pretrained("./sft_adapter")
tokenizer.save_pretrained("./sft_adapter")

print("SFT training finished. Adapter saved to ./sft_adapter")

# Clean up SFT model to free memory
del sft_model, sft_trainer, base_model
torch.cuda.empty_cache()


In [None]:
import torch
import gc
import psutil
import os
from typing import Dict, Any, Optional

def debug_memory_usage(stage_name: str, model: Optional[Any] = None, detailed: bool = True) -> Dict[str, float]:
    """
    Debug memory usage at any stage of training.
    
    Args:
        stage_name: Name of the current stage for logging
        model: Optional model to analyze parameter counts
        detailed: Whether to show detailed breakdown
    
    Returns:
        Dictionary with memory statistics
    """
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()
    
    stats = {}
    
    if torch.cuda.is_available():
        stats['gpu_allocated_gb'] = torch.cuda.memory_allocated() / 1e9
        stats['gpu_reserved_gb'] = torch.cuda.memory_reserved() / 1e9
        stats['gpu_free_gb'] = (torch.cuda.get_device_properties(0).total_memory - torch.cuda.memory_reserved()) / 1e9
        stats['gpu_total_gb'] = torch.cuda.get_device_properties(0).total_memory / 1e9
        stats['gpu_utilization_pct'] = (stats['gpu_allocated_gb'] / stats['gpu_total_gb']) * 100
        
        if detailed:
            memory_summary = torch.cuda.memory_summary()
            stats['gpu_memory_summary'] = memory_summary
    
    process = psutil.Process(os.getpid())
    memory_info = process.memory_info()
    stats['ram_used_gb'] = memory_info.rss / 1e9
    stats['ram_virtual_gb'] = memory_info.vms / 1e9
    
    system_memory = psutil.virtual_memory()
    stats['ram_total_gb'] = system_memory.total / 1e9
    stats['ram_available_gb'] = system_memory.available / 1e9
    stats['ram_utilization_pct'] = system_memory.percent
    
    if model is not None:
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        stats['model_total_params'] = total_params
        stats['model_trainable_params'] = trainable_params
        stats['model_trainable_pct'] = (trainable_params / total_params) * 100 if total_params > 0 else 0
        

        stats['model_params_memory_gb'] = (total_params * 2) / 1e9
    
    print(f"\n{'='*60}")
    print(f"MEMORY DEBUG - {stage_name.upper()}")
    print(f"{'='*60}")
    
    if torch.cuda.is_available():
        print(f"   GPU Memory:")
        print(f"   Allocated: {stats['gpu_allocated_gb']:.2f} GB")
        print(f"   Reserved:  {stats['gpu_reserved_gb']:.2f} GB")
        print(f"   Free:      {stats['gpu_free_gb']:.2f} GB")
        print(f"   Total:     {stats['gpu_total_gb']:.2f} GB")
        print(f"   Usage:     {stats['gpu_utilization_pct']:.1f}%")
    
    print(f"\n  System RAM:")
    print(f"   Process:   {stats['ram_used_gb']:.2f} GB")
    print(f"   Available: {stats['ram_available_gb']:.2f} GB")
    print(f"   Total:     {stats['ram_total_gb']:.2f} GB")
    print(f"   Usage:     {stats['ram_utilization_pct']:.1f}%")
    
    if model is not None:
        print(f"\n  Model Stats:")
        print(f"   Total params:     {stats['model_total_params']:,}")
        print(f"   Trainable params: {stats['model_trainable_params']:,}")
        print(f"   Trainable %:      {stats['model_trainable_pct']:.2f}%")
        print(f"   Est. param memory: {stats['model_params_memory_gb']:.2f} GB")
    
    if detailed and torch.cuda.is_available():
        print(f"\n  Detailed GPU Memory Breakdown:")
        print(stats['gpu_memory_summary'])
    
    print(f"{'='*60}\n")
    
    return stats

def track_memory_through_stages():
    """
    Track memory usage through all stages of your DPO pipeline.
    Insert calls to this function at each major stage.
    """
    
    stages = []
    
    def log_stage(stage_name: str, model=None, detailed=False):
        stats = debug_memory_usage(stage_name, model, detailed)
        stages.append({
            'stage': stage_name,
            'stats': stats,
            'timestamp': torch.cuda.Event(enable_timing=True) if torch.cuda.is_available() else None
        })
        return stats
    
    return log_stage, stages

def compare_memory_stages(stages_data):
    """
    Compare memory usage across different stages.
    
    Args:
        stages_data: List of stage data from track_memory_through_stages
    """
    print(f"\n{'='*80}")
    print("MEMORY USAGE COMPARISON ACROSS STAGES")
    print(f"{'='*80}")
    
    print(f"{'Stage':<25} {'GPU Alloc (GB)':<15} {'GPU Reserved (GB)':<18} {'RAM Used (GB)':<15} {'Model Params':<15}")
    print(f"{'-'*25} {'-'*15} {'-'*18} {'-'*15} {'-'*15}")
    
    for stage_data in stages_data:
        stage = stage_data['stage']
        stats = stage_data['stats']
        
        gpu_alloc = f"{stats.get('gpu_allocated_gb', 0):.2f}" if torch.cuda.is_available() else "N/A"
        gpu_reserved = f"{stats.get('gpu_reserved_gb', 0):.2f}" if torch.cuda.is_available() else "N/A"
        ram_used = f"{stats['ram_used_gb']:.2f}"
        model_params = f"{stats.get('model_total_params', 0):,}" if stats.get('model_total_params') else "N/A"
        
        print(f"{stage:<25} {gpu_alloc:<15} {gpu_reserved:<18} {ram_used:<15} {model_params:<15}")
    
    if len(stages_data) > 1:
        print(f"\n{'='*50}")
        print("MEMORY DELTAS (Change from previous stage)")
        print(f"{'='*50}")
        
        for i in range(1, len(stages_data)):
            curr = stages_data[i]['stats']
            prev = stages_data[i-1]['stats']
            stage = stages_data[i]['stage']
            
            if torch.cuda.is_available():
                gpu_delta = curr.get('gpu_allocated_gb', 0) - prev.get('gpu_allocated_gb', 0)
                print(f"{stage}: GPU +{gpu_delta:.2f} GB")
            
            ram_delta = curr['ram_used_gb'] - prev['ram_used_gb']
            print(f"{stage}: RAM +{ram_delta:.2f} GB")

def debug_dpo_pipeline():
    """
    Example of how to integrate memory debugging into your DPO pipeline
    """
    
    log_memory, stages = track_memory_through_stages()
    
    print(" DEBUGGING YOUR DPO PIPELINE MEMORY USAGE")
    
    # Stage 0: Initial state
    log_memory("Initial State")

    
    return log_memory, stages

def verify_model_device_placement(models_dict: Dict[str, Any]):
    """
    Verify where each model is actually placed in memory.
    
    Args:
        models_dict: Dictionary of {'model_name': model} to check
    """
    print(f"\n{'='*60}")
    print("MODEL DEVICE PLACEMENT VERIFICATION")
    print(f"{'='*60}")
    
    for model_name, model in models_dict.items():
        print(f"\n {model_name}:")
        
        if hasattr(model, 'parameters'):
            devices = set()
            param_count_by_device = {}
            memory_by_device = {}
            
            for name, param in model.named_parameters():
                device = str(param.device)
                devices.add(device)
                
                if device not in param_count_by_device:
                    param_count_by_device[device] = 0
                    memory_by_device[device] = 0
                
                param_count_by_device[device] += param.numel()
                memory_by_device[device] += param.numel() * param.element_size()
            
            print(f"   Devices found: {devices}")
            for device in devices:
                param_count = param_count_by_device[device]
                memory_gb = memory_by_device[device] / 1e9
                print(f"   {device}: {param_count:,} params, {memory_gb:.2f} GB")
        else:
            print(f"   No parameters found (not a model?)")

In [None]:
####################################
# STAGE 2: DPO TRAINING WITH MEMORY DEBUG
####################################

# Initialize memory tracking
log_memory, stages = track_memory_through_stages()

log_memory("00_Initial_State")

# -----------------------
# Load DPO dataset
# -----------------------
print("Loading DPO dataset...")
with open("dpo_pairs_clean.jsonl", "r", encoding="utf-8") as f:
    dpo_pairs = [json.loads(line) for line in f]

print(f"Loaded {len(dpo_pairs)} DPO pairs")

hf_dpo = Dataset.from_list(dpo_pairs)
split_dpo = hf_dpo.train_test_split(test_size=0.02, seed=42)
train_dpo, eval_dpo = split_dpo["train"], split_dpo["test"]
print(f"DPO - Train size: {len(train_dpo)} | Eval size: {len(eval_dpo)}")

log_memory("01_After_Dataset_Load")

# -----------------------
# Load base model for DPO
# -----------------------
print("Loading base model for DPO in full precision...")
dpo_base_model = AutoModelForCausalLM.from_pretrained(
    POLICY_NAME,
    torch_dtype=torch.bfloat16,
    device_map={"": 0},
    trust_remote_code=True,
)

log_memory("02_After_Base_Model_Load", dpo_base_model, detailed=True)

# -----------------------
# Load and merge SFT adapter
# -----------------------
print("Loading SFT adapter...")
dpo_model = PeftModel.from_pretrained(dpo_base_model, "./sft_adapter")

log_memory("03_After_SFT_Adapter_Load", dpo_model)

print("Merging SFT adapter with base model...")
dpo_model = dpo_model.merge_and_unload()

log_memory("04_After_SFT_Merge", dpo_model)

# -----------------------
# Create frozen reference model
# -----------------------
print("Creating frozen reference model...")
ref_model = AutoModelForCausalLM.from_pretrained(
    POLICY_NAME,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

log_memory("05_After_Reference_Model_Creation", dpo_model)

ref_model = PeftModel.from_pretrained(ref_model, "./sft_adapter")
ref_model = ref_model.merge_and_unload()

log_memory("06_After_Reference_SFT_Merge", dpo_model)

for param in ref_model.parameters():
    param.requires_grad = False

ref_model.eval()
print("Frozen reference model ready")

log_memory("07_After_Reference_Finalized", dpo_model)

print("\n VERIFYING ACTUAL DEVICE PLACEMENT:")
verify_model_device_placement({
    'policy_model': dpo_model,
    'reference_model': ref_model
})

# -----------------------
# Add fresh LoRA adapters to policy model
# -----------------------
print("Adding fresh LoRA adapters for DPO...")
dpo_lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)

dpo_model = get_peft_model(dpo_model, dpo_lora_config)

log_memory("08_After_LoRA_Addition", dpo_model, detailed=True)

if hasattr(dpo_model, 'gradient_checkpointing_disable'):
    dpo_model.gradient_checkpointing_disable()
elif hasattr(dpo_model, 'config'):
    dpo_model.config.use_cache = False
    if hasattr(dpo_model.config, 'gradient_checkpointing'):
        dpo_model.config.gradient_checkpointing = False

dpo_model.train()
print("Setting model to training mode...")

dpo_model.config.use_cache = False
if hasattr(dpo_model.config, 'return_dict'):
    dpo_model.config.return_dict = True

log_memory("09_After_Model_Configuration", dpo_model)

# -----------------------
# Verify trainable parameters
# -----------------------
print("Verifying trainable parameters...")
trainable_params = []
for name, param in dpo_model.named_parameters():
    if param.requires_grad:
        trainable_params.append(name)

if len(trainable_params) > 0:
    print(f"Found {len(trainable_params)} trainable parameters")
    print("First few trainable params:")
    for i, name in enumerate(trainable_params[:5]):
        print(f"   â€¢ {name}")
    if len(trainable_params) > 5:
        print(f"   â€¢ ... and {len(trainable_params) - 5} more")
else:
    print("ERROR: No trainable parameters found!")
    raise ValueError("No trainable parameters - model setup failed")

eval_subset = eval_dpo.shuffle(seed=42).select(range(50))

# -----------------------
# DPO config with reference model and stronger hyperparameters
# -----------------------
dpo_args = DPOConfig(
    output_dir="./dpo_out",
    per_device_train_batch_size=1,
    gradient_accumulation_steps=16,
    learning_rate=5e-6,
    num_train_epochs=2,
    beta=0.2,
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=50,
    save_strategy="steps",
    save_steps=200,
    save_total_limit=2,
    bf16=True,
    load_best_model_at_end=True, 
    metric_for_best_model="eval_loss", 
    greater_is_better=False,
    
    max_length=786,
    max_prompt_length=512,
    precompute_ref_log_probs=False,
    remove_unused_columns=True,
    dataloader_drop_last=True,
    dataloader_num_workers=0,
    dataloader_pin_memory=False,
    

    prediction_loss_only=True,
    warmup_steps=50,
    weight_decay=0.01,
    report_to="none",
)

# -----------------------
# Create DPO Trainer with reference model
# -----------------------
print("Creating DPO trainer with reference model...")
dpo_trainer = DPOTrainer(
    model=dpo_model,
    ref_model=ref_model,
    train_dataset=train_dpo,
    eval_dataset=eval_dpo,
    processing_class=tokenizer,
    args=dpo_args,
)

log_memory("10_After_DPO_Trainer_Creation", dpo_model, detailed=True)

dpo_trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3))

# -----------------------
# Training info
# -----------------------
total_steps = len(train_dpo) // dpo_args.per_device_train_batch_size // dpo_args.gradient_accumulation_steps * dpo_args.num_train_epochs
print(f"Training Summary:")
print(f"   â€¢ Dataset: {len(train_dpo):,} train samples")
print(f"   â€¢ Batch size: {dpo_args.per_device_train_batch_size}")
print(f"   â€¢ Gradient accumulation: {dpo_args.gradient_accumulation_steps}")
print(f"   â€¢ Effective batch size: {dpo_args.per_device_train_batch_size * dpo_args.gradient_accumulation_steps}")
print(f"   â€¢ Epochs: {dpo_args.num_train_epochs}")
print(f"   â€¢ Estimated steps: ~{total_steps:,}")

print("Starting DPO training with reference model...")

# -----------------------
# Final gradient check and fix
# -----------------------
print("Final gradient check and fix...")
dpo_model.train()

lora_param_count = 0
for name, param in dpo_model.named_parameters():
    if 'lora' in name.lower():
        param.requires_grad = True
        lora_param_count += 1

# Test gradient computation
dpo_model.zero_grad()
test_input = tokenizer("Test gradient", return_tensors="pt", max_length=50, truncation=True).to("cuda:0")
test_output = dpo_model(**test_input)
test_loss = test_output.logits.sum()
test_loss.backward()

has_grad = any(param.grad is not None for param in dpo_model.parameters() if param.requires_grad)
print(f"Gradient test: {'PASSED' if has_grad else 'FAILED'}")

if not has_grad:
    print("Still no gradients - there may be a deeper issue")
    grad_params = [name for name, param in dpo_model.named_parameters() if param.requires_grad]
    print(f"Parameters with requires_grad=True: {len(grad_params)}")
else:
    print("Gradients working - proceeding with training")

log_memory("11_Before_Training_Start", dpo_model, detailed=True)

print("Starting DPO training...")

print("\n" + "="*80)
print(" COMPLETE MEMORY ANALYSIS")
print("="*80)
compare_memory_stages(stages)

# -----------------------
# Add memory monitoring during training
# -----------------------
from transformers import TrainerCallback

class MemoryMonitorCallback(TrainerCallback):
    def __init__(self, log_memory_func):
        self.log_memory = log_memory_func
        
    def on_step_begin(self, args, state, control, model=None, **kwargs):
        if state.global_step % 10 == 0:  # Every 10 steps to reduce spam
            self.log_memory(f"Training_Step_{state.global_step}", model)
        return control
        
    def on_step_end(self, args, state, control, model=None, **kwargs):
        if state.global_step == 1:  # Log after first step
            self.log_memory(f"After_First_Step", model)
        return control

memory_callback = MemoryMonitorCallback(log_memory)
dpo_trainer.add_callback(memory_callback)

print("Memory monitoring callback added successfully!")

# -----------------------
# Train DPO
# -----------------------
try:
    dpo_trainer.train()
except torch.cuda.OutOfMemoryError as e:
    print(f"\nðŸ’¥ OOM ERROR AT TRAINING START!")
    print(f"Error: {e}")
    log_memory("OOM_Error_Point", dpo_model, detailed=True)

    compare_memory_stages(stages)
    
    raise e

# -----------------------
# Save DPO model
# -----------------------
print("Saving DPO model...")
dpo_trainer.save_model("./dpo_model")
tokenizer.save_pretrained("./dpo_model")

print("Finished: Complete SFT â†’ DPO pipeline!")
print("Final model saved in ./dpo_model/")