In [None]:
from accelerate import Accelerator
import torch

In [None]:
project_dir = "./checkpoint_test/"

In [None]:
accelerator = Accelerator(project_dir=project_dir)

my_scheduler = torch.optim.lr_scheduler.StepLR(my_optimizer, step_size=1, gamma=0.99)
my_model, my_optimizer, my_training_dataloader = accelerator.prepare(my_model, my_optimizer, my_training_dataloader)

# Register the LR scheduler
accelerator.register_for_checkpointing(my_scheduler)

# Save the starting state
accelerator.save_state()

device = accelerator.device
my_model.to(device)

# Perform training
for epoch in range(num_epochs):
    for batch in my_training_dataloader:
        my_optimizer.zero_grad()
        inputs, targets = batch
        inputs = inputs.to(device)
        targets = targets.to(device)
        outputs = my_model(inputs)
        loss = my_loss_function(outputs, targets)
        accelerator.backward(loss)
        my_optimizer.step()
    my_scheduler.step()

# Restore the previous state
accelerator.load_state(f"{project_dir}/checkpointing/checkpoint_0")

In [None]:
from accelerate import Accelerator

accelerator = Accelerator(project_dir=project_dir)

train_dataloader = accelerator.prepare(train_dataloader)
accelerator.load_state("my_state")

# Assume the checkpoint was saved 100 steps into the epoch
skipped_dataloader = accelerator.skip_first_batches(train_dataloader, 100)

# After the first iteration, go back to `train_dataloader`

# First epoch
for batch in skipped_dataloader:
    # Do something
    pass

# Second epoch
for batch in train_dataloader:
    # Do something
    pass

# How to Use Checkpoint

In [None]:
from accelerate import Accelerator

# Initialize the model, optimizer, and scheduler
model = ...  # Your model here
optimizer = ...  # Your optimizer here
scheduler = ...  # Your scheduler here

# Initialize Accelerator
accelerator = Accelerator()

# Register for checkpointing
accelerator.register_for_checkpointing(model=model, optimizer=optimizer, lr_scheduler=scheduler)

# Frequency for saving checkpoints
save_frequency = 1000

# Training loop
for epoch in range(num_epochs):
    for step, batch in enumerate(train_dataloader):
        # Training step
        outputs = model(batch)
        loss = compute_loss(outputs, batch)
        accelerator.backward(loss)
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        
        # Save checkpoint based on frequency
        if step % save_frequency == 0:
            checkpoint_name = f'checkpoint_epoch_{epoch}_step_{step}.pth'
            accelerator.save_state(checkpoint_name)
