In [73]:
import torch
import torchvision
import math

In [74]:
# Load observations from the mnist dataset. The observations are divided into a training set and a test set
mnist_train = torchvision.datasets.MNIST('./data', train=True, download=True)
x_train = mnist_train.data.reshape(-1, 1, 28, 28).float()  # torch.functional.nn.conv2d argument must include channels (1)
y_train = torch.zeros((mnist_train.targets.shape[0], 10))  # Create output tensor
y_train[torch.arange(mnist_train.targets.shape[0]), mnist_train.targets] = 1  # Populate output

mnist_test = torchvision.datasets.MNIST('./data', train=False, download=True)
x_test = mnist_test.data.reshape(-1, 1, 28, 28).float()  # torch.functional.nn.conv2d argument must include channels (1)
y_test = torch.zeros((mnist_test.targets.shape[0], 10))  # Create output tensor
y_test[torch.arange(mnist_test.targets.shape[0]), mnist_test.targets] = 1  # Populate output

In [75]:
# Normalization of inputs
mean = x_train.mean()
std = x_train.std()
# Standard score = x - mean / std
x_train = (x_train - mean) / std
x_test = (x_test - mean) / std

# Divide training data into batches to speed up optimization
batches = 600
x_train_batches = torch.split(x_train, batches)
y_train_batches = torch.split(y_train, batches)

In [76]:
class ConvolutionalNeuralNetworkModel:
	def __init__(self):
		std = 1 / math.sqrt(1 * 5 * 5)
		# std = math.sqrt(6/(1 + 32)) # Xavier initializtion
		# 1 is the channel, 1 for greyscale and 3 for RGB 
		self.W1 = (torch.randn(32, 1, 5, 5) * std).requires_grad_()  # Shape: (out_channels, in_channels, image height, image width)
		self.b1 = (torch.randn(32) * std).requires_grad_()

		std = 1 / math.sqrt(32 * 5 * 5)
		# std = math.sqrt(6/(32 + 64)) # Xavier initializtion
		self.W2 = (torch.randn(64, 32, 5, 5) * std).requires_grad_()
		self.b2 = (torch.randn(64) * std).requires_grad_()

		std = 1 / math.sqrt(64 * 5 * 5)
		self.W3 = (torch.randn(16, 64, 4, 4) * std).requires_grad_()
		self.b3 = (torch.randn(16) * std).requires_grad_()

		std = 1 / math.sqrt(16 * 4 * 4)
		# std = math.sqrt(6/(64*7*7 + 10)) # Xavier initialization
		self.W4 = (torch.randn(16 * 4 * 4, 1024) * std).requires_grad_()  # Shape: (in_channels * image_pixels_after_max_pool, classes)
		self.b4 = (torch.randn(1024) * std).requires_grad_()

		std = 1 / math.sqrt(1024)
		self.W5 = (torch.randn(1024, 10) * std).requires_grad_()
		self.b5= (torch.randn(10) * std).requires_grad_()

	def parameters(self):
		return [self.W1, self.b1, self.W2, self.b2, self.W3, self.b3, self.W4, self.b4]

	def logits(self, x):
		x = torch.nn.functional.conv2d(x, self.W1, self.b1, padding=2)
		x = torch.nn.functional.relu(x)
		x = torch.nn.functional.max_pool2d(x, kernel_size=2)
		# x = torch.nn.functional.dropout(x, p=0.25)
		x = torch.nn.functional.conv2d(x, self.W2, self.b2, padding=2)
		x = torch.nn.functional.relu(x)
		x = torch.nn.functional.max_pool2d(x, kernel_size=2)
		# x = torch.nn.functional.dropout(x, p=0.10)
		x = torch.nn.functional.conv2d(x, self.W3, self.b3, padding=2)
		x = torch.nn.functional.relu(x)
		x = torch.nn.functional.max_pool2d(x, kernel_size=2)
		x = x.reshape(-1, 16 * 4 * 4) @ self.W4 + self.b4
		return x @ self.W5 + self.b5

	# Predictor
	def f(self, x):
		return torch.softmax(self.logits(x), dim=1)

	# Cross Entropy loss
	def loss(self, x, y):
		return torch.nn.functional.cross_entropy(self.logits(x), y.argmax(1))

	# Accuracy
	def accuracy(self, x, y):
		return torch.mean(torch.eq(self.f(x).argmax(1), y.argmax(1)).float())

In [77]:
model = ConvolutionalNeuralNetworkModel()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [78]:
for epoch in range(5):
    for batch in range(len(x_train_batches)):
        model.loss(x_train_batches[batch], y_train_batches[batch]).backward()  # Compute loss gradients
        optimizer.step()  # Perform optimization by adjusting W and b,
        optimizer.zero_grad()  # Clear gradients for next step

    print("accuracy = %s" % model.accuracy(x_test, y_test).item())

# Accuracy: 0.9867
# Accuracy with only ReLU: 0.9890
# Accuracy with ReLU and dropout: 0.9854
# Accuracy with only dropout: 0.9842

accuracy = 0.9711999893188477
accuracy = 0.9828000068664551
accuracy = 0.9866999983787537
accuracy = 0.9890000224113464
accuracy = 0.9869999885559082
