<a href="https://colab.research.google.com/github/martinpius/PYTORCH/blob/main/Saving_and_loading_Models_in_pytorch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [45]:
from google.colab import drive
drive.mount("/content/drive", force_remount = True)
try:
  COLAB = True
  import torch
  print(f"You are on Google CoLaB with torch version: {torch.__version__}")
except Exception as e:
  print(f"{type(e)}: {e}\n>>>please load your drive properly...")
  COLAB = False
#Assigning GPU device when available:
if torch.cuda.is_available():
  device = torch.device('cuda')
else:
  device = torch.device('cpu')
def time_fmt(t:float = 123.817)->float:
  h = int(t / (60 * 60))
  m = int(t % (60 * 60) / 60)
  s = int(t % 60)
  return f"{h}: {m:>02}: {s:>05.2f}"
print(f">>>time testing\tplease wait...\ntime elapse:\t{time_fmt()}")

Mounted at /content/drive
You are on Google CoLaB with torch version: 1.8.1+cu101
>>>time testing	please wait...
time elapse:	0: 02: 03.00


In [46]:
#In this short tutorial we are going to learn how to create a simple deep learning model in torch,
#saving the model into the device of our choice and then loading back the model whenerver needed:


In [47]:
#For demo, we train a simple cnn to classify mnist images:

In [48]:
#Importing neccessary libraries and modules from torch
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.datasets as datasets
from torchvision.transforms import transforms
from tqdm import tqdm
import time, datetime, sys, os

In [49]:
#Model's hyperparameters:
batch_size = 64
epochs = 10
in_channels = 1
num_classes = 10
learning_rate = 1e-3
load_model = True

In [50]:
#Creating the simple cnn using the following class:
class CNNBLOCK(nn.Module):
  def __init__(self, in_channels = 1, num_classes = 10, *args, **kwargs):
    super(CNNBLOCK, self).__init__(*args, **kwargs)
    self.conv1 = nn.Conv2d(in_channels = in_channels, 
                           out_channels = 8, 
                           kernel_size = (3,3),
                           stride = (1,1), 
                           padding = (1,1))
    self.maxpool = nn.MaxPool2d(kernel_size = (2,2), stride = (2,2))
    self.conv2 = nn.Conv2d(in_channels = 8, 
                           out_channels = 16, 
                           kernel_size = (3,3), 
                           stride = (1,1), 
                           padding = (1,1))
    self.fc1 = nn.Linear(in_features = 16*7*7, out_features = 256)
    self.fc2 = nn.Linear(in_features = 256, out_features = num_classes)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.maxpool(x)
    x = F.relu(self.conv2(x))
    x = self.maxpool(x)
    x = x.reshape(x.shape[0], -1)
    x = F.relu(self.fc1(x))
    return self.fc2(x)


In [51]:
#Instantiating our model class:
model = CNNBLOCK().to(device)

In [52]:
#Loading the data from torch:
train_data = datasets.MNIST(root = 'train_dataset/', train = True, transform = transforms.ToTensor(), download = True)
test_data = datasets.MNIST(root = 'test_dataset/', train = False, transform = transforms.ToTensor(), download = True)
train_loader = DataLoader(dataset = train_data, shuffle = True, batch_size = batch_size)
test_loader = DataLoader(dataset = test_data, shuffle = True, batch_size = batch_size)
x_train_batch, y_train_batch = next(iter(train_loader))
print(f"x_batch_shape: {x_train_batch.shape}, y_batch_shape: {y_train_batch.shape}")

x_batch_shape: torch.Size([64, 1, 28, 28]), y_batch_shape: torch.Size([64])


In [53]:
def save_checkpoint(state, filename = 'my_checkpoint.pth.tar'):
  print(">>>saving model checkpoint:")
  torch.save(state, filename)

def load_checkpoint(checkpoint):
  print(">>>>Loading model's checkpoint:")
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_dict'])


In [54]:
#Getting loss and optimizer objects
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.RMSprop(model.parameters(), lr = learning_rate)
#Loading checkpoint if any
if load_model:
  load_checkpoint(torch.load('my_checkpoint.pth.tar'))

>>>>Loading model's checkpoint:


In [55]:
#training loop:
tic = time.time()
for epoch in range(epochs):
  losses = []
  print(f">>>train starts for epoch: {epoch}\n>>>train on progress:\tplease wait while the model is training...")
  if epoch %2  == 0:
    checkpoint = {'state_dict': model.state_dict(),'optimizer_dict': optimizer.state_dict()}
    save_checkpoint(checkpoint)
  for idx, (data, target) in enumerate(tqdm(train_loader)):
    #assign the data to the gpu device if available
    data = data.to(device = device)
    target = target.to(device = device)
    #forwad pass
    preds = model(data)
    train_loss = loss_fn(preds, target)
    losses.append(train_loss.item())
    #backward pass
    optimizer.zero_grad()
    train_loss.backward()
    #gradient discent using RMSprop with lr of 1e-3
    optimizer.step()
  mean_loss = sum(losses)/len(losses)
  print(f">>>At epoch {epoch+1}: the mean loss is: {float(mean_loss):.4f}")
#Monitoring metrics for the train and validation data
def _evaluate_model_(loader, model):
  if loader.dataset.train:
    print(f"Checking metrics for the training set\n>>>please wait\tchecking on progress...")
  else:
    print(f"Checking metrics for the validation set\n>>>please wait\checking on progress...")
  num_correct = 0
  num_examples = 0
  model.eval()
  #no need to recompute grads
  with torch.no_grad():
    for x, y in loader:
      x = x.to(device)
      y = y.to(device)
      preds = model(x)
      _, predictions = preds.max(1) #maximum proba in a set of 10 classes
      num_correct+=(predictions ==y).sum()
      num_examples+=predictions.size(0)
  model.train()
  return num_correct/num_examples
toc = time.time()
print(f"Accuracy for the training data is {float(_evaluate_model_(train_loader, model))*100:.2f}")
print(f"Accuracy for the validation data is {float(_evaluate_model_(test_loader, model))*100:.2f}")
print(f"\n>>>training and evaluation time is : {time_fmt(toc - tic)}")


  2%|▏         | 18/938 [00:00<00:05, 173.14it/s]

>>>train starts for epoch: 0
>>>train on progress:	please wait while the model is training...
>>>saving model checkpoint:


100%|██████████| 938/938 [00:05<00:00, 176.59it/s]
  2%|▏         | 19/938 [00:00<00:05, 172.09it/s]

>>>At epoch 1: the mean loss is: 0.0025
>>>train starts for epoch: 1
>>>train on progress:	please wait while the model is training...


100%|██████████| 938/938 [00:05<00:00, 163.71it/s]
  2%|▏         | 15/938 [00:00<00:06, 141.00it/s]

>>>At epoch 2: the mean loss is: 0.0021
>>>train starts for epoch: 2
>>>train on progress:	please wait while the model is training...
>>>saving model checkpoint:


100%|██████████| 938/938 [00:05<00:00, 172.13it/s]
  1%|▏         | 13/938 [00:00<00:07, 126.37it/s]

>>>At epoch 3: the mean loss is: 0.0020
>>>train starts for epoch: 3
>>>train on progress:	please wait while the model is training...


100%|██████████| 938/938 [00:05<00:00, 160.16it/s]
  2%|▏         | 19/938 [00:00<00:04, 186.20it/s]

>>>At epoch 4: the mean loss is: 0.0023
>>>train starts for epoch: 4
>>>train on progress:	please wait while the model is training...
>>>saving model checkpoint:


100%|██████████| 938/938 [00:05<00:00, 164.61it/s]
  2%|▏         | 15/938 [00:00<00:06, 143.92it/s]

>>>At epoch 5: the mean loss is: 0.0021
>>>train starts for epoch: 5
>>>train on progress:	please wait while the model is training...


100%|██████████| 938/938 [00:05<00:00, 158.50it/s]
  2%|▏         | 18/938 [00:00<00:05, 173.63it/s]

>>>At epoch 6: the mean loss is: 0.0022
>>>train starts for epoch: 6
>>>train on progress:	please wait while the model is training...
>>>saving model checkpoint:


100%|██████████| 938/938 [00:06<00:00, 152.74it/s]
  2%|▏         | 15/938 [00:00<00:06, 141.71it/s]

>>>At epoch 7: the mean loss is: 0.0025
>>>train starts for epoch: 7
>>>train on progress:	please wait while the model is training...


100%|██████████| 938/938 [00:05<00:00, 158.09it/s]
  2%|▏         | 18/938 [00:00<00:05, 178.78it/s]

>>>At epoch 8: the mean loss is: 0.0026
>>>train starts for epoch: 8
>>>train on progress:	please wait while the model is training...
>>>saving model checkpoint:


100%|██████████| 938/938 [00:05<00:00, 175.51it/s]
  2%|▏         | 19/938 [00:00<00:04, 184.32it/s]

>>>At epoch 9: the mean loss is: 0.0025
>>>train starts for epoch: 9
>>>train on progress:	please wait while the model is training...


100%|██████████| 938/938 [00:05<00:00, 175.04it/s]


>>>At epoch 10: the mean loss is: 0.0023
Checking metrics for the training set
>>>please wait	checking on progress...
Accuracy for the training data is 99.91
Checking metrics for the validation set
>>>please wait\checking on progress...
Accuracy for the validation data is 99.07

>>>training and evaluation time is : 0: 00: 56.00
