# Checkpoint Testing Project

This project demonstrates the new checkpoint functionality for optimizer and learning rate scheduler states.

## Features Tested

- **Optimizer State Checkpointing**: Save and restore optimizer state including parameter-specific momentum and other state variables
- **LR Scheduler State Checkpointing**: Save and restore learning rate scheduler state to maintain proper learning rate schedules across training interruptions
- **Automatic Checkpoint Discovery**: Automatically find and resume from the most recent valid checkpoint
- **Modification Time-based Selection**: Use file modification times rather than step numbers for robust checkpoint discovery

## Configurations

1. **train.yaml**: Initial training with checkpointing enabled
2. **resume.yaml**: Resume training from the latest checkpoint

## Usage Instructions

### Step 1: Initial Training
Run the initial training configuration to create checkpoints:

In [None]:
import forgather.nb.notebooks as nb
nb.display_model_project_index()

In [None]:
nb.display_config()

In [None]:
from forgather import Project

# Load the checkpointing project to display trainer args
proj = Project("train.yaml")
trainer_args = proj("trainer_args")
print("Trainer Arguments Configuration:")
for key, value in trainer_args.items():
    print(f"  {key}: {value}")

In [None]:
# Display configuration details for train.yaml
print("=== Initial Training Configuration (train.yaml) ===")
config = proj.environment.load("configs/train.yaml") 
print("Configuration loaded successfully")
print(f"Uses project template with checkpointing enabled")

# Show key checkpointing-related trainer arguments
print("\n=== Key Checkpoint Settings ===")
trainer_args = proj("trainer_args")
checkpoint_settings = {
    'save_strategy': trainer_args.get('save_strategy'),
    'save_steps': trainer_args.get('save_steps'), 
    'save_total_limit': trainer_args.get('save_total_limit'),
    'max_steps': trainer_args.get('max_steps'),
    'save_optimizer_state': trainer_args.get('save_optimizer_state'),
    'save_scheduler_state': trainer_args.get('save_scheduler_state'),
    'resume_from_checkpoint': trainer_args.get('resume_from_checkpoint')
}

for key, value in checkpoint_settings.items():
    print(f"  {key}: {value}")

### Step 2: Run Initial Training

To run the training, use the command line:

```bash
cd examples/tiny_experiments/checkpointing
fgcli.py -t train.yaml -d 0
```

This will:
- Train for 500 steps
- Save checkpoints every 100 steps 
- Save optimizer and scheduler state with each checkpoint
- Create checkpoints in `output_models/default_model/checkpoints/`

In [None]:
training_script = proj()
training_script.run()

In [None]:
import os

# Check if checkpoints were created
checkpoint_dir = "output_models/default_model/checkpoints"
if os.path.exists(checkpoint_dir):
    checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith('checkpoint-')]
    print(f"Found {len(checkpoints)} checkpoints:")
    for cp in sorted(checkpoints):
        cp_path = os.path.join(checkpoint_dir, cp)
        files = os.listdir(cp_path)
        has_training_state = 'training_state.pt' in files
        print(f"  {cp}: {len(files)} files {'✓ Has training state' if has_training_state else '✗ No training state'}")
else:
    print("No checkpoints found. Run the initial training first.")
    print(f"Expected checkpoint directory: {checkpoint_dir}")

### Step 3: Resume Training Configuration

In [None]:
# Load the resume configuration to show trainer args
print("=== Resume Training Configuration (resume.yaml) ===")
proj_resume = Project("resume.yaml")
resume_trainer_args = proj_resume("trainer_args")

print("Configuration loaded successfully")
print(f"Uses project template with checkpoint resumption enabled")

# Show key resume-related trainer arguments
print("\n=== Key Resume Settings ===")
resume_settings = {
    'max_steps': resume_trainer_args.get('max_steps'),
    'resume_from_checkpoint': resume_trainer_args.get('resume_from_checkpoint'),
    'restore_optimizer_state': resume_trainer_args.get('restore_optimizer_state'),
    'restore_scheduler_state': resume_trainer_args.get('restore_scheduler_state'),
    'save_strategy': resume_trainer_args.get('save_strategy'),
    'save_steps': resume_trainer_args.get('save_steps')
}

for key, value in resume_settings.items():
    print(f"  {key}: {value}")

### Step 4: Run Resume Training

To resume training from the latest checkpoint:

```bash
cd examples/tiny_experiments/checkpointing
fgcli.py -t resume.yaml -d 0
```

This will:
- Automatically find the latest checkpoint by modification time
- Restore the model weights, optimizer state, and scheduler state
- Continue training from step 500 to step 800
- Maintain the learning rate schedule continuity

In [None]:
proj = Project("resume.yaml")
training_script = proj()
training_script.run()

### Step 5: Verify Checkpoint Functionality

In [None]:
import torch

# Examine a checkpoint's training state
checkpoint_dir = "output_models/default_model/checkpoints"
if os.path.exists(checkpoint_dir):
    checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith('checkpoint-')]
    if checkpoints:
        latest_checkpoint = sorted(checkpoints)[-1]
        training_state_path = os.path.join(checkpoint_dir, latest_checkpoint, "training_state.pt")
        
        print(f"=== Examining {latest_checkpoint} ===")
        if os.path.exists(training_state_path):
            training_state = torch.load(training_state_path, map_location="cpu")
            print(f"Training state keys: {list(training_state.keys())}")
            
            if 'global_step' in training_state:
                print(f"Global step: {training_state['global_step']}")
            
            if 'optimizer' in training_state:
                opt_state = training_state['optimizer']
                print(f"Optimizer state keys: {list(opt_state.keys())}")
                if 'param_groups' in opt_state:
                    print(f"Learning rate: {opt_state['param_groups'][0].get('lr', 'N/A')}")
            
            if 'lr_scheduler' in training_state:
                sched_state = training_state['lr_scheduler']
                print(f"Scheduler state keys: {list(sched_state.keys())}")
                print(f"Last epoch: {sched_state.get('last_epoch', 'N/A')}")
        else:
            print(f"No training state found in {latest_checkpoint}")
    else:
        print("No checkpoints found")
else:
    print("Checkpoint directory does not exist")
    print(f"Expected: {checkpoint_dir}")

## Expected Behavior

1. **Initial Training**: Creates checkpoints with both model weights and training state
2. **Resume Training**: 
   - Finds latest checkpoint by modification time
   - Loads model weights
   - Restores optimizer state (momentum, etc.)
   - Restores scheduler state (step count, learning rate schedule)
   - Continues training seamlessly

## Configuration Options

The checkpoint functionality can be controlled with these training arguments:

- `save_optimizer_state`: Save optimizer state in checkpoints
- `save_scheduler_state`: Save LR scheduler state in checkpoints  
- `restore_optimizer_state`: Restore optimizer state when resuming
- `restore_scheduler_state`: Restore scheduler state when resuming
- `resume_from_checkpoint`: Boolean (auto-find) or string (specific path)
- `save_total_limit`: Maximum number of checkpoints to keep