In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
#load data
#this might take a while as it will download the dataset from internet
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
data_train = torchvision.datasets.MNIST('./', download=True, train=True, transform = transform)
data_test = torchvision.datasets.MNIST('./', download=True, train=False, transform = transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw


113.5%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


In [6]:
# create the data loaders
trainloader = torch.utils.data.DataLoader(data_train, batch_size=4, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(data_test, batch_size=4, shuffle=False, num_workers=2)

In [7]:
# defining classes: for MNIST, numbers 0-9
classes = tuple(range(10))

In [63]:
# defining the network
import torch.nn as nn
import torch.nn.functional as F

# input_size = 1 x 28 x 28
input_size = data_train[0][0].shape.numel()
# print(input_size)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # get data from input channels and upscale it to 784
        # now, trim it from 784 down to 10, progressively
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        # get 3 channel image and transform it to 784 channel data
        x = x.view(-1, input_size)
        # use the ReLU activation function for the first 2 FC layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # use softmax for the last layer to get normalized 0-1 numbers
        x = F.softmax(self.fc3(x), dim=0)
        return x

In [64]:
# define a network to operate on, a loss function, and an optimizer
import torch.optim as optim

net = Net()
criterion = nn.CrossEntropyLoss()
# use SGD to get to the right network weights
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)

In [67]:
# train the network
num_epochs = 10

# train `num_epochs` times
for epoch in range(num_epochs):
    running_loss = 0.0
    print(f"-Epoch {epoch} started-")
    
    # go over all minibatches each time
    for i, data in enumerate(trainloader, 0):
        
        # get current minibatch data
        inputs, labels = data
        
        # zero out previous net weights: start from fresh state
        optimizer.zero_grad()
        
        # do the forward pass and collect the loss from forward pass
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # do the backward pass
        loss.backward()
        
        # go on to next iteration in optimizer
        optimizer.step()
        
        
        # print stats just for data logging
        running_loss += loss.item()
        # at every 2000 data pts in the current minibatch, print the loss
        if i % 2000 == 1999:
            print(f'Minibatch {i // 2000}\t loss: {running_loss/2000}')
            running_loss = 0.0
print("\n-Training completed-")

            
            
# save the model at the end
path = './mnist_net.pth'
torch.save(net.state_dict(), path)
print("Data saved")

-Epoch 0 started-
Minibatch 0	 loss: 1.7689617218971252
Minibatch 1	 loss: 1.760445522069931
Minibatch 2	 loss: 1.758530782699585
Minibatch 3	 loss: 1.7608513691425323
Minibatch 4	 loss: 1.7555492604970933
Minibatch 5	 loss: 1.7630649082660674
Minibatch 6	 loss: 1.7554638268351554
-Epoch 1 started-
Minibatch 0	 loss: 1.7493897597789765
Minibatch 1	 loss: 1.7563895573616028
Minibatch 2	 loss: 1.758571307182312
Minibatch 3	 loss: 1.7499641982316971
Minibatch 4	 loss: 1.7602580374479293
Minibatch 5	 loss: 1.7508675047159195
Minibatch 6	 loss: 1.7589246668815612
-Epoch 2 started-
Minibatch 0	 loss: 1.7519245581626892
Minibatch 1	 loss: 1.749681302845478
Minibatch 2	 loss: 1.7527376952767373
Minibatch 3	 loss: 1.7476373903751374
Minibatch 4	 loss: 1.7583470113277435
Minibatch 5	 loss: 1.7586359758377075
Minibatch 6	 loss: 1.7494117301702499
-Epoch 3 started-
Minibatch 0	 loss: 1.7532571882605552
Minibatch 1	 loss: 1.7480087239146234
Minibatch 2	 loss: 1.7539122959971427
Minibatch 3	 loss: 1

In [68]:
# test network on test data

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        # get current minibatch data
        images, labels = data
        
        # get the network predictions
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # add to logging
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the test set: %d %%' % (100 * correct / total))

Accuracy of the network on the test set: 87 %
