# Deep Learning with PyTorch Step-by-Step: A Beginner's Guide

# Saving and Loading Models

In [None]:
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.utils.data.dataset import random_split
from torch.utils.tensorboard import SummaryWriter

import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('fivethirtyeight')

In [None]:
from plots.chapter2 import *

## Helper Functions

In [None]:
def make_train_step(model, loss_fn, optimizer):
    # Builds function that performs a step in the train loop
    def perform_train_step(x, y):
        # Sets model to TRAIN mode
        model.train()
        
        # Step 1 - computes model's predictions - forward pass
        yhat = model(x)
        # Step 2 - computes the loss
        loss = loss_fn(yhat, y)
        # Step 3 - computes gradients for "b" and "w" parameters
        loss.backward()
        # Step 4 - updates parameters using gradients and
        # the learning rate
        optimizer.step()
        optimizer.zero_grad()
        
        # Returns the loss
        return loss.item()
    
    # Returns the function that will be called inside the 
    # train loop
    return perform_train_step

In [None]:
def mini_batch(device, data_loader, step):
    mini_batch_losses = []
    for x_batch, y_batch in data_loader:
        x_batch = x_batch.to(device)
        y_batch = y_batch.to(device)

        mini_batch_loss = step(x_batch, y_batch)
        mini_batch_losses.append(mini_batch_loss)

    loss = np.mean(mini_batch_losses)
    return loss

In [None]:
def make_val_step(model, loss_fn):
    # Builds function that performs a step 
    # in the validation loop
    def perform_val_step(x, y):
        # Sets model to EVAL mode
        model.eval()     
        
        # Step 1 - Computes our model's predicted output
        # forward pass
        yhat = model(x)
        # Step 2 - Computes the loss
        loss = loss_fn(yhat, y)
        # There is no need to compute Steps 3 and 4, 
        # since we don't update parameters during evaluation
        return loss.item()
    
    return perform_val_step

## Data Generation

In [None]:
true_b = 1
true_w = 2
N = 100

# Data Generation
np.random.seed(42)
x = np.random.rand(N, 1)
epsilon = (.1 * np.random.randn(N, 1))
y = true_b + true_w * x + epsilon

### Generating training and validation sets

In [None]:
# Shuffles the indices
idx = np.arange(N)
np.random.shuffle(idx)

# Uses first 80 random indices for train
train_idx = idx[:int(N*.8)]
# Uses the remaining indices for validation
val_idx = idx[int(N*.8):]

# Generates train and validation sets
x_train, y_train = x[train_idx], y[train_idx]
x_val, y_val = x[val_idx], y[val_idx]

## Full Pipeline

In [None]:
%run -i data_preparation/v2.py
%run -i model_configuration/v3.py
%run -i model_training/v5.py

## Saving Model

In [None]:
checkpoint = {'epoch': n_epochs,
              'model_state_dict': model.state_dict(),
              'optimizer_state_dict': optimizer.state_dict(),
              'loss': losses,
              'val_loss': val_losses}

torch.save(checkpoint, 'model_checkpoint.pth')

In [None]:
## After running this, the model.checkpoint.pth file will show up, meaning the model has been saved
!find . -type f -name model_checkpoint.pth

## Loading Model

### Going to Untrained Model Stage

In [None]:
%run -i data_preparation/v2.py
%run -i model_configuration/v3.py

In [None]:
# checking model parameters of untrained model
print(model.state_dict())

### Loading the model of when training occurred

In [None]:
# Loading the trained model which we saved earlier
checkpoint = torch.load('model_checkpoint.pth')

model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

saved_epoch = checkpoint['epoch']
saved_losses = checkpoint['loss']
saved_val_losses = checkpoint['val_loss']

print(model.train()) # always use TRAIN for resuming training

### Checking model parameters after loading model

In [None]:
# checking model parameters
print(model.state_dict())

## Resuming Training

In [None]:
# running model training v5 script (training for additional 200 epochs)
%run -i model_training/v5.py

In [None]:
# checking model parameters
print(model.state_dict())

## Plotting Losses

In [None]:
fig = plot_resumed_losses(saved_epoch, saved_losses, saved_val_losses, n_epochs, losses, val_losses)