# Train Branched Finetuning

This notebook branches from the main trajectory at specific epochs and continues training with different learning rates.

**Experiment Design:**
- Load checkpoint from main trajectory at epoch `i`
- Train 5 variants with different learning rates
- Each variant trains for 10 epochs
- Save checkpoints at every epoch

**Outputs:**
- `checkpoints/branch_e{i}_lr{j}_e{k}.pt` where:
  - `i` = branching epoch from main trajectory
  - `j` = learning rate variant (1-5)
  - `k` = epochs trained after branching (1-10)

## Setup Environment

In [None]:
LOCAL = True

if LOCAL:
    ROOT_DIR = "/Users/Yang/Desktop/research-model-merge/playground/merge_soup-resnet18-cifar10"
    DATA_DIR = "/Users/Yang/Desktop/research-model-merge/datasets"
    PROJECT_ROOT = "/Users/Yang/Desktop/research-model-merge"
else:
    # on Colab
    ROOT_DIR = "/content/research-model-merge/playground/merge_soup-resnet18-cifar10"
    DATA_DIR = "/content/research-model-merge/datasets"
    PROJECT_ROOT = "/content/research-model-merge"
    DRIVE_DIR = "/content/drive/MyDrive/research-model_merge-shared/merge_soup-resnet18-cifar10"

### Mount Google Drive (Colab only)

In [None]:
if not LOCAL:
    from google.colab import drive
    drive.mount('/content/drive')

### Clone Repository and Install Dependencies (Colab only)

In [None]:
if not LOCAL:
    !rm -rf research-model-merge
    !git clone https://github.com/nbzy1995/research-model-merge.git /content/research-model-merge
    !pip install --quiet --upgrade pip
    !pip install -q -r research-model-merge/requirements.txt
    print("✅ Repository cloned and dependencies installed!")

### Import Libraries

In [None]:
import os
import sys
import subprocess
import json

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt
import pandas as pd

# Add project directories to path
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)

from datasets.cifar10 import CIFAR10
from utils import create_cifar10_resnet18, train_model, load_checkpoint

### Check Device Information

In [None]:
print("🔍 System Information:")
print(f"Python version: {subprocess.check_output(['python', '--version']).decode().strip()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA version: {torch.version.cuda}")
    DEVICE = torch.device("cuda")
else:
    if LOCAL:
        print("⚠️ No GPU available! Training will be slow on CPU.")
    else:
        print("❌ No GPU available! Please enable GPU runtime in Colab.")
        print("Runtime > Change runtime type > Hardware accelerator > GPU")
    DEVICE = torch.device("cpu")

## Prepare Dataset

In [None]:
dataset = CIFAR10(
    data_location=DATA_DIR,
    batch_size=256,
    num_workers=2
)

train_loader = dataset.train_loader
val_loader = dataset.val_loader
test_loader = dataset.test_loader

print(f"✅ Dataset loaded:")
print(f"   Train samples: {len(dataset.train_sampler)}")
print(f"   Val samples: {len(dataset.val_sampler)}")
print(f"   Test samples: {len(dataset.test_dataset)}")

## Branching Configuration

In [None]:
# Branching points (which epochs from main trajectory to branch from)
BRANCHING_EPOCHS = [0, 2, 5, 8]

# Learning rate variants for branched training
LR_VARIANTS = [
    {'id': 1, 'lr': 0.1,   'name': 'lr1_0.1'},
    {'id': 2, 'lr': 0.05,  'name': 'lr2_0.05'},
    {'id': 3, 'lr': 0.01,  'name': 'lr3_0.01'},
    {'id': 4, 'lr': 0.005, 'name': 'lr4_0.005'},
    {'id': 5, 'lr': 0.001, 'name': 'lr5_0.001'},
]

# Training hyperparameters
BRANCH_EPOCHS = 10  # How many epochs to train after branching
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.9

# Checkpoint configuration
if LOCAL:
    CHECKPOINT_DIR = f"{ROOT_DIR}/checkpoints"
else:
    CHECKPOINT_DIR = f"{DRIVE_DIR}/checkpoints"

print(f"Branching Configuration:")
print(f"  Branching Epochs: {BRANCHING_EPOCHS}")
print(f"  LR Variants: {len(LR_VARIANTS)}")
print(f"  Epochs per Branch: {BRANCH_EPOCHS}")
print(f"  Total Models to Train: {len(BRANCHING_EPOCHS) * len(LR_VARIANTS)}")
print(f"  Checkpoint Dir: {CHECKPOINT_DIR}")

## Train Branched Variants

For each branching point, we load the checkpoint from the main trajectory and train 5 variants with different learning rates.

In [None]:
# Store training results
all_results = []

for branch_epoch in BRANCHING_EPOCHS:
    print(f"\n{'#'*80}")
    print(f"BRANCHING FROM EPOCH {branch_epoch}")
    print(f"{'#'*80}\n")
    
    # Load checkpoint from main trajectory
    if branch_epoch == 0:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, "base_model.pt")
    else:
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"main_epoch{branch_epoch}.pt")
    
    print(f"Loading checkpoint: {checkpoint_path}")
    
    if not os.path.exists(checkpoint_path):
        print(f"❌ Checkpoint not found: {checkpoint_path}")
        print(f"   Please run train_main_trajectory.ipynb first!")
        continue
    
    for lr_config in LR_VARIANTS:
        lr_id = lr_config['id']
        lr = lr_config['lr']
        lr_name = lr_config['name']
        
        print(f"\n{'='*80}")
        print(f"Training Branch: epoch{branch_epoch} -> {lr_name}")
        print(f"{'='*80}\n")
        
        # Create new model and load checkpoint
        model = create_cifar10_resnet18(num_classes=10)
        state_dict = load_checkpoint(checkpoint_path, device=DEVICE)
        model.load_state_dict(state_dict)
        model = model.to(DEVICE)
        
        # Setup optimizer and scheduler for this variant
        optimizer = optim.SGD(
            model.parameters(),
            lr=lr,
            momentum=MOMENTUM,
            weight_decay=WEIGHT_DECAY
        )
        
        # Constant learning rate (no scheduler)
        lr_scheduler = None
        
        criterion = nn.CrossEntropyLoss()
        
        # Train
        checkpoint_template = f"branch_e{branch_epoch}_lr{lr_id}_e{{epoch}}.pt"
        
        history = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            criterion=criterion,
            epochs=BRANCH_EPOCHS,
            device=DEVICE,
            checkpoint_dir=CHECKPOINT_DIR,
            checkpoint_name_template=checkpoint_template,
            log_interval=20,
            save_epoch_0=False
        )
        
        # Store results
        result = {
            'branch_epoch': branch_epoch,
            'lr_id': lr_id,
            'lr': lr,
            'lr_name': lr_name,
            'history': history,
            'final_val_acc': history['val_acc'][-1],
            'best_val_acc': max(history['val_acc'])
        }
        all_results.append(result)
        
        print(f"\n✅ Completed: epoch{branch_epoch} -> {lr_name}")
        print(f"   Final Val Acc: {100*result['final_val_acc']:.2f}%")
        print(f"   Best Val Acc: {100*result['best_val_acc']:.2f}%")

print(f"\n{'='*80}")
print(f"ALL BRANCHED TRAINING COMPLETED!")
print(f"{'='*80}")

## Save Training Summary

In [None]:
# Create summary DataFrame
summary_data = []
for r in all_results:
    summary_data.append({
        'branch_epoch': r['branch_epoch'],
        'lr_id': r['lr_id'],
        'lr': r['lr'],
        'lr_name': r['lr_name'],
        'final_val_acc': f"{100*r['final_val_acc']:.2f}%",
        'best_val_acc': f"{100*r['best_val_acc']:.2f}%"
    })

df_summary = pd.DataFrame(summary_data)

# Save to CSV
summary_path = os.path.join(CHECKPOINT_DIR, "branched_training_summary.csv")
df_summary.to_csv(summary_path, index=False)

print("\n" + "="*80)
print("Branched Training Summary")
print("="*80)
print(df_summary.to_string(index=False))
print("="*80)
print(f"\n✅ Summary saved to {summary_path}")

## Visualize Training Curves by Branching Epoch

In [None]:
# Create a plot for each branching epoch
for branch_epoch in BRANCHING_EPOCHS:
    # Filter results for this branching epoch
    branch_results = [r for r in all_results if r['branch_epoch'] == branch_epoch]
    
    if not branch_results:
        continue
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    epochs_range = range(1, BRANCH_EPOCHS + 1)
    
    # Plot 1: Training and Validation Loss (combined)
    for r in branch_results:
        # Plot train loss with solid line
        axes[0].plot(epochs_range, r['history']['train_loss'], 
                    linestyle='-', marker='o', alpha=0.7, label=f"{r['lr_name']} (train)")
        # Plot val loss with dashed line
        axes[0].plot(epochs_range, r['history']['val_loss'], 
                    linestyle='--', marker='s', alpha=0.7, label=f"{r['lr_name']} (val)")
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title(f'Training and Validation Loss (Branched from Epoch {branch_epoch})')
    axes[0].legend(fontsize=8, ncol=2)
    axes[0].grid(True)
    
    # Plot 2: Validation Accuracy
    for r in branch_results:
        axes[1].plot(epochs_range, [100*x for x in r['history']['val_acc']], 
                    marker='o', label=r['lr_name'])
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Validation Accuracy (%)')
    axes[1].set_title(f'Validation Accuracy (Branched from Epoch {branch_epoch})')
    axes[1].legend()
    axes[1].grid(True)
    
    plt.tight_layout()
    plot_path = os.path.join(CHECKPOINT_DIR, f"branch_e{branch_epoch}_curves.png")
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"✅ Plot saved: {plot_path}")

## Experiment Summary

In [None]:
print("\n" + "="*80)
print("Branched Finetuning Experiment Complete!")
print("="*80)
print(f"\nTotal models trained: {len(all_results)}")
print(f"Branching epochs: {BRANCHING_EPOCHS}")
print(f"LR variants per branch: {len(LR_VARIANTS)}")
print(f"\nCheckpoints saved in: {CHECKPOINT_DIR}")
print(f"\nCheckpoint naming format: branch_e{{branch_epoch}}_lr{{lr_id}}_e{{epoch}}.pt")
print("="*80)
print(f"\n✅ Ready for soup analysis!")
print(f"   Next: Run analyze_branching_soups.ipynb")
print("="*80)