# Checkpoint Manager Test

Quick test to verify checkpoint save/load functionality.

In [None]:
import os
import torch
from pathlib import Path
from datetime import datetime

print("‚úÖ Imports loaded")

## CheckpointManager Class

In [None]:
class CheckpointManager:
    """Checkpoint manager for testing."""
    
    def __init__(self, checkpoint_dir, save_every=5):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(exist_ok=True, parents=True)
        self.save_every = save_every
    
    def save_checkpoint(self, epoch, model, optimizer, scheduler=None, metrics=None, extra_state=None, is_best=False):
        checkpoint = {
            'epoch': epoch,
            'timestamp': datetime.now().isoformat(),
            'metrics': metrics or {},
            'extra_state': extra_state or {},
        }
        
        if isinstance(model, dict):
            checkpoint['model_state_dict'] = {name: m.state_dict() for name, m in model.items()}
        else:
            checkpoint['model_state_dict'] = model.state_dict()
        
        if isinstance(optimizer, dict):
            checkpoint['optimizer_state_dict'] = {name: opt.state_dict() for name, opt in optimizer.items()}
        else:
            checkpoint['optimizer_state_dict'] = optimizer.state_dict()
        
        if scheduler is not None:
            if isinstance(scheduler, dict):
                checkpoint['scheduler_state_dict'] = {name: sch.state_dict() for name, sch in scheduler.items()}
            else:
                checkpoint['scheduler_state_dict'] = scheduler.state_dict()
        
        checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
        torch.save(checkpoint, checkpoint_path)
        
        latest_path = self.checkpoint_dir / 'checkpoint_latest.pt'
        torch.save(checkpoint, latest_path)
        
        if is_best:
            best_path = self.checkpoint_dir / 'checkpoint_best.pt'
            torch.save(checkpoint, best_path)
        
        return checkpoint_path
    
    def load_checkpoint(self, checkpoint_path=None):
        if checkpoint_path is None:
            latest_path = self.checkpoint_dir / 'checkpoint_latest.pt'
            if latest_path.exists():
                checkpoint_path = latest_path
            else:
                checkpoints = sorted(self.checkpoint_dir.glob('checkpoint_epoch_*.pt'))
                if checkpoints:
                    checkpoint_path = checkpoints[-1]
                else:
                    return None
        
        checkpoint_path = Path(checkpoint_path)
        if not checkpoint_path.exists():
            return None
        
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        return checkpoint
    
    def restore_training_state(self, checkpoint, model, optimizer, scheduler=None):
        if isinstance(model, dict):
            for name, m in model.items():
                m.load_state_dict(checkpoint['model_state_dict'][name])
        else:
            model.load_state_dict(checkpoint['model_state_dict'])
        
        if isinstance(optimizer, dict):
            for name, opt in optimizer.items():
                opt.load_state_dict(checkpoint['optimizer_state_dict'][name])
        else:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scheduler is not None and 'scheduler_state_dict' in checkpoint:
            if isinstance(scheduler, dict):
                for name, sch in scheduler.items():
                    sch.load_state_dict(checkpoint['scheduler_state_dict'][name])
            else:
                scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        start_epoch = checkpoint['epoch'] + 1
        return start_epoch

print("‚úÖ CheckpointManager defined")

## Test 1: Basic Save/Load

In [None]:
# Setup
test_dir = Path('/tmp/test_checkpoints')
test_dir.mkdir(exist_ok=True)

manager = CheckpointManager(test_dir, save_every=1)

# Create dummy model
model = torch.nn.Linear(10, 10)
optimizer = torch.optim.Adam(model.parameters())
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)

print("‚úÖ Test setup complete")

In [None]:
# Save checkpoint
print("\nüìù Saving checkpoint...")
checkpoint_path = manager.save_checkpoint(
    epoch=5,
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    metrics={'loss': 0.123, 'accuracy': 0.95}
)

print(f"‚úÖ Checkpoint saved: {checkpoint_path}")

# Verify files exist
assert (test_dir / 'checkpoint_epoch_5.pt').exists(), "Epoch checkpoint not found"
assert (test_dir / 'checkpoint_latest.pt').exists(), "Latest checkpoint not found"
print("‚úÖ Checkpoint files verified")

In [None]:
# Load checkpoint
print("\nüìñ Loading checkpoint...")
checkpoint = manager.load_checkpoint()

assert checkpoint is not None, "Checkpoint not loaded"
assert checkpoint['epoch'] == 5, f"Wrong epoch: {checkpoint['epoch']}"
assert checkpoint['metrics']['loss'] == 0.123, "Wrong loss value"
assert checkpoint['metrics']['accuracy'] == 0.95, "Wrong accuracy value"

print("‚úÖ Checkpoint loaded successfully")
print(f"   Epoch: {checkpoint['epoch']}")
print(f"   Loss: {checkpoint['metrics']['loss']}")
print(f"   Accuracy: {checkpoint['metrics']['accuracy']}")

## Test 2: Resume Training State

In [None]:
# Create new model/optimizer (simulating restart)
new_model = torch.nn.Linear(10, 10)
new_optimizer = torch.optim.Adam(new_model.parameters())
new_scheduler = torch.optim.lr_scheduler.StepLR(new_optimizer, step_size=1)

# Restore state
print("\nüîÑ Restoring training state...")
start_epoch = manager.restore_training_state(
    checkpoint,
    model=new_model,
    optimizer=new_optimizer,
    scheduler=new_scheduler
)

assert start_epoch == 6, f"Wrong start epoch: {start_epoch}"
print(f"‚úÖ Training state restored")
print(f"   Next epoch: {start_epoch}")

## Test 3: Best Checkpoint

In [None]:
# Save best checkpoint
print("\nüíé Saving best checkpoint...")
manager.save_checkpoint(
    epoch=10,
    model=model,
    optimizer=optimizer,
    metrics={'loss': 0.050, 'accuracy': 0.98},
    is_best=True
)

assert (test_dir / 'checkpoint_best.pt').exists(), "Best checkpoint not found"
print("‚úÖ Best checkpoint saved")

# Verify best checkpoint
best_checkpoint = torch.load(test_dir / 'checkpoint_best.pt')
assert best_checkpoint['epoch'] == 10, "Wrong best epoch"
assert best_checkpoint['metrics']['loss'] == 0.050, "Wrong best loss"
print(f"‚úÖ Best checkpoint verified (epoch {best_checkpoint['epoch']}, loss {best_checkpoint['metrics']['loss']})")

## Test 4: Multiple Models

In [None]:
# Test with multiple models
print("\nüîÄ Testing multiple models...")

models = {
    'encoder': torch.nn.Linear(10, 5),
    'decoder': torch.nn.Linear(5, 10)
}

optimizers = {
    'encoder': torch.optim.Adam(models['encoder'].parameters()),
    'decoder': torch.optim.Adam(models['decoder'].parameters())
}

# Save
manager.save_checkpoint(
    epoch=15,
    model=models,
    optimizer=optimizers,
    metrics={'loss': 0.030}
)

# Load
checkpoint = manager.load_checkpoint()
assert 'encoder' in checkpoint['model_state_dict'], "Encoder state not found"
assert 'decoder' in checkpoint['model_state_dict'], "Decoder state not found"

print("‚úÖ Multiple models test passed")

## Test Summary

In [None]:
print("\n" + "="*60)
print("‚úÖ ALL TESTS PASSED!")
print("="*60)

print("\n‚úÖ Verified functionality:")
print("   ‚Ä¢ Save checkpoint")
print("   ‚Ä¢ Load checkpoint")
print("   ‚Ä¢ Resume training state")
print("   ‚Ä¢ Best checkpoint tracking")
print("   ‚Ä¢ Multiple models support")
print("   ‚Ä¢ Optimizer state persistence")
print("   ‚Ä¢ Scheduler state persistence")
print("   ‚Ä¢ Metrics tracking")

# Cleanup
import shutil
shutil.rmtree(test_dir, ignore_errors=True)
print("\nüßπ Test directory cleaned up")