# How to restart a run from checkpoint

<a target="_blank" href="https://colab.research.google.com/github/neptune-ai/examples/blob/main/how-to-guides/restart-run-from-checkpoint/notebooks/neptune_save_restart_run_from_checkpoint.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open in Colab"/>
</a>
<a target="_blank" href="https://github.com/neptune-ai/examples/blob/main/how-to-guides/restart-run-from-checkpoint/notebooks/neptune_save_restart_run_from_checkpoint.ipynb">
  <img alt="Open in GitHub" src="https://img.shields.io/badge/Open_in_GitHub-blue?logo=github&labelColor=black">
</a>
<a target="_blank" href="https://app.neptune.ai/o/common/org/showroom/runs/details?viewId=standard-view&detailsTab=metadata&shortId=SHOW-32776&type=run"> 
  <img alt="Explore in Neptune" src="https://neptune.ai/wp-content/uploads/2024/01/neptune-badge.svg">
</a>
<a target="_blank" href="https://docs.neptune.ai/tutorials/restarting_from_checkpoint/">
  <img alt="View tutorial in docs" src="https://neptune.ai/wp-content/uploads/2024/01/docs-badge-2.svg">
</a>


## Introduction
Resuming ML experiments from checkpoints is essential to ensure progress is not lost in scenarios such as server disruptions or failures. 

By the end of this guide, you'll learn how to resume your experiment from a saved checkpoint.

## Before you start

This notebook example lets you try out Neptune as an anonymous user, with zero setup.

If you want to see the example logged to your own workspace instead:

  1. Create a Neptune account. [Register &rarr;](https://neptune.ai/register)
  1. Create a Neptune project that you will use for tracking metadata. For instructions, see [Creating a project](https://docs.neptune.ai/setup/creating_project) in the Neptune docs.

## Install Neptune and dependencies

In [None]:
!pip install neptune torch torchvision

## Import dependencies

In [None]:
import neptune
from neptune.utils import stringify_unsupported
from torch import load as torch_load
from torch import save as torch_save
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from typing import Dict, Any

In [None]:
# (Neptune) Set environment varibles
%env NEPTUNE_PROJECT=common/showroom
%env NEPTUNE_API_TOKEN=ANONYMOUS

## Utils

In [None]:
# (Neptune) Fetch and load checkpoints
def load_checkpoint(run: neptune.Run, epoch: int):
    checkpoint_name = f"epoch_{epoch}"
    ext = run["checkpoints"][checkpoint_name].fetch_extension()
    run["checkpoints"][checkpoint_name].download()  # Download the checkpoint
    checkpoint = torch_load(f"{checkpoint_name}.{ext}")  # Load the checkpoint
    return checkpoint


#


# (Neptune) Save and log checkpoints while training
def save_checkpoint(
    run: neptune.Run, model: nn.Module, optimizer: optim.Optimizer, epoch: int, loss: torch.tensor
):
    checkpoint = {
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "epoch": epoch,
        "loss": loss.item(),
    }
    checkpoint_name = f"checkpoint-{epoch}-{loss:.2f}.pth"
    torch_save(checkpoint, checkpoint_name)  # Save the checkpoint locally
    run[f"checkpoints/epoch_{epoch}"].upload(checkpoint_name)  # Upload to Neptune


def train(
    run: neptune.Run,
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: optim.Optimizer,
    parameters: Dict[str, Any],
    start_epoch: int = 0,
):
    for epoch in range(start_epoch, parameters["num_epochs"]):
        for i, (x, y) in enumerate(dataloader, 0):
            x, y = x.to(parameters["device"]), y.to(parameters["device"])
            optimizer.zero_grad()
            outputs = model(x)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, y)
            acc = (torch.sum(preds == y.data)) / len(x)

            # (Neptune) Log metrics
            run["metrics"]["batch/loss"].append(loss.item())
            run["metrics"]["batch/acc"].append(acc.item())

            loss.backward()
            optimizer.step()

        if epoch % parameters["ckpt_frequency"] == 0:
            # (Neptune) Log checkpoints
            save_checkpoint(run, model, optimizer, epoch, loss)

## Hyperparameters for training

In [None]:
parameters = {
    "lr": 1e-2,
    "batch_size": 128,
    "input_size": 32 * 32 * 3,
    "n_classes": 10,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "num_epochs": 1,
    "ckpt_frequency": 5,
}

## Load model and dataset

### Model

In [None]:
class Model(nn.Module):
    def __init__(self, input_size: int, hidden_dim: int, n_classes: int):
        super(Model, self).__init__()
        self.seq_model = nn.Sequential(
            nn.Linear(input_size, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Linear(hidden_dim // 2, n_classes),
        )

    def forward(self, input):
        x = input.view(-1, 32 * 32 * 3)
        return self.seq_model(x)


model = Model(parameters["input_size"], parameters["input_size"], parameters["n_classes"]).to(
    parameters["device"]
)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=parameters["lr"])

### Data

In [None]:
data_dir = "data/CIFAR10"
compressed_ds = "./data/CIFAR10/cifar-10-python.tar.gz"
data_tfms = {
    "train": transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
}

trainset = datasets.CIFAR10(data_dir, transform=data_tfms["train"], download=True)
trainloader = DataLoader(trainset, batch_size=parameters["batch_size"], shuffle=True, num_workers=0)

## Save checkpoint

In [None]:
# (Neptune) Initialize a new run
run = neptune.init_run()

In [None]:
# (Neptune) Log hyperparameters
run["parameters"] = stringify_unsupported(parameters)

In [None]:
train(run, model, trainloader, criterion, optimizer, parameters)

In [None]:
run_id = run["sys/id"].fetch()  # Get the run id to use downstream

## Stop logging

Once you are done logging, stop tracking the run.

In [None]:
run.stop()

## Fetch and load checkpoints from Neptune  

In [None]:
# (Neptune) Initialize existing run
run = neptune.init_run(
    with_id=run_id,  # Replace this with the ID of the run you want to restart
)

In [None]:
# (Neptune) Fetch hyperparameters
parameters = run["parameters"].fetch()
parameters["num_epochs"] = 2
run["parameters"] = stringify_unsupported(parameters)

# (Neptune) Fetch and load checkpoint
checkpoints = run.get_structure()["checkpoints"]
epochs = [
    int(checkpoint.split("_")[-1]) for checkpoint in checkpoints
]  # Fetch the epochs of the checkpoints
epochs.sort()  # Sort the epochs
epoch = epochs[-1]  # Fetch the last epoch
checkpoint = load_checkpoint(run, epoch)  # Load the checkpoint

# Load model and optimizer state
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

## Resume training from checkpoint

In [None]:
train(run, model, trainloader, criterion, optimizer, parameters, start_epoch=checkpoint["epoch"])

## Stop logging

Once you are done logging, stop tracking the run.

In [None]:
run.stop()

## Conclusion

You learned how to save, load, and resume the training from a saved checkpoint using Neptune.

Visit our docs for more tutorials and guides on how to use Neptune: https://docs.neptune.ai