In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
#load data
#this might take a while as it will download the dataset from internet
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
data_train = torchvision.datasets.MNIST('./', download=True, train=True, transform = transform)
data_test = torchvision.datasets.MNIST('./', download=True, train=False, transform = transform)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100.1%

Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw


113.5%

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz
Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100.4%

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


180.4%

Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw
Processing...
Done!


In [129]:
# create the data loaders
num_minibatches = 32
trainloader = torch.utils.data.DataLoader(data_train, batch_size=num_minibatches, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(data_test, batch_size=num_minibatches, shuffle=False, num_workers=2)

In [130]:
# defining the network
import torch.nn as nn
import torch.nn.functional as F

# input_size = 1 x 28 x 28
input_size = data_train[0][0].shape.numel()
# print(input_size)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # trim it from 784 down to 10, over 3 layers
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        # get 3 channel image and transform it to 784 channel data
        x = x.view(-1, input_size)
        # use the ReLU activation function for the first 2 FC layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # use softmax for the last layer to get normalized 0-1 numbers
        # set dim=1 to go alongside all channels in the final layer
        x = F.softmax(self.fc3(x), dim=1)
        return x

In [131]:
# define a network to operate on, a loss function, and an optimizer
import torch.optim as optim

net = Net()
criterion = nn.CrossEntropyLoss()
# use SGD to get to the right network weights
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)

In [132]:
# train the network
num_epochs = 10

# train `num_epochs` times
for epoch in range(num_epochs):
    running_loss = 0.0
    print(f"-Epoch {epoch} started-")
    
    # go over all minibatches each time
    for i, data in enumerate(trainloader, 0):
        
        # get current minibatch data
        inputs, labels = data
        
        # zero out previous net weights: start from fresh state
        optimizer.zero_grad()
        
        # do the forward pass and collect the loss from forward pass
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # do the backward pass
        loss.backward()
        
        # go on to next iteration in optimizer
        optimizer.step()
        
        
        # print stats just for data logging
        running_loss += loss.item()
        # at every 300 minibatches, print running loss
        # (printing every minibatch would have too much logging)
        checkpt_check_num = 300
        if i % checkpt_check_num == checkpt_check_num - 1:
            print(f'Minibatch {i}\t loss: {running_loss/checkpt_check_num}')
            # reset running loss to keep it averaged over the last
            # `checkpt_check_num` batches
            running_loss = 0.0

print("\n-Training completed-")

            
            
# save the model at the end
path = './mnist_net.pth'
torch.save(net.state_dict(), path)
print("Data saved")

-Epoch 0 started-
Minibatch 299	 loss: 2.3005217051506044
Minibatch 599	 loss: 2.2950608078638712
Minibatch 899	 loss: 2.284980450471242
Minibatch 1199	 loss: 2.2552730282147726
Minibatch 1499	 loss: 2.202246105670929
Minibatch 1799	 loss: 2.1274591251214345
-Epoch 1 started-
Minibatch 299	 loss: 2.008053803841273
Minibatch 599	 loss: 1.9446010212103526
Minibatch 899	 loss: 1.9119017255306243
Minibatch 1199	 loss: 1.9027746431032817
Minibatch 1499	 loss: 1.8998741865158082
Minibatch 1799	 loss: 1.8795991428693135
-Epoch 2 started-
Minibatch 299	 loss: 1.8464713827768962
Minibatch 599	 loss: 1.8367745089530945
Minibatch 899	 loss: 1.8358353984355926
Minibatch 1199	 loss: 1.8199246442317962
Minibatch 1499	 loss: 1.8166569391886394
Minibatch 1799	 loss: 1.8119802391529083
-Epoch 3 started-
Minibatch 299	 loss: 1.7969027614593507
Minibatch 599	 loss: 1.7539490914344789
Minibatch 899	 loss: 1.7141488369305928
Minibatch 1199	 loss: 1.6900486210982004
Minibatch 1499	 loss: 1.675754518508911
M

In [133]:
# test network on test data

correct = 0
total = 0
with torch.no_grad():
    for i, data in enumerate(testloader, 0):
        # get current minibatch data
        images, labels = data
        
        # get the network predictions
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # add to logging
        num_datapts_in_batch = labels.size(0)
        total += num_datapts_in_batch
        num_correct = (predicted == labels).sum().item()
        print(f"Minibatch {i}\t{num_correct} of {num_datapts_in_batch} correct")
        correct += num_correct

print('Accuracy of the network on the test set: %d %%' % (100 * correct / total))

Minibatch 0	29 of 32 correct
Minibatch 1	25 of 32 correct
Minibatch 2	31 of 32 correct
Minibatch 3	27 of 32 correct
Minibatch 4	25 of 32 correct
Minibatch 5	26 of 32 correct
Minibatch 6	27 of 32 correct
Minibatch 7	28 of 32 correct
Minibatch 8	29 of 32 correct
Minibatch 9	27 of 32 correct
Minibatch 10	25 of 32 correct
Minibatch 11	25 of 32 correct
Minibatch 12	27 of 32 correct
Minibatch 13	28 of 32 correct
Minibatch 14	26 of 32 correct
Minibatch 15	25 of 32 correct
Minibatch 16	24 of 32 correct
Minibatch 17	27 of 32 correct
Minibatch 18	27 of 32 correct
Minibatch 19	24 of 32 correct
Minibatch 20	29 of 32 correct
Minibatch 21	27 of 32 correct
Minibatch 22	26 of 32 correct
Minibatch 23	27 of 32 correct
Minibatch 24	27 of 32 correct
Minibatch 25	30 of 32 correct
Minibatch 26	28 of 32 correct
Minibatch 27	29 of 32 correct
Minibatch 28	28 of 32 correct
Minibatch 29	25 of 32 correct
Minibatch 30	27 of 32 correct
Minibatch 31	30 of 32 correct
Minibatch 32	27 of 32 correct
Minibatch 33	26 of 3