# Saving and Restoring Models

* Once you train your models, you may save it and reload it later.
* You may save just the weights or the whole architecture 
* Saving the model including the architecture uses bit more space but well worth it


In [None]:
# Import standard libraries that you may use most times
import torch 
import torch.nn as nn
import matplotlib.pyplot as plt
from torchinfo import summary
from torchvision import transforms, datasets
print("PyTorch version:", torch.__version__)

# Bring back MNIST dataset

# Convert Pil image to PyTorch Tensor
data_transform = transforms.Compose([transforms.ToTensor()])

# Load the data set
mnist_train = datasets.MNIST(root='./data', train=True,  transform=data_transform, download=True)
mnist_test  = datasets.MNIST(root='./data', train=False, transform=data_transform, download=True)

# Prepare dataloaders
train_dataloader = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True)
test_dataloader  = torch.utils.data.DataLoader(mnist_test,  batch_size=64, shuffle=False)


# Bring back the MNIST classfier model
class NeuralNet(nn.Module):
    def __init__(self):
        super(NeuralNet, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1     = nn.Linear(28 * 28, 128)
        self.relu    = nn.ReLU()
        self.dropout = nn.Dropout(0.2)
        self.fc2     = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.flatten(x) # Flatten tensor
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.softmax(self.fc2(x))
        return x
    
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create model
model = NeuralNet().to(device)

# Print summary
summary(model, input_size=(1,28,28))


In [None]:
def train_model(epochs, optimizer):
    # Iterate over #epochs
    for epoch in range(num_epochs):
        # Keep track of network progress
        train_loss    = 0.0
        train_correct = 0
        test_correct  = 0

        # Visit each data sample once (random)
        for image, labels in train_dataloader: 
            # Compute model prediction and loss
            pred_labels = model(image.to(device))
            loss        = CEloos(pred_labels, labels.to(device))

            # Backpropagate
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()       

            # Add loss to history
            train_loss    += loss.item()
            # Count number of correct predictions 
            train_correct += (torch.argmax(pred_labels.cpu(), 1) == labels.cpu()).sum().item()

        # Test loop (once per epoch)
        with torch.no_grad():
            for images, labels in test_dataloader:
                pred_labels = model(images.to(device))
                test_correct += (torch.argmax(pred_labels.cpu(), 1) == labels.cpu()).sum().item()

        # Compute accuracy (train & test)
        train_acc = train_correct / len(mnist_train)
        test_acc  = test_correct  / len(mnist_test)
        print('Epoch [{}/{}], Loss: {:.4f}, Train Acc: {:.2f}%, Test Acc: {:.2f}%'
              .format(epoch, num_epochs, train_loss / len(mnist_train), 100 * train_acc, 100 * test_acc))

In [None]:
# Loss function
CEloos   = nn.CrossEntropyLoss()
optimizer_adam = torch.optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train
train_model(num_epochs, optimizer_adam)

## Save

In [None]:
# Two alternatives 
# Only saving the model
torch.save(model.state_dict(), "model/model.pt")

# Save model + optimizer + ... (training state)
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer_adam.state_dict(),
            }, "model/training_state")

## Restore

In [None]:
checkpoint = torch.load("model/training_state")
model.load_state_dict(checkpoint['model_state_dict'])
optimizer_adam.load_state_dict(checkpoint['optimizer_state_dict'])

## Test

In [None]:
# Sample and plot
samples = next(iter(test_dataloader))[0][:5]
labels  = next(iter(test_dataloader))[1][:5]

p = [torch.argmax(k) for k in model(samples.to(device))]
for i, l in enumerate(p):
    plt.figure(figsize=(1,1))
    plot_tensor(samples[i].squeeze(), labels[i].squeeze(), p[i].item())
    plt.show()