In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np

In [2]:
#Uses gpu if available
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

batch_size = 8 #For gpu

#Downloads datasets
transform = transforms.ToTensor()
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=False, transform=transform) #Set download to true first time
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform) #Set download to true first time

#Split for cross validation
train_dataset, validation_set = torch.utils.data.random_split(train_dataset, [50000, 10000])

#Creates DataLoaders for each set
trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
valloader = torch.utils.data.DataLoader(validation_set, batch_size=batch_size, shuffle=False, num_workers=2)
testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)


In [3]:
class CNN(nn.Module):
    def __init__(self,):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=3, stride=1)
        self.pool1 = nn.MaxPool2d(2,2)
        self.conv2 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1)
        self.pool2 = nn.MaxPool2d(2,2)
        self.conv3 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=2, stride=1)
        self.pool3 = nn.MaxPool2d(2,2)

        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(in_features=4096, out_features=1024)
        self.drop1 = nn.Dropout(p=0.3) #Maybe want?
        self.fc2 = nn.Linear(in_features=1024, out_features=1024)
        self.drop2 = nn.Dropout(p=0.3)

        self.out = nn.Linear(in_features=1024, out_features=10)


    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)
        x = F.relu(self.conv2(x))
        x = self.pool2(x)
        x = F.relu(self.conv3(x))
        x = self.pool3(x)

        x = self.flatten(x)

        x = F.relu(self.fc1(x))
        x = self.drop1(x)
        x = F.relu(self.fc2(x))
        x = self.drop2(x)

        x = self.out(x)
        
        return x

In [4]:
cnn = CNN()
cnn.to(device)

CNN(
  (conv1): Conv2d(1, 256, kernel_size=(3, 3), stride=(1, 1))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1))
  (pool2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv3): Conv2d(512, 1024, kernel_size=(2, 2), stride=(1, 1))
  (pool3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=4096, out_features=1024, bias=True)
  (drop1): Dropout(p=0.3, inplace=False)
  (fc2): Linear(in_features=1024, out_features=1024, bias=True)
  (drop2): Dropout(p=0.3, inplace=False)
  (out): Linear(in_features=1024, out_features=10, bias=True)
)

In [5]:
def train_epoch(network, optimizer, criterion):
    network.train(True)
    running_loss = 0.0
    running_accuracy = 0.0

    for batch_index, data in enumerate(trainloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()

        outputs = network(inputs)
        num_correct = torch.sum(labels == torch.argmax(outputs, dim=1)).item()
        running_accuracy += num_correct / batch_size

        loss = criterion(outputs, labels)
        running_loss += loss.item()
        loss.backward()
        optimizer.step()

        if batch_index % 500 == 499:
            avg_loss_across_batches = running_loss / 500
            avg_acc_across_batches = (running_accuracy / 500) * 100
            print('Batch{0}, Loss: {1:.3f}, Accuracy: {2:.1f}%'.format(batch_index+1, avg_loss_across_batches, avg_acc_across_batches))
            running_loss = 0
            running_accuracy = 0

def validate_epoch(network, criterion):
    network.train(False)
    running_loss = 0.0
    running_accuracy = 0.0

    for i, data in enumerate(valloader):
        inputs, labels = data[0].to(device), data[1].to(device)

        with torch.no_grad():
            outputs = network(inputs)
            num_correct = torch.sum(labels == torch.argmax(outputs, dim=1)).item()
            running_accuracy += num_correct / batch_size

            loss = criterion(outputs, labels)
            running_loss += loss.item()

    avg_loss_across_batches = running_loss / len(valloader)
    avg_accuracy_across_batches = (running_accuracy / len(valloader)) * 100
    print('Val Loss: {0:.3f}, Val Accuracy: {1:.1f}%'.format(avg_loss_across_batches, avg_accuracy_across_batches))
    print('*****************************************')
    print()
            

In [10]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(cnn.parameters(), lr=0.0001)

num_epochs = 2

for i in range(num_epochs):
    train_epoch(cnn, optimizer, criterion)
    validate_epoch(cnn, criterion)

Batch500, Loss: 0.032, Accuracy: 99.2%
Batch1000, Loss: 0.033, Accuracy: 98.9%
Batch1500, Loss: 0.024, Accuracy: 99.3%
Batch2000, Loss: 0.039, Accuracy: 98.9%
Batch2500, Loss: 0.030, Accuracy: 99.1%
Batch3000, Loss: 0.034, Accuracy: 99.1%
Batch3500, Loss: 0.026, Accuracy: 99.3%
Batch4000, Loss: 0.029, Accuracy: 99.1%
Batch4500, Loss: 0.028, Accuracy: 99.2%
Batch5000, Loss: 0.026, Accuracy: 99.1%
Batch5500, Loss: 0.035, Accuracy: 99.1%
Batch6000, Loss: 0.026, Accuracy: 99.1%
Val Loss: 0.039, Val Accuracy: 98.9%
*****************************************

Batch500, Loss: 0.017, Accuracy: 99.5%
Batch1000, Loss: 0.015, Accuracy: 99.5%
Batch1500, Loss: 0.034, Accuracy: 99.2%
Batch2000, Loss: 0.019, Accuracy: 99.5%
Batch2500, Loss: 0.031, Accuracy: 99.0%
Batch3000, Loss: 0.020, Accuracy: 99.4%
Batch3500, Loss: 0.026, Accuracy: 99.2%
Batch4000, Loss: 0.031, Accuracy: 99.1%
Batch4500, Loss: 0.021, Accuracy: 99.4%
Batch5000, Loss: 0.025, Accuracy: 99.2%
Batch5500, Loss: 0.026, Accuracy: 99.2%
Ba

In [13]:
total_correct = 0
for batch_index, data in enumerate(testloader):
        inputs, labels = data[0].to(device), data[1].to(device)
        outputs = cnn(inputs)
        total_correct += torch.sum(labels == torch.argmax(outputs, dim=1)).item()

print(f"Accuracy {(total_correct/len(test_dataset)) * 100}")

Accuracy 99.22999999999999
