# MoCo Model Fine-tuning with LoRA/QLoRA

This notebook demonstrates how to fine-tune a pre-trained MoCo model using LoRA (Low-Rank Adaptation) and QLoRA (Quantized LoRA) for parameter-efficient learning.

**Key Benefits:**
- üöÄ **Parameter Efficient**: Train only 1-5% of model parameters
- üíæ **Memory Efficient**: Reduce memory footprint significantly
- ‚ö° **Fast Training**: Quick adaptation to new tasks
- üéØ **High Performance**: Maintain model quality while reducing computation

## 1. Import Required Libraries

In [None]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from pathlib import Path

# Import custom modules
sys.path.insert(0, '/Users/yeguo/VSCodeProjects/sem_moco')
from model import MoCo
from finetune_lora import LoRALayer, QLoRALayer, add_lora_to_model, freeze_backbone_params
from lora_utils import LoRAManager, print_lora_config
from utils import get_config_hierarchical
from dataset import build_dataloader_from_dir, set_seed
from torchvision import models

print("‚úì Libraries imported successfully")
print(f"‚úì PyTorch version: {torch.__version__}")
print(f"‚úì GPU available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"‚úì GPU: {torch.cuda.get_device_name(0)}")

## 2. Load Pre-trained Model and Checkpoint

In [None]:
# Configuration
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

# Paths
config_path = "./configs/stage1.yaml"
checkpoint_path = "./checkpoints/moco_stage1/moco_stage1_epoch_100.pth"  # Update with your checkpoint path

# Load config
cfg = get_config_hierarchical(config_path)
print(f"‚úì Config loaded from {config_path}")
print(f"  Image size: {cfg['img_size']}")
print(f"  Batch size: {cfg['batch_size']}")
print(f"  Proj dim: {cfg['proj_dim']}")

# Create base model
print("\nüèóÔ∏è  Creating MoCo model...")
backbone = models.resnet50(weights=None)
model = MoCo(
    backbone,
    proj_dim=cfg["proj_dim"],
    hidden_dim=cfg["hidden_dim"],
    queue_size=cfg["queue_size"],
    momentum=cfg["momentum"],
    temperature=cfg["temperature"],
).to(device)

# Load checkpoint
if os.path.exists(checkpoint_path):
    print(f"üìÇ Loading checkpoint: {checkpoint_path}")
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model"])
    print("‚úì Checkpoint loaded successfully")
else:
    print(f"‚ö†Ô∏è  Checkpoint not found: {checkpoint_path}")
    print("   Using randomly initialized weights")

print(f"\n‚úì Model loaded on {device}")

## 3. Prepare Dataset for Fine-tuning

In [None]:
# Load datasets
print("üì¶ Loading datasets...")
train_loader = build_dataloader_from_dir(
    cfg["data_path"],
    batch_size=cfg["batch_size"],
    image_size=cfg["img_size"],
    split="train",
    num_workers=4,
)

val_loader = build_dataloader_from_dir(
    cfg["data_path"],
    batch_size=cfg["batch_size"],
    image_size=cfg["img_size"],
    split="val",
    num_workers=4,
)

print(f"‚úì Train batches: {len(train_loader)}")
print(f"‚úì Val batches: {len(val_loader)}")

# Inspect a batch
sample_batch = next(iter(train_loader))
print(f"\nBatch info:")
print(f"  Input 1 shape: {sample_batch[0].shape}")
print(f"  Input 2 shape: {sample_batch[1].shape}")

## 4. Configure LoRA/QLoRA Parameters

In [None]:
# LoRA Configuration Parameters
lora_config = {
    'rank': 8,              # LoRA rank (r) - smaller = fewer parameters
    'alpha': 16,            # LoRA alpha - scaling factor
    'target_modules': ['fc', 'linear'],  # Modules to apply LoRA to
    'use_qlora': False,     # Set to True for QLoRA (4-bit quantization)
    'freeze_backbone': True, # Freeze non-LoRA parameters
}

print("LoRA Configuration:")
print("="*50)
for key, value in lora_config.items():
    print(f"  {key}: {value}")
print("="*50)

# Explanation of parameters
print("\nüìñ Parameter Explanation:")
print("  ‚Ä¢ rank: Dimensionality of LoRA matrices (8-64 typical)")
print("  ‚Ä¢ alpha: Scaling factor (usually 2x rank)")
print("  ‚Ä¢ target_modules: Which layer types to apply LoRA")
print("  ‚Ä¢ use_qlora: 4-bit quantization for reduced memory")
print("  ‚Ä¢ freeze_backbone: Only train LoRA layers")

## 5. Initialize LoRA/QLoRA Adapter

In [None]:
# Clone model for LoRA version
import copy
model_lora = copy.deepcopy(model)

# Apply LoRA to the model
print("üîß Adding LoRA layers...")
model_lora = add_lora_to_model(
    model_lora,
    r=lora_config['rank'],
    lora_alpha=lora_config['alpha'],
    target_modules=lora_config['target_modules'],
    use_qlora=lora_config['use_qlora']
)

# Freeze backbone if needed
if lora_config['freeze_backbone']:
    print("‚ùÑÔ∏è  Freezing backbone parameters...")
    freeze_backbone_params(model_lora, freeze=True)

# Print LoRA configuration
print_lora_config(model_lora)

# Create LoRA manager
lora_manager = LoRAManager(model_lora)
lora_info = lora_manager.get_lora_info()

print(f"\nüìä LoRA Statistics:")
print(f"  Total parameters: {lora_info['total_params']:,}")
print(f"  LoRA parameters: {lora_info['total_lora_params']:,}")
print(f"  LoRA ratio: {lora_info['lora_ratio']:.2f}%")
print(f"  Number of LoRA layers: {len(lora_info['lora_layers'])}")

## 6. Set Up Training Arguments

In [None]:
# Training hyperparameters
training_config = {
    'learning_rate': 1e-4,
    'epochs': 5,
    'batch_size': cfg['batch_size'],
    'weight_decay': 1e-5,
    'warmup_steps': 100,
    'save_every_epochs': 1,
    'use_amp': True,  # Automatic Mixed Precision
}

print("Training Configuration:")
print("="*50)
for key, value in training_config.items():
    print(f"  {key}: {value}")
print("="*50)

# Setup optimizer - only optimize LoRA parameters
lora_params = [p for name, p in model_lora.named_parameters() if 'lora' in name]
optimizer = torch.optim.AdamW(
    lora_params,
    lr=training_config['learning_rate'],
    weight_decay=training_config['weight_decay']
)

# Setup loss function
criterion = nn.CrossEntropyLoss()

# Setup learning rate scheduler
from torch.optim.lr_scheduler import CosineAnnealingLR
scheduler = CosineAnnealingLR(
    optimizer,
    T_max=training_config['epochs'],
    eta_min=1e-6
)

print(f"\n‚úì Optimizer created with {len(lora_params)} parameters")
print(f"‚úì Learning rate scheduler configured")

## 7. Fine-tune the Model

In [None]:
# Training loop
print("üöÄ Starting LoRA fine-tuning...")
print("="*60)

# Setup for training
model_lora.train()
scaler = torch.cuda.amp.GradScaler() if training_config['use_amp'] else None

# Store metrics
train_losses = []
val_losses = []

# Training epochs
for epoch in range(training_config['epochs']):
    print(f"\nüìç Epoch {epoch+1}/{training_config['epochs']}")
    
    # Training phase
    epoch_loss = 0.0
    num_batches = 0
    
    for batch_idx, (im_q, im_k) in enumerate(train_loader):
        im_q = im_q.to(device)
        im_k = im_k.to(device)
        
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        if training_config['use_amp']:
            with torch.amp.autocast(device_type='cuda'):
                logits, labels = model_lora(im_q, im_k)
                loss = criterion(logits, labels)
            
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            logits, labels = model_lora(im_q, im_k)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
        
        epoch_loss += loss.item()
        num_batches += 1
        
        # Print progress
        if (batch_idx + 1) % 10 == 0:
            avg_loss = epoch_loss / num_batches
            print(f"  Batch {batch_idx+1}/{len(train_loader)} - Loss: {avg_loss:.4f}")
    
    # Epoch metrics
    avg_train_loss = epoch_loss / num_batches
    train_losses.append(avg_train_loss)
    
    # Update scheduler
    scheduler.step()
    
    print(f"  ‚úì Epoch {epoch+1} - Train Loss: {avg_train_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")
    
    # Validation phase (optional, every epoch)
    if True:  # Set to False to skip validation
        model_lora.eval()
        val_loss = 0.0
        num_val_batches = 0
        
        with torch.no_grad():
            for im_q, im_k in val_loader:
                im_q = im_q.to(device)
                im_k = im_k.to(device)
                
                if training_config['use_amp']:
                    with torch.amp.autocast(device_type='cuda'):
                        logits, labels = model_lora(im_q, im_k)
                        loss = criterion(logits, labels)
                else:
                    logits, labels = model_lora(im_q, im_k)
                    loss = criterion(logits, labels)
                
                val_loss += loss.item()
                num_val_batches += 1
        
        avg_val_loss = val_loss / num_val_batches
        val_losses.append(avg_val_loss)
        print(f"  ‚úì Validation Loss: {avg_val_loss:.4f}")
        
        model_lora.train()

print("\n" + "="*60)
print("‚úÖ Training complete!")

In [None]:
# Visualize training curves
import matplotlib.pyplot as plt

fig, ax = plt.subplots(1, 1, figsize=(10, 6))

ax.plot(range(1, len(train_losses)+1), train_losses, 'b-o', label='Training Loss', linewidth=2)
if val_losses:
    ax.plot(range(1, len(val_losses)+1), val_losses, 'r-s', label='Validation Loss', linewidth=2)

ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Loss', fontsize=12)
ax.set_title('LoRA Fine-tuning: Training Progress', fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"  Best train loss: {min(train_losses):.4f}")
if val_losses:
    print(f"  Best val loss: {min(val_losses):.4f}")

## 8. Save and Load Fine-tuned Model

In [None]:
# Save LoRA checkpoint
output_dir = "./checkpoints/lora"
os.makedirs(output_dir, exist_ok=True)

checkpoint_path = os.path.join(output_dir, "moco_lora_finetuned.pth")

print("üíæ Saving LoRA checkpoint...")
lora_manager.save_lora_checkpoint(
    checkpoint_path,
    metadata={
        'lora_rank': lora_config['rank'],
        'lora_alpha': lora_config['alpha'],
        'num_epochs': training_config['epochs'],
        'final_train_loss': train_losses[-1] if train_losses else None,
        'final_val_loss': val_losses[-1] if val_losses else None,
    }
)

print(f"‚úì Checkpoint saved to {checkpoint_path}")

# Show checkpoint size comparison
base_model_size = sum(p.numel() for p in model.parameters()) * 4 / (1024**2)
lora_size = sum(p.numel() for name, p in model_lora.named_parameters() if 'lora' in name) * 4 / (1024**2)

print(f"\nüìä Checkpoint Sizes:")
print(f"  Base model: {base_model_size:.2f} MB")
print(f"  LoRA only: {lora_size:.2f} MB")
print(f"  Space saved: {100 * (1 - lora_size/base_model_size):.1f}%")

In [None]:
# Load LoRA checkpoint into a fresh model
print("\nüîÑ Loading LoRA checkpoint into fresh model...")

# Create fresh model
fresh_model = MoCo(
    models.resnet50(weights=None),
    proj_dim=cfg["proj_dim"],
    hidden_dim=cfg["hidden_dim"],
    queue_size=cfg["queue_size"],
    momentum=cfg["momentum"],
    temperature=cfg["temperature"],
).to(device)

# Load base checkpoint
if os.path.exists(checkpoint_path.replace('moco_lora_finetuned.pth', '../moco_stage1/moco_stage1_epoch_100.pth')):
    base_checkpoint = torch.load(
        checkpoint_path.replace('moco_lora_finetuned.pth', '../moco_stage1/moco_stage1_epoch_100.pth'),
        map_location=device
    )
    fresh_model.load_state_dict(base_checkpoint["model"])

# Add LoRA layers
fresh_model = add_lora_to_model(
    fresh_model,
    r=lora_config['rank'],
    lora_alpha=lora_config['alpha'],
    target_modules=lora_config['target_modules'],
    use_qlora=lora_config['use_qlora']
)

# Load LoRA weights
fresh_lora_manager = LoRAManager(fresh_model)
metadata = fresh_lora_manager.load_lora_checkpoint(checkpoint_path)

print(f"‚úì Model loaded successfully")
print(f"‚úì Metadata: {metadata}")

# Set to evaluation mode
fresh_model.eval()
print(f"‚úì Model in evaluation mode")

## 9. Evaluate Fine-tuned Model

In [None]:
# Evaluation: Compare base model vs LoRA fine-tuned model
print("üìä Model Evaluation")
print("="*60)

def evaluate_model(model, val_loader, device, model_name="Model"):
    """Evaluate model on validation set."""
    model.eval()
    total_loss = 0.0
    num_batches = 0
    
    with torch.no_grad():
        for im_q, im_k in val_loader:
            im_q = im_q.to(device)
            im_k = im_k.to(device)
            
            logits, labels = model(im_q, im_k)
            loss = criterion(logits, labels)
            
            total_loss += loss.item()
            num_batches += 1
    
    avg_loss = total_loss / num_batches
    return avg_loss

# Evaluate base model
print(f"\nüîç Evaluating base model...")
base_model.eval()
base_loss = evaluate_model(model, val_loader, device, "Base Model")
print(f"  Base model val loss: {base_loss:.4f}")

# Evaluate LoRA fine-tuned model
print(f"\nüîç Evaluating LoRA fine-tuned model...")
lora_loss = evaluate_model(model_lora, val_loader, device, "LoRA Model")
print(f"  LoRA model val loss: {lora_loss:.4f}")

# Compare results
improvement = ((base_loss - lora_loss) / base_loss * 100)
print(f"\nüìà Comparison:")
print(f"  Base model loss: {base_loss:.4f}")
print(f"  LoRA model loss: {lora_loss:.4f}")
print(f"  Improvement: {improvement:.2f}%")

if improvement > 0:
    print(f"  ‚úì LoRA fine-tuning improved performance!")
else:
    print(f"  ‚ÑπÔ∏è  Further tuning may be needed")

In [None]:
# Summary and next steps
print("\n" + "="*60)
print("‚úÖ LoRA Fine-tuning Summary")
print("="*60)

print("\nüìä Key Metrics:")
print(f"  LoRA Rank: {lora_config['rank']}")
print(f"  LoRA Alpha: {lora_config['alpha']}")
print(f"  Trainable Parameters: {lora_info['total_lora_params']:,}")
print(f"  Efficiency: {lora_info['lora_ratio']:.2f}% of total parameters")
print(f"  Training Time: ~{training_config['epochs']} epochs")
print(f"  Learning Rate: {training_config['learning_rate']}")

print("\nüíæ Saved Artifacts:")
print(f"  LoRA Checkpoint: {checkpoint_path}")
print(f"  Size: {lora_size:.2f} MB (vs {base_model_size:.2f} MB base)")

print("\nüöÄ Next Steps:")
print("  1. Export the LoRA weights to ONNX format")
print("  2. Deploy the model with minimal overhead")
print("  3. Fine-tune on additional tasks by loading this checkpoint")
print("  4. Merge LoRA weights with base model for inference")

print("\nüìö Useful Commands:")
print("  ‚Ä¢ lora_manager.merge_lora_weights() - Merge LoRA into base model")
print("  ‚Ä¢ lora_manager.save_lora_checkpoint() - Save LoRA weights")
print("  ‚Ä¢ lora_manager.load_lora_checkpoint() - Load LoRA weights")
print("  ‚Ä¢ lora_manager.get_lora_info() - Get LoRA statistics")

print("\n" + "="*60)