In [None]:
import torch
import random
import numpy as np
from PIL import Image
import torchvision
random.seed(0)
np.random.seed(0)
torch.manual_seed(0)
torch.cuda.manual_seed(0)
torch.backends.cudnn.deterministic = True

In [None]:
import torchvision.datasets

In [None]:
transforms = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(p=0.05),
    torchvision.transforms.RandomVerticalFlip(p=0.05),
    torchvision.transforms.RandomRotation([90, 180])
])

MNIST_train = torchvision.datasets.MNIST('./', download=True, train=True,transform=transforms)
MNIST_test = torchvision.datasets.MNIST('./', download=True, train=False,transform=transforms)

In [None]:
X_train = MNIST_train.data/255
y_train = MNIST_train.targets
X_test = MNIST_test.data/255
y_test = MNIST_test.targets

In [None]:
X_train = (X_train - X_train.min())/(X_train.max() - X_train.min())
X_test = (X_test - X_test.min())/(X_test.max() - X_test.min())

In [None]:
X_test.max()

In [None]:
import matplotlib.pyplot as plt
plt.imshow(X_train[0, :, :])
plt.show()
print(y_train[0])

In [None]:
X_train = X_train.unsqueeze(1).float()
X_test = X_test.unsqueeze(1).float()

In [None]:
class MNISTNet(torch.nn.Module):
    def __init__(self):
        super(MNISTNet, self).__init__()        

        activation_function  = torch.nn.Tanh()
        pooling_layer  = torch.nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv1_1 = torch.nn.Conv2d(
            in_channels=1, out_channels=6, kernel_size=3, padding=1)
        self.conv1_2 = torch.nn.Conv2d(
            in_channels=6, out_channels=6, kernel_size=3, padding=1)
        
        self.act1 = activation_function
        self.bn1 = torch.nn.BatchNorm2d(num_features=6)
        self.pool1 = pooling_layer
       

        self.conv2_1 = torch.nn.Conv2d(
            in_channels=6, out_channels=16, kernel_size=3, padding=0)
        self.conv2_2 = torch.nn.Conv2d(
            in_channels=16, out_channels=16, kernel_size=3, padding=0)

        self.act2 = activation_function
        self.bn2 = torch.nn.BatchNorm2d(num_features=16)
        self.pool2 = pooling_layer
        
        self.fc1 = torch.nn.Linear(5 * 5 * 16, 120)
        self.act3 = activation_function
    
        self.fc2 = torch.nn.Linear(120, 84)
        self.act4 = activation_function
        
        self.fc3 = torch.nn.Linear(84, 10)
    
    def forward(self, x):

        x = self.conv1_2(self.conv1_1(x))
        x = self.act1(x)
        x = self.bn1(x)
        x = self.pool1(x)
        x = self.conv2_2(self.conv2_1(x))
        x = self.act2(x)
        x = self.bn2(x)
        x = self.pool2(x)
        x = x.view(x.size(0), x.size(1) * x.size(2) * x.size(3))
        x = self.fc1(x)
        x = self.act3(x)
        x = self.fc2(x)
        x = self.act4(x)
        x = self.fc3(x)
        
        return x

In [None]:
def train(net, X_train, y_train, X_test, y_test):
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    net = net.to(device)
    loss = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=1.0e-4)
    
    batch_size = 500

    test_accuracy_history = []
    test_loss_history = []

    X_test = X_test.to(device)
    y_test = y_test.to(device)

    for epoch in range(50):
        order = np.random.permutation(len(X_train))
        for start_index in range(0, len(X_train), batch_size):
            optimizer.zero_grad()
            net.train()

            batch_indexes = order[start_index:start_index+batch_size]

            X_batch = X_train[batch_indexes].to(device)
            y_batch = y_train[batch_indexes].to(device)

            preds = net.forward(X_batch) 

            loss_value = loss(preds, y_batch)
            loss_value.backward()

            optimizer.step()

        net.eval()
        test_preds = net.forward(X_test)
        test_loss_history.append(loss(test_preds, y_test).data.cpu())

        accuracy = (test_preds.argmax(dim=1) == y_test).float().mean().data.cpu()
        test_accuracy_history.append(accuracy)
        print(f'{epoch}/30: {accuracy}')


    return test_accuracy_history, test_loss_history


In [None]:
net = MNISTNet()

accuracies, losses = train(net, X_train, y_train, X_test, y_test)

In [None]:
import matplotlib.pyplot as plt
plt.plot(losses)

In [None]:
plt.plot(accuracies)

In [None]:
torch.save(net.state_dict(), './mnist_net_99.pt')