In this example, we show how you can use Neptune to create forks from previous runs when/if;
1. Create a new run (fork) from any previously saved checkpoint with a different set of hyperparameters
2. Your training run experiences training instability and you'd like to restart a new run from a previous checkpoint
3. A drop in hardware causes training problems and you need to restart the training run

In [7]:
from neptune_scale import Run
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from tqdm.auto import trange, tqdm

import math

In [47]:
parameters = {
    "batch_size": 64,
    "input_size": (1, 28, 28),
    "n_classes": 10,
    "epochs": 3,
    "device": torch.device("cuda:0" if torch.cuda.is_available() else "cpu"),
}

input_size = math.prod(parameters["input_size"])

In [65]:
updating_parameters = {
    "learning_rate": 0.05,
    "epochs": 10,
    "batch_size": 128}

In [26]:
class SimpleNN(nn.Module):
    def __init__(self, input_size):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_size, 128)  # Flatten 28*28 pixels to 128 neurons
        self.fc2 = nn.Linear(128, 10)     # 10 output classes

    def forward(self, x):
        x = x.view(-1, 28*28)  # Flatten the image (batch_size, 28, 28) -> (batch_size, 28*28)
        x = torch.relu(self.fc1(x))  # Apply ReLU after the first linear layer
        x = self.fc2(x)  # Output layer
        return x


criterion = nn.CrossEntropyLoss()


In [48]:
data_tfms = {
    "train": transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
}

trainset = datasets.MNIST(
    root="mnist",
    train=True,
    download=True,
    transform=data_tfms["train"],
)

trainloader = torch.utils.data.DataLoader(
    trainset,
    batch_size=parameters["batch_size"],
    shuffle=True,
    num_workers=0,
)

In [28]:
model = SimpleNN(
    input_size,
).to(parameters["device"])

In [62]:
# Save checkpoint file
# Function to save the checkpoint
import os
def save_checkpoint(epoch, step, model, optimizer, loss):
    # Function saves the checkpoint locally
    checkpoint_dir = './checkpoints'

    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)
    checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}.pth')
    torch.save({
        'epoch': epoch,
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch}")

# Function to load the checkpoint
def load_checkpoint(model, optimizer, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    epoch = checkpoint['epoch']
    step = checkpoint['step']
    loss = checkpoint['loss']
    print(f"Checkpoint loaded from step {step}, epoch {epoch}, loss: {loss:.4f}")
    return epoch, step, loss

In [72]:
# Define a generic training loop with Neptue logging
def train(model: nn.Module, updating_parameters, trainloader, run: Run = None, step = 0):
    # Training loop
    step_counter = step
    optimizer = optim.Adam(model.parameters(), lr=updating_parameters["learning_rate"])
    for epoch in range(1, updating_parameters["epochs"] + 1):
        model.train()
        running_loss = 0.0
        
        for batch_idx, (data, target) in enumerate(trainloader, start=step+1):
            step_counter += 1
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            
            # Log loss to Neptune
            if run is not None:
                run.log_metrics(
                    data={
                        "train/loss": loss.item(),
                        "train/epoch": epoch
                    },
                    step=step_counter
                )

        # Print average loss for the epoch
        avg_loss = running_loss / len(trainloader)
        print(f"Epoch {epoch} average loss: {avg_loss:.4f}")

        # Save checkpoint after each epoch
        save_checkpoint(epoch, step_counter, model, optimizer, avg_loss)

In [68]:
# Create initial experiment (fancy run)

run = Run(experiment_name="head_of_forking")

run.log_configs({
    "paramters/intial/config/n_classes": parameters["n_classes"],
    "paramters/changing/config/epochs": updating_parameters["epochs"],
    "paramters/changing/config/lr": updating_parameters["learning_rate"],
    "paramters/changing/config/batch_size": updating_parameters["batch_size"],
})

run.add_tags(["forks", "notebook"])

print(run.get_experiment_url())

https://scale.neptune.ai/leo/pytorch-tutorial/runs/details?runIdentificationKey=head_of_forking&type=experiment


In [73]:
# Execute training loop for initial experiment
train(model, updating_parameters, trainloader)


Epoch 1 average loss: 0.4472
Checkpoint saved at epoch 1
Epoch 2 average loss: 0.4196
Checkpoint saved at epoch 2
Epoch 3 average loss: 0.4100
Checkpoint saved at epoch 3
Epoch 4 average loss: 0.4055
Checkpoint saved at epoch 4
Epoch 5 average loss: 0.4272
Checkpoint saved at epoch 5
Epoch 6 average loss: 0.3910
Checkpoint saved at epoch 6
Epoch 7 average loss: 0.4181
Checkpoint saved at epoch 7
Epoch 8 average loss: 0.4615
Checkpoint saved at epoch 8
Epoch 9 average loss: 0.4319
Checkpoint saved at epoch 9
Epoch 10 average loss: 0.4553
Checkpoint saved at epoch 10


In [59]:
run.close()

2025-03-28 16:53:32,430 neptune:INFO: Waiting for all operations to be processed
2025-03-28 16:53:32,431 neptune:INFO: All operations were processed


In [74]:
# Start a fork from each checkpoint

checkpoint_path = os.path.join('./checkpoints', 'checkpoint_epoch_9.pth')
optimizer = optim.Adam(model.parameters(), lr=updating_parameters["learning_rate"])

epoch, step, loss = load_checkpoint(model, optimizer, checkpoint_path)

Checkpoint loaded from step 8442, epoch 9, loss: 0.4319


In [None]:
# Create a forked run
run = Run(
    experiment_name="forked_experiment",
    fork_run_id="",
    fork_step=step
)