In [1]:
import torch
import torchvision

In [2]:
class Fully_Connected_Network(torch.nn.Module):
    def __init__(self, K):
        # call base class constructor
        super(Fully_Connected_Network, self).__init__()
        
        # init components
        self.W1 = torch.nn.Linear(784, K, bias=True)
        self.activation = torch.nn.Sigmoid()
        self.w2 = torch.nn.Linear(K, 10)
        
    def forward(self, x):
        return self.w2(self.activation(self.W1(x)))

In [3]:
class CNN_Network(torch.nn.Module):
    def __init__(self, K):
        # call base class constructor
        super(CNN_Network, self).__init__()
        
        # init components
        self.c1 = torch.nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(5,5), stride=1, padding=2) # 28x28 input
        self.c2 = torch.nn.Conv2d(in_channels=self.c1.out_channels, out_channels=64, kernel_size=(5,5), stride=1, padding=2)
        self.pool = torch.nn.MaxPool2d(kernel_size=(2,2), stride=2) # applied after first layer: 14x14 input ; applied after second layer: 7x7
        self.activation = torch.nn.Sigmoid()
        self.fc1 = torch.nn.Linear(7*7*64, 10, bias=True)
        
    def forward(self, x):
        a = self.activation(self.pool(self.c1(x)))
        a = self.activation(self.pool(self.c2(a)))
        a = torch.flatten(a,1)
        return self.fc1(a)

In [4]:
# obtain MNIST datasets
use_CNN = True # set to false if fully connected network should be used

transform = torchvision.transforms.ToTensor() if use_CNN else torchvision.transforms.Compose([torchvision.transforms.flatten, torchvision.transforms.ToTensor()])
train_set = torchvision.datasets.MNIST(root="data/MNIST",train=True, download=True,transform=transform)
test_set = torchvision.datasets.MNIST(root="data/MNIST",train=False, download=True,transform=transform)

# loaders
train_loader = torch.utils.data.DataLoader(
train_set, shuffle=True, batch_size=64)
test_loader = torch.utils.data.DataLoader(
test_set, shuffle=False, batch_size=100)

In [5]:
# define variables
K = 10 # hiddent neurons
epochs = 100
lr = 0.01
momentum = 0.9

# init network
network = CNN_Network(K) if use_CNN else Fully_Connected_Network(K) # choose network
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=network.parameters(), lr=lr, momentum=momentum)

# training epochs
for epoch in range(epochs):
    # iterate over training set
    for x,t in train_loader:
        optimizer.zero_grad()
        z = network(x)
        J = loss(z, t)
        J.backward()
        optimizer.step() # apply gradient step
    # compute test accuracy
    correct_class_count = 0
    with torch.no_grad():
        for x,t in test_loader:
            z = network(x)
            _, y = torch.max(z.data, 1) # returns maximum value and idx y
            correct_class_count += (y == t).sum().item()
    print(F"Epoch {epoch+1}: accuracy test: {correct_class_count / len(test_set):1.2f}")

Epoch 1: accuracy test: 0.85
Epoch 2: accuracy test: 0.91
Epoch 3: accuracy test: 0.94
Epoch 4: accuracy test: 0.96
Epoch 5: accuracy test: 0.96
Epoch 6: accuracy test: 0.97
Epoch 7: accuracy test: 0.97
Epoch 8: accuracy test: 0.98
Epoch 9: accuracy test: 0.98
Epoch 10: accuracy test: 0.98
Epoch 11: accuracy test: 0.98
Epoch 12: accuracy test: 0.98
Epoch 13: accuracy test: 0.98
Epoch 14: accuracy test: 0.98
Epoch 15: accuracy test: 0.98
Epoch 16: accuracy test: 0.98
Epoch 17: accuracy test: 0.98
Epoch 18: accuracy test: 0.99
Epoch 19: accuracy test: 0.98
Epoch 20: accuracy test: 0.99
Epoch 21: accuracy test: 0.99
Epoch 22: accuracy test: 0.99
Epoch 23: accuracy test: 0.99
Epoch 24: accuracy test: 0.99
Epoch 25: accuracy test: 0.99
Epoch 26: accuracy test: 0.99
Epoch 27: accuracy test: 0.99
Epoch 28: accuracy test: 0.99
Epoch 29: accuracy test: 0.99
Epoch 30: accuracy test: 0.99
Epoch 31: accuracy test: 0.99
Epoch 32: accuracy test: 0.99
Epoch 33: accuracy test: 0.99
Epoch 34: accuracy 