<a href="https://colab.research.google.com/github/myDSMLProjects/PyTorch-Fundamentals/blob/master/PyTorch_CNN_Saving_model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

In [18]:
class CNN(nn.Module):
  def __init__(self, input_size=1, num_classes=10):
    super(CNN, self).__init__()

    self.conv1 = nn.Conv2d(in_channels=input_size, out_channels=8, kernel_size=3, stride=1, padding=1)
    self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.conv2 = nn.Conv2d(in_channels=8, out_channels=16, kernel_size=3, stride=1, padding=1)
    self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
    self.fc1 = nn.Linear(in_features=16*7*7, out_features=num_classes)

  def forward(self, x):
    x = F.relu(self.conv1(x))
    x = self.pool1(x)
    x = F.relu(self.conv2(x))
    x = self.pool2(x)
    x = x.reshape(x.shape[0], -1)
    x = self.fc1(x)
    return x

def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
  print('=>Saving checkpoint')
  torch.save(state, filename)

def load_checkpoint(checkpoint):
  print("=>Loading checkpoint")
  model.load_state_dict(checkpoint['state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer'])

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [15]:
# Hyperparameters
learning_rate = 1e-3
batch_size = 1000
num_epochs=10
load_model=True

In [5]:
train_dataset = datasets.FashionMNIST(root='data/', download=True, train=True, transform=transforms.ToTensor())
test_dataset = datasets.FashionMNIST(root='data/', download=True, train=False, transform=transforms.ToTensor())

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Processing...



Done!


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [16]:
# Initilaize the network
model = CNN().to(device)
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(lr=learning_rate, params=model.parameters())
if load_model:
  load_checkpoint(torch.load('my_checkpoint.pth.tar'))

=>Loading checkpoint


In [19]:
# Load the model
for epoch in range(num_epochs):
  total_loss = 0
  total_correct = 0

  if epoch%3==0:
    checkpoint = {'state_dict': model.state_dict(), 'optimizer':optimizer.state_dict()}
    save_checkpoint(checkpoint)

  for batch_idx, (images, labels) in enumerate(train_loader):
    images = images.to(device)
    labels = labels.to(device)

    predicted_labels = model(images)
    loss = criterion(predicted_labels, labels)
    total_loss+=loss.item()
    total_correct+=(predicted_labels.argmax(dim=1)==labels).sum()

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print(f"Epoch: {epoch+1}/{num_epochs} \t Accuracy: {100*total_correct.item()/len(train_dataset):.2f} \t Loss: {total_loss:.2f}")

=>Saving checkpoint
Epoch: 1/10 	 Accuracy: 88.58 	 Loss: 19.693710386753082
Epoch: 2/10 	 Accuracy: 88.69 	 Loss: 19.45617577433586
Epoch: 3/10 	 Accuracy: 88.81 	 Loss: 19.217567533254623
=>Saving checkpoint
Epoch: 4/10 	 Accuracy: 88.86 	 Loss: 19.03289246559143
Epoch: 5/10 	 Accuracy: 89.08 	 Loss: 18.846978098154068
Epoch: 6/10 	 Accuracy: 89.17 	 Loss: 18.580260187387466
=>Saving checkpoint
Epoch: 7/10 	 Accuracy: 89.26 	 Loss: 18.475908398628235
Epoch: 8/10 	 Accuracy: 89.25 	 Loss: 18.365311473608017
Epoch: 9/10 	 Accuracy: 89.35 	 Loss: 18.170067995786667
=>Saving checkpoint
Epoch: 10/10 	 Accuracy: 89.50 	 Loss: 17.952729895710945
