# Resume Base Model Training from Checkpoint

This notebook helps you resume training from an existing base model checkpoint with configurable parameters.

## Features
- Resume from existing checkpoint
- Configurable training parameters
- Configurable input/output directories
- Automatic checkpoint compatibility fixing
- Progress monitoring


## 1. Configuration Parameters

Customize all parameters below before starting training.


In [None]:
# ============================================================================
# CONFIGURABLE PARAMETERS
# ============================================================================

# --- Input/Output Directories ---
CHECKPOINT_TO_RESUME = "checkpoints/base/base_model_last.ckpt"  # Path to existing checkpoint
OUTPUT_DIR = "checkpoints/base_training"  # Where to save new checkpoints
CONFIG_DIR = "configs"  # Directory containing config files

# --- Training Configuration ---
NUM_EPOCHS = 3  # Number of epochs to train
LEARNING_RATE = 0.001  # Learning rate
BATCH_SIZE = 128  # Batch size (if overriding config)

# --- Dataset Configuration ---
DATASET = "CIFAR"  # Options: 'CIFAR', 'CIFAR100', 'CIFAR100_TREES', 'CIFAR100_ANIMALS', 'IMAGENETTE'

# --- Model Architecture ---
D_MODEL = 128  # Model dimension
NUM_HEADS = 2  # Number of attention heads
DROPOUT = 0.1  # Dropout rate
D_FF = 256  # Feed-forward dimension
IMG_SIZE = 32  # Input image size
DENOISING_STEPS = 500  # Number of denoising steps

# --- Training Settings ---
EXPERIMENT_NAME = "diffit_base_resumed"  # Name for this training run
SAVE_EVERY_N_EPOCHS = 1  # Save checkpoint every N epochs
KEEP_LAST_N_CHECKPOINTS = 3  # Keep only last N checkpoints
GRADIENT_CLIP_VAL = 1.0  # Gradient clipping value
LOG_EVERY_N_STEPS = 25  # Log metrics every N steps

# --- Hardware Settings ---
ACCELERATOR = "auto"  # Options: 'auto', 'gpu', 'cpu', 'tpu'
DEVICES = "auto"  # Options: 'auto', 1, 2, [0,1], etc.
PRECISION = "32-true"  # Options: '32-true', '16-mixed', 'bf16-mixed'

# --- Advanced Options ---
VERBOSE = True  # Print detailed progress information
FIX_CHECKPOINT_COMPATIBILITY = True  # Automatically fix checkpoint compatibility issues


## 2. Setup and Imports


In [None]:
import sys
from pathlib import Path
import yaml
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor

# Add project to path
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.insert(0, str(project_root))

from diffit.training.base_checkpoint_callbacks import (
    BaseModelCheckpointCallback,
    ResumeBaseModelCallback
)
from diffit.training.data import DiffiTDataModule
from diffit.models.unet import UShapedNetwork

print("‚úÖ Imports successful!")
print(f"üìÅ Project root: {project_root}")
print(f"üîß PyTorch Lightning version: {pl.__version__}")
print(f"üîß PyTorch version: {torch.__version__}")


## 3. Verify Checkpoint and Configuration


In [None]:
print("üîç Verifying checkpoint and configuration...\n")
print("=" * 70)

# Check checkpoint exists
checkpoint_path = Path(CHECKPOINT_TO_RESUME)
if checkpoint_path.exists():
    file_size_mb = checkpoint_path.stat().st_size / (1024 * 1024)
    print(f"‚úÖ Checkpoint found: {CHECKPOINT_TO_RESUME}")
    print(f"   Size: {file_size_mb:.2f} MB")
else:
    print(f"‚ùå ERROR: Checkpoint not found at {CHECKPOINT_TO_RESUME}")
    print("   Please update CHECKPOINT_TO_RESUME variable")
    raise FileNotFoundError(f"Checkpoint not found: {CHECKPOINT_TO_RESUME}")

# Determine dataset config file
dataset_config_map = {
    'CIFAR': 'cifar10.yaml',
    'CIFAR100': 'cifar100.yaml',
    'CIFAR100_TREES': 'cifar100_trees_only.yaml',
    'CIFAR100_ANIMALS': 'cifar100_animals_only.yaml',
    'CIFAR100_VEHICLES': 'cifar100_vehicles_only.yaml',
    'IMAGENETTE': 'imagenette.yaml'
}

data_config_file = dataset_config_map.get(DATASET, f"{DATASET.lower()}.yaml")
data_config_path = Path(CONFIG_DIR) / "data" / data_config_file

if data_config_path.exists():
    print(f"\n‚úÖ Dataset config found: {data_config_path}")
else:
    print(f"\n‚ùå ERROR: Dataset config not found at {data_config_path}")
    raise FileNotFoundError(f"Dataset config not found: {data_config_path}")

# Check training config
training_config_path = Path(CONFIG_DIR) / "training" / "base_training.yaml"
if training_config_path.exists():
    print(f"‚úÖ Training config found: {training_config_path}")
else:
    print(f"‚ö†Ô∏è  Warning: Training config not found at {training_config_path}")
    print("   Will use notebook parameters only")

print("\n" + "=" * 70)


## 4. Load Configuration Files


In [None]:
print("üìã Loading configuration files...\n")

# Load training config (if exists)
training_config = {}
if training_config_path.exists():
    with open(training_config_path, 'r') as f:
        config_data = yaml.safe_load(f)
        training_config = config_data.get('training', {})
    print(f"‚úÖ Loaded training config from {training_config_path}")

# Override with notebook parameters
training_config.update({
    'num_epochs': NUM_EPOCHS,
    'learning_rate': LEARNING_RATE,
    'dataset': DATASET,
    'accelerator': ACCELERATOR,
    'devices': DEVICES,
    'precision': PRECISION,
    'gradient_clip_val': GRADIENT_CLIP_VAL,
    'log_every_n_steps': LOG_EVERY_N_STEPS
})

# Load data config
with open(data_config_path, 'r') as f:
    data_config = yaml.safe_load(f)

print(f"‚úÖ Loaded dataset config from {data_config_path}")

# Display final configuration
print("\n" + "=" * 70)
print("üìä Final Training Configuration:")
print("=" * 70)
print(f"Epochs:              {training_config['num_epochs']}")
print(f"Learning Rate:       {training_config['learning_rate']}")
print(f"Dataset:             {training_config['dataset']}")
print(f"Accelerator:         {training_config['accelerator']}")
print(f"Devices:             {training_config['devices']}")
print(f"Precision:           {training_config['precision']}")
print(f"Gradient Clip:       {training_config['gradient_clip_val']}")
print(f"\nModel Architecture:")
print(f"D Model:             {D_MODEL}")
print(f"Num Heads:           {NUM_HEADS}")
print(f"Dropout:             {DROPOUT}")
print(f"Feed Forward Dim:    {D_FF}")
print(f"Image Size:          {IMG_SIZE}")
print(f"Denoising Steps:     {DENOISING_STEPS}")
print("=" * 70)


## 5. Setup Data Module


In [None]:
print("üì¶ Setting up data module...\n")

data_module = DiffiTDataModule(data_config)

print(f"‚úÖ Data module created for {DATASET}")
print(f"   Config: {data_config_path.name}")


## 6. Initialize Model


In [None]:
print("ü§ñ Initializing model...\n")

# Determine device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Device: {device}")
if device.type == 'cuda':
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.2f} GB")

# Initialize model
model = UShapedNetwork(
    learning_rate=LEARNING_RATE,
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    dropout=DROPOUT,
    d_ff=D_FF,
    img_size=IMG_SIZE,
    device=device,
    denoising_steps=DENOISING_STEPS
)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\n‚úÖ Model initialized successfully!")
print(f"   Total parameters:     {total_params:,}")
print(f"   Trainable parameters: {trainable_params:,}")
print(f"   Model size:           ~{total_params * 4 / 1024**2:.2f} MB (FP32)")


## 7. Fix Checkpoint Compatibility (if needed)


In [None]:
print("üîß Checking checkpoint compatibility...\n")

# Load checkpoint to inspect
checkpoint = torch.load(CHECKPOINT_TO_RESUME, map_location='cpu')
starting_epoch = checkpoint.get('epoch', 0)

print(f"üìä Checkpoint info:")
print(f"   Current epoch: {starting_epoch}")
print(f"   Will resume from epoch: {starting_epoch + 1}")

# Fix checkpoint compatibility if needed
checkpoint_fixed = False
checkpoint_to_use = CHECKPOINT_TO_RESUME

if FIX_CHECKPOINT_COMPATIBILITY:
    if 'pytorch-lightning_version' not in checkpoint:
        print(f"\nüîß Adding PyTorch Lightning version: {pl.__version__}")
        checkpoint['pytorch-lightning_version'] = pl.__version__
        checkpoint_fixed = True
    
    # Add empty optimizer and scheduler states if missing
    if 'optimizer_states' not in checkpoint:
        print("üîß Adding missing optimizer/scheduler states (weights-only checkpoint)")
        checkpoint['optimizer_states'] = []
        checkpoint['lr_schedulers'] = []
        checkpoint_fixed = True
    
    if checkpoint_fixed:
        # Save fixed checkpoint to temporary location
        temp_checkpoint = "temp_fixed_checkpoint.ckpt"
        torch.save(checkpoint, temp_checkpoint)
        checkpoint_to_use = temp_checkpoint
        print(f"‚úÖ Fixed checkpoint saved to: {temp_checkpoint}")
    else:
        print("‚úÖ Checkpoint is compatible, no fixes needed")
else:
    print("‚ö†Ô∏è  Checkpoint compatibility fixing is disabled")

print(f"\nüì• Will resume from checkpoint: {checkpoint_to_use}")


In [None]:
print("‚öôÔ∏è  Setting up trainer and callbacks...\n")

# Setup checkpoint callback
checkpoint_callback = BaseModelCheckpointCallback(
    base_dir=OUTPUT_DIR,
    run_number=None,  # Will create new run
    experiment_name=EXPERIMENT_NAME,
    monitor='train_loss',
    mode='min',
    save_every_n_epochs=SAVE_EVERY_N_EPOCHS,
    keep_last_n=KEEP_LAST_N_CHECKPOINTS,
    verbose=VERBOSE
)

# Set starting epoch for checkpoint callback
checkpoint_callback.set_starting_epoch(starting_epoch)

# Learning rate monitor
lr_monitor = LearningRateMonitor(logging_interval='step')

callbacks = [checkpoint_callback, lr_monitor]

print(f"‚úÖ Checkpoint callback configured:")
print(f"   Output directory: {OUTPUT_DIR}")
print(f"   Experiment name: {EXPERIMENT_NAME}")
print(f"   Save every: {SAVE_EVERY_N_EPOCHS} epoch(s)")
print(f"   Keep last: {KEEP_LAST_N_CHECKPOINTS} checkpoint(s)")

# Setup trainer
trainer = pl.Trainer(
    max_epochs=NUM_EPOCHS,
    callbacks=callbacks,
    accelerator=ACCELERATOR,
    devices=DEVICES,
    precision=PRECISION,
    gradient_clip_val=GRADIENT_CLIP_VAL,
    log_every_n_steps=LOG_EVERY_N_STEPS,
    default_root_dir=str(checkpoint_callback.logs_dir)
)

print(f"\n‚úÖ Trainer configured successfully!")
print(f"   Logs directory: {checkpoint_callback.logs_dir}")


## 9. Start Training

‚ö†Ô∏è **Warning**: This will start the training process. Make sure all parameters are correctly configured above.


In [None]:
print("\n" + "=" * 70)
print("üöÄ STARTING TRAINING")
print("=" * 70)
print(f"Resuming from epoch {starting_epoch} ‚Üí Training until epoch {NUM_EPOCHS}")
print(f"Checkpoint: {checkpoint_to_use}")
print(f"Output: {OUTPUT_DIR}")
print("=" * 70 + "\n")

# Start training
trainer.fit(model, datamodule=data_module, ckpt_path=checkpoint_to_use)

print("\n" + "=" * 70)
print("‚úÖ TRAINING COMPLETED!")
print("=" * 70)
print(f"üìÅ New checkpoints saved to: {checkpoint_callback.run_dir}")
print("=" * 70)


## 10. Cleanup and Summary


In [None]:
import os

# Clean up temporary checkpoint if it was created
if checkpoint_to_use == "temp_fixed_checkpoint.ckpt":
    if os.path.exists("temp_fixed_checkpoint.ckpt"):
        os.remove("temp_fixed_checkpoint.ckpt")
        print("üßπ Cleaned up temporary checkpoint file")

# Display summary
print("\n" + "=" * 70)
print("üìä TRAINING SUMMARY")
print("=" * 70)
print(f"Original checkpoint:  {CHECKPOINT_TO_RESUME}")
print(f"Starting epoch:       {starting_epoch}")
print(f"Final epoch:          {NUM_EPOCHS}")
print(f"Epochs trained:       {NUM_EPOCHS - starting_epoch}")
print(f"Output directory:     {checkpoint_callback.run_dir}")
print(f"Logs directory:       {checkpoint_callback.logs_dir}")
print("=" * 70)

print("\nüí° Next steps:")
print("1. Check the output directory for new checkpoints")
print("2. View training logs in TensorBoard (if installed)")
print("3. Use the best checkpoint for inference or fine-tuning")
print("\n‚úÖ All done!")


## Optional: Inspect Checkpoint Contents


In [None]:
# Optional: Inspect the checkpoint to see what's inside
print("üîç Checkpoint contents:")
print("="*70)
for key in checkpoint.keys():
    if key == 'state_dict':
        print(f"  {key}: {len(checkpoint[key])} model weights")
    elif isinstance(checkpoint[key], (list, dict)):
        print(f"  {key}: {type(checkpoint[key]).__name__} (len={len(checkpoint[key])})")
    else:
        print(f"  {key}: {checkpoint[key]}")
print("="*70)
