In [23]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

import matplotlib.pyplot as plt
import numpy as np
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch import optim
import torch.nn as nn
import torch.nn.functional as F

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ])
trainset = datasets.MNIST('MNIST_data/', download=True, train=True, transform=transform )
trainloader = torch.utils.data.DataLoader(trainset, batch_size=16, shuffle=True)
testset    = datasets.MNIST('MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=16)

In [24]:
class Network(nn.Module):

    def __init__(self):
        super(Network, self).__init__()
        self.conv1 = nn.Conv2d(1, 8, (3,3), 1, (1,1))
        self.conv2 = nn.Conv2d(8, 16, (3,3), 1, (1,1))
        self.pool = nn.MaxPool2d((2,2), (2,2))
        self.cf1 = nn.Linear(16*7*7, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)
        x = self.cf1(x)

        return x
    
model = Network()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

In [26]:
epochs = 3
best_acc = 0
for e in range(epochs):
    loss_train, loss_test = 0, 0
    
    model.train()
    for i, (images, labels) in enumerate(iter(trainloader)):

        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels)
        loss_train+=loss.item()
        loss.backward()
        optimizer.step()

    model.eval()
    with torch.no_grad():
        correct, total = 0, 0
        for i, (images, labels) in enumerate(iter(testloader)):

            test_output = model(images)
            loss = criterion(test_output, labels)
            loss_test += loss.item()

            pred_y = torch.max(test_output, 1)[1].data.squeeze()
            correct += (pred_y == labels).float().sum()
            total += len(labels)
            accuracy = (pred_y == labels).sum().item() / float(labels.size(0))

    acc = (correct/total).float()*100
    print(f'epoch {e+1} done: accuracy of {acc:.2f}, train_loss:{loss_train/len(trainloader.dataset)} and test_loss: {loss_test/len(testloader.dataset)}')

    model.train()
    if e == 0:
        diff_loss = 1000

    if abs(loss_test-loss_train) < diff_loss + diff_loss/90 and acc > best_acc:
        diff_loss = abs(loss_test-loss_train)
        best_acc = acc
        torch.save(model.state_dict(), 'model_cnn_mnist.pth')
        epoch_best = e

epoch 1 done: accuracy of 96.14, train_loss:0.00694112148287104 and test_loss: 0.008385878305119695
1000 0 83.85878305119695 416.4672889722624 tensor(96.1400) 0
epoch 2 done: accuracy of 96.95, train_loss:0.006874228710403652 and test_loss: 0.006922774167919124
332.60850592106544 1 69.22774167919124 412.45372262421915 tensor(96.9500) tensor(96.1400)
epoch 3 done: accuracy of 96.54, train_loss:0.007104952316032177 and test_loss: 0.006684733371458424
332.60850592106544 2 66.84733371458424 426.2971389619306 tensor(96.5400) tensor(96.1400)
