# Train Main Trajectory

This notebook trains a single ResNet18 model on CIFAR-10 for 10 epochs.
This main trajectory will serve as the starting point for branched finetuning experiments.

**Outputs:**
- `checkpoints/base_model.pt` - Initial model (epoch 0)
- `checkpoints/main_epoch{i}.pt` - Checkpoints at epochs 1-10

## Setup Environment

In [1]:
LOCAL = True

if LOCAL:
    ROOT_DIR = "/Users/Yang/Desktop/research-model-merge/playground/merge_soup-resnet18-cifar10"
    DATA_DIR = "/Users/Yang/Desktop/research-model-merge/datasets"
    PROJECT_ROOT = "/Users/Yang/Desktop/research-model-merge"
else:
    # on Colab
    ROOT_DIR = "/content/research-model-merge/playground/merge_soup-resnet18-cifar10"
    DATA_DIR = "/content/research-model-merge/datasets"
    PROJECT_ROOT = "/content/research-model-merge"
    DRIVE_DIR = "/content/drive/MyDrive/research-model_merge-shared/merge_soup-resnet18-cifar10"

### Mount Google Drive (Colab only)

In [2]:
if not LOCAL:
    from google.colab import drive
    drive.mount('/content/drive')

### Clone Repository and Install Dependencies (Colab only)

In [3]:
if not LOCAL:
    # !rm -rf research-model-merge
    !git clone https://github.com/nbzy1995/research-model-merge.git /content/research-model-merge
    !pip install --quiet --upgrade pip
    !pip install -q -r research-model-merge/requirements.txt
    print("✅ Repository cloned and dependencies installed!")

### Import Libraries

In [4]:
import os
import sys
import subprocess

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import matplotlib.pyplot as plt

# Add project directories to path
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)
if ROOT_DIR not in sys.path:
    sys.path.insert(0, ROOT_DIR)

from datasets.cifar10 import CIFAR10
from utils import create_cifar10_resnet18, train_model, save_checkpoint

### Check Device Information

In [5]:
print("🔍 System Information:")
print(f"Python version: {subprocess.check_output(['python', '--version']).decode().strip()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

if torch.cuda.is_available():
    print(f"GPU device: {torch.cuda.get_device_name(0)}")
    print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
    print(f"CUDA version: {torch.version.cuda}")
    DEVICE = torch.device("cuda")
else:
    if LOCAL:
        print("⚠️ No GPU available! Training will be slow on CPU.")
    else:
        print("❌ No GPU available! Please enable GPU runtime in Colab.")
        print("Runtime > Change runtime type > Hardware accelerator > GPU")
    DEVICE = torch.device("cpu")

🔍 System Information:
Python version: Python 3.11.5
PyTorch version: 2.3.0
CUDA available: False
⚠️ No GPU available! Training will be slow on CPU.


## Prepare Dataset

In [6]:
dataset = CIFAR10(
    data_location=DATA_DIR,
    batch_size=256,
    num_workers=2
)

train_loader = dataset.train_loader
val_loader = dataset.val_loader
test_loader = dataset.test_loader

print(f"✅ Dataset loaded:")
print(f"   Train samples: {len(dataset.train_sampler)}")
print(f"   Val samples: {len(dataset.val_sampler)}")
print(f"   Test samples: {len(dataset.test_dataset)}")

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
✅ Dataset loaded:
   Train samples: 49000
   Val samples: 1000
   Test samples: 10000


## Training Configuration

In [None]:
# Fixed seed for reproducibility
SEED = 42

# Training hyperparameters
EPOCHS = 10
LEARNING_RATE = 0.01
WEIGHT_DECAY = 1e-4
MOMENTUM = 0.9
WARMUP_EPOCHS = 5

# Checkpoint configuration
if LOCAL:
    CHECKPOINT_DIR = f"{ROOT_DIR}/checkpoints"
else:
    CHECKPOINT_DIR = f"{DRIVE_DIR}/checkpoints"

os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Training Configuration:")
print(f"  Seed: {SEED}")
print(f"  Epochs: {EPOCHS}")
print(f"  Learning Rate: {LEARNING_RATE}")
print(f"  Weight Decay: {WEIGHT_DECAY}")
print(f"  Momentum: {MOMENTUM}")
print(f"  Warmup Epochs: {WARMUP_EPOCHS}")
print(f"  Checkpoint Dir: {CHECKPOINT_DIR}")

Training Configuration:
  Seed: 42
  Epochs: 10
  Learning Rate: 0.1
  Weight Decay: 0.0001
  Momentum: 0.9
  Warmup Epochs: 5
  Checkpoint Dir: /Users/Yang/Desktop/research-model-merge/playground/merge_soup-resnet18-cifar10/checkpoints


## Create Model

In [8]:
model = create_cifar10_resnet18(num_classes=10, seed=SEED)
model = model.to(DEVICE)

print("✅ Model created:")
print(f"   Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"   Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Save base model (epoch 0)
base_model_path = os.path.join(CHECKPOINT_DIR, "base_model.pt")
save_checkpoint(model, base_model_path)
print(f"\n✅ Saved base model: {base_model_path}")

✅ Model created:
   Total parameters: 11,173,962
   Trainable parameters: 11,173,962

✅ Saved base model: /Users/Yang/Desktop/research-model-merge/playground/merge_soup-resnet18-cifar10/checkpoints/base_model.pt


## Setup Optimizer and Scheduler

In [9]:
optimizer = optim.SGD(
    model.parameters(),
    lr=LEARNING_RATE,
    momentum=MOMENTUM,
    weight_decay=WEIGHT_DECAY
)

# Step LR: decay by 0.1 after warmup epochs
lr_scheduler = StepLR(optimizer, step_size=WARMUP_EPOCHS, gamma=0.1)

criterion = nn.CrossEntropyLoss()

print("✅ Optimizer and scheduler configured")

✅ Optimizer and scheduler configured


## Train Main Trajectory

In [None]:
print(f"\n{'='*80}")
print(f"Starting Main Trajectory Training")
print(f"{'='*80}\n")

history = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    lr_scheduler=lr_scheduler,
    criterion=criterion,
    epochs=EPOCHS,
    device=DEVICE,
    checkpoint_dir=CHECKPOINT_DIR,
    checkpoint_name_template="main_epoch{epoch}.pt",
    log_interval=20,
    save_epoch_0=False  # Already saved as base_model.pt
)

print(f"\n{'='*80}")
print(f"Main Trajectory Training Completed!")
print(f"{'='*80}")


Starting Main Trajectory Training



Epoch 1/10 [Train]:   4%|▍         | 8/192 [00:51<19:03,  6.22s/it, loss=2.4792, lr=0.100000]

## Plot Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

epochs_range = range(1, EPOCHS + 1)

# Plot 1: Training and Validation Loss (combined)
axes[0].plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss')
axes[0].plot(epochs_range, history['val_loss'], 'r-o', label='Val Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True)

# Plot 2: Validation Accuracy
axes[1].plot(epochs_range, [100*x for x in history['val_acc']], 'g-o', label='Val Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Validation Accuracy')
axes[1].legend()
axes[1].grid(True)

# Plot 3: Learning Rate
axes[2].plot(epochs_range, history['lr'], 'm-o', label='Learning Rate')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].legend()
axes[2].grid(True)
axes[2].set_yscale('log')

plt.tight_layout()
plt.savefig(f"{CHECKPOINT_DIR}/main_trajectory_curves.png", dpi=150, bbox_inches='tight')
plt.show()

print(f"✅ Training curves saved to {CHECKPOINT_DIR}/main_trajectory_curves.png")

## Summary

In [None]:
print("\n" + "="*80)
print("Main Trajectory Training Summary")
print("="*80)
print(f"Final Train Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Val Loss: {history['val_loss'][-1]:.4f}")
print(f"Final Val Accuracy: {100*history['val_acc'][-1]:.2f}%")
print(f"Best Val Accuracy: {100*max(history['val_acc']):.2f}% (Epoch {history['val_acc'].index(max(history['val_acc']))+1})")
print("="*80)
print(f"\n✅ Checkpoints saved:")
print(f"   - base_model.pt (epoch 0)")
for i in range(1, EPOCHS + 1):
    print(f"   - main_epoch{i}.pt")
print("="*80)