In [2]:
import itertools
from IPython.display import Image
from IPython import display
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [3]:
trn_dataset = datasets.MNIST(
    root="dataset",
    download=False,
    train=True,
    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081))
    ]))

val_dataset = datasets.MNIST(
    root="dataset",
    download=False,
    train=False,
    transform= transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307, ), (0.3081))
    ]))

In [4]:
batch_size = 64
trn_loader = torch.utils.data.DataLoader(trn_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

val_loader = torch.utils.data.DataLoader(val_dataset,
                                         batch_size=batch_size,
                                         shuffle=True)

In [5]:
class CNNClassifier(nn.Module):
    
    def __init__(self):
        super(CNNClassifier, self).__init__()
        conv1 = nn.Conv2d(1, 6, 5, 1)
        pool1 = nn.MaxPool2d(2)
        conv2 = nn.Conv2d(6, 16, 5, 1)
        pool2 = nn.MaxPool2d(2)
        
        self.conv_module = nn.Sequential(
            conv1,
            nn.ReLU(),
            pool1,
            conv2,
            nn.ReLU(),
            pool2
        )
        
        fc1 = nn.Linear(16*4*4, 120)
        fc2 = nn.Linear(120, 84)
        fc3 = nn.Linear(84, 10)
        
        self.fc_module = nn.Sequential(
            fc1,
            nn.ReLU(),
            fc2,
            nn.ReLU(),
            fc3
        )
        
    def forward(self, x):
        out = self.conv_module(x)
        dim = 1
        for d in out.size()[1:]:
            dim = dim * d
        out = out.view(-1, dim)
        out = self.fc_module(out)
        return F.softmax(out, dim=1)

In [6]:
cnn = CNNClassifier()

criterion = nn.CrossEntropyLoss()
learning_rate = 1e-3
optimizer = optim.Adam(cnn.parameters(), lr=learning_rate)

num_epochs = 2
num_batchs = len(trn_loader)

trn_loss_list = []
val_loss_list = []
for epoch in range(num_epochs):
    trn_loss = 0.0
    for i, data in enumerate(trn_loader):
        x, label = data
        
        optimizer.zero_grad()
        
        model_output = cnn(x)
        
        loss = criterion(model_output, label)
        
        loss.backward()
        
        optimizer.step()
        
        trn_loss += loss.item()
        del loss
        del model_output
        
        if (i+1) % 100 == 0:
            with torch.no_grad():
                val_loss = 0.0
                for j, val in enumerate(val_loader):
                    val_x, val_label = val
                    val_output = cnn(val_x)
                    v_loss = criterion(val_output, val_label)
                    val_loss += v_loss
                    
                print(f"epoch: {epoch+1}/{num_epochs} | step: {i+1}/{num_batchs} | trn loss: {trn_loss / 100:.4f} | val loss {val_loss/len(val_loader):.4f}")
                
                trn_loss_list.append(trn_loss/100)
                val_loss_list.append(val_loss/len(val_loader))
                trn_loss = 0.0

epoch: 1/2 | step: 100/938 | trn loss: 1.9246 | val loss 1.7249
epoch: 1/2 | step: 200/938 | trn loss: 1.7353 | val loss 1.7060
epoch: 1/2 | step: 300/938 | trn loss: 1.6914 | val loss 1.6272
epoch: 1/2 | step: 400/938 | trn loss: 1.6399 | val loss 1.6122
epoch: 1/2 | step: 500/938 | trn loss: 1.6206 | val loss 1.6073
epoch: 1/2 | step: 600/938 | trn loss: 1.6055 | val loss 1.6027
epoch: 1/2 | step: 700/938 | trn loss: 1.5982 | val loss 1.5891
epoch: 1/2 | step: 800/938 | trn loss: 1.5863 | val loss 1.5948
epoch: 1/2 | step: 900/938 | trn loss: 1.5909 | val loss 1.6003
epoch: 2/2 | step: 100/938 | trn loss: 1.5833 | val loss 1.5781
epoch: 2/2 | step: 200/938 | trn loss: 1.5903 | val loss 1.5783
epoch: 2/2 | step: 300/938 | trn loss: 1.5871 | val loss 1.5780
epoch: 2/2 | step: 400/938 | trn loss: 1.5841 | val loss 1.5753
epoch: 2/2 | step: 500/938 | trn loss: 1.5796 | val loss 1.5748
epoch: 2/2 | step: 600/938 | trn loss: 1.5835 | val loss 1.5781
epoch: 2/2 | step: 700/938 | trn loss: 1