3. Implement check points in PyTorch by saving model state_dict, optimizer state_dict, epochs
and loss during training so that the training can be resumed at a later point. Also, illustrate
the use of check point to save the best found parameters during training.
Use the original source program MNIST_CNN_Checkpoint.py
Use state_dict to save model parameters and Optimizer information.
Step 1: Re-run the MNIST program by appending the following command at the end. Make sure
checkpoints folder exists in the current working directory.

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

###Train for initial number of epochs and save checkpoint

In [2]:
class CNNClassifier(nn.Module):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(64, 128, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2),
            nn.Conv2d(128, 64, kernel_size=3),
            nn.ReLU(),
            nn.MaxPool2d((2, 2), stride=2)
        )
        self.classification_head = nn.Sequential(
            nn.Linear(64, 20, bias=True),
            nn.ReLU(),
            nn.Linear(20, 10, bias=True)
        )

    def forward(self, x):
        features = self.net(x)
        return self.classification_head(features.view(features.shape[0], -1))

In [3]:
def train(epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % LOG_INTERVAL == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))

In [4]:
def test():
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(DEVICE), target.to(DEVICE)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.data.max(1, keepdim=True)[1]
            correct += pred.eq(target.data.view_as(pred)).sum()
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))
    return test_loss

In [5]:
EPOCHS = 2
BATCH_SIZE_TRAIN = 64
BATCH_SIZE_TEST = 1000
LR = 0.01
LOG_INTERVAL = 100
RANDOM_SEED = 1
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(RANDOM_SEED)

<torch._C.Generator at 0x7a75fc1b55d0>

In [6]:
train_dataset = MNIST('data/', train=True, download=True, transform=ToTensor())
test_dataset = MNIST('data/', train=False, download=True, transform=ToTensor())

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE_TRAIN, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE_TEST, shuffle=True)

In [7]:
model = CNNClassifier().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=LR)

In [8]:
for epoch in range(1, EPOCHS + 1):
    train(epoch)
    avg_loss = test()

    check_point ={
      "last_loss": avg_loss,
      "last_epoch": epoch + 1,
      "model_state": model.state_dict(),
      "optimizer_state": optimizer.state_dict()
    }

    torch.save(check_point,"./checkpoints/checkpoint.pt")


Test set: Avg. loss: 0.0021, Accuracy: 2820/10000 (28%)


Test set: Avg. loss: 0.0004, Accuracy: 8845/10000 (88%)



To continue training in the same notebook, delete the variables. Otherwise we will have to create a seperate script to load the checkpoints and datasets

In [9]:
del model
del criterion
del optimizer
del check_point
del EPOCHS

###Use the checkpoint and resume the training


In [10]:
NEW_EPOCHS = 5

In [11]:
check_point = torch.load("./checkpoints/checkpoint.pt")

In [12]:
model = CNNClassifier().to(DEVICE)
model.load_state_dict(check_point["model_state"])

<All keys matched successfully>

In [13]:
criterion = nn.CrossEntropyLoss()

In [14]:
optimizer = optim.SGD(model.parameters(), lr=LR)
optimizer.load_state_dict(check_point["optimizer_state"])

In [15]:
EPOCHS = check_point["last_epoch"]

In [16]:
for epoch in range(EPOCHS, NEW_EPOCHS+1):
    train(epoch)
    avg_loss = test()

    check_point ={
      "last_loss": avg_loss,
      "last_epoch": epoch + 1,
      "model_state": model.state_dict(),
      "optimizer_state": optimizer.state_dict()
    }

    torch.save(check_point,"./checkpoints/checkpoint.pt")


Test set: Avg. loss: 0.0003, Accuracy: 9092/10000 (91%)


Test set: Avg. loss: 0.0002, Accuracy: 9439/10000 (94%)


Test set: Avg. loss: 0.0002, Accuracy: 9272/10000 (93%)

