In [4]:
import torch
import torchvision
import torchvision.transforms as transforms

In [5]:
#load data
#this might take a while as it will download the dataset from internet
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
data_train = torchvision.datasets.MNIST('./', download=True, train=True, transform = transform)
data_test = torchvision.datasets.MNIST('./', download=True, train=False, transform = transform)

In [13]:
# create the data loaders
num_minibatches = 16
trainloader = torch.utils.data.DataLoader(data_train, batch_size=num_minibatches, shuffle=True, num_workers=2)
testloader = torch.utils.data.DataLoader(data_test, batch_size=num_minibatches, shuffle=False, num_workers=2)

In [14]:
# defining the network
import torch.nn as nn
import torch.nn.functional as F

# input_size = 1 x 28 x 28
input_size = data_train[0][0].shape.numel()
# print(input_size)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        
        # trim it from `input_size` down to 10, over 3 layers
        self.fc1 = nn.Linear(input_size, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    
    def forward(self, x):
        # get 3 channel image and transform it to `input_size` channel data
        x = x.view(-1, input_size)
        # use the ReLU activation function for the first 2 FC layers
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        # use softmax for the last layer to get normalized 0-1 numbers
        # set dim=1 to go alongside all channels in the final layer
        x = F.softmax(self.fc3(x), dim=1)
        return x

In [15]:
# define a network to operate on, a loss function, and an optimizer
import torch.optim as optim

net = Net()
criterion = nn.CrossEntropyLoss()
# use SGD to get to the right network weights
optimizer = optim.SGD(net.parameters(), lr = 0.001, momentum=0.9)

In [16]:
# train the network
num_epochs = 10

# train `num_epochs` times
for epoch in range(num_epochs):
    running_loss = 0.0
    print(f"-Epoch {epoch} started-")
    
    # go over all minibatches each time
    for minibatch_num, data in enumerate(trainloader, 0):
        
        # get current minibatch data
        inputs, labels = data
        
        # zero out previous net weights: start from fresh state
        optimizer.zero_grad()
        
        # do the forward pass and collect the loss from forward pass
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # do the backward pass
        loss.backward()
        
        # go on to next iteration in optimizer
        optimizer.step()
        
        
        # print stats just for data logging
        running_loss += loss.item()
        # at every 600 minibatches, print running loss
        # (printing every minibatch (3600) would have too much logging)
        checkpt_check_num = 600
        if minibatch_num % checkpt_check_num == checkpt_check_num - 1:
            print(f'Minibatch {minibatch_num}\t loss: {running_loss/checkpt_check_num}')
            # reset running loss to keep it averaged over the last
            # `checkpt_check_num` batches
            running_loss = 0.0
    if (epoch == 0):
        print(minibatch_num)

print("\n-Training completed-")

            
            
# save the model at the end
path = './mnist_net.pth'
torch.save(net.state_dict(), path)
print("Data saved")

-Epoch 0 started-
Minibatch 599	 loss: 2.295587578614553
Minibatch 1199	 loss: 2.251006892522176
Minibatch 1799	 loss: 2.0962278123696647
Minibatch 2399	 loss: 1.9719869295756023
Minibatch 2999	 loss: 1.825321408311526
Minibatch 3599	 loss: 1.7341150905688603
3749
-Epoch 1 started-
Minibatch 599	 loss: 1.688332492907842
Minibatch 1199	 loss: 1.6687026009956996
Minibatch 1799	 loss: 1.6645892107486724
Minibatch 2399	 loss: 1.656424816250801
Minibatch 2999	 loss: 1.6523048456509908
Minibatch 3599	 loss: 1.6465963554382324
-Epoch 2 started-
Minibatch 599	 loss: 1.6374105807145436
Minibatch 1199	 loss: 1.6397119824091593
Minibatch 1799	 loss: 1.638065475622813
Minibatch 2399	 loss: 1.6351583735148112
Minibatch 2999	 loss: 1.6293977467219034
Minibatch 3599	 loss: 1.6347827968994777
-Epoch 3 started-
Minibatch 599	 loss: 1.629635539650917
Minibatch 1199	 loss: 1.634259360631307
Minibatch 1799	 loss: 1.6304224667946499
Minibatch 2399	 loss: 1.6232998019456863
Minibatch 2999	 loss: 1.607082519

In [17]:
# test network on test data

correct = 0
total = 0
with torch.no_grad():
    for i, data in enumerate(testloader, 0):
        # get current minibatch data
        images, labels = data
        
        # get the network predictions
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        
        # add to logging
        num_datapts_in_batch = labels.size(0)
        total += num_datapts_in_batch
        num_correct = (predicted == labels).sum().item()
        print(f"Minibatch {i}\t{num_correct} of {num_datapts_in_batch} correct")
        correct += num_correct

print('Accuracy of the network on the test set: %d %%' % (100 * correct / total))

Minibatch 0	15 of 16 correct
Minibatch 1	16 of 16 correct
Minibatch 2	15 of 16 correct
Minibatch 3	15 of 16 correct
Minibatch 4	15 of 16 correct
Minibatch 5	14 of 16 correct
Minibatch 6	16 of 16 correct
Minibatch 7	14 of 16 correct
Minibatch 8	16 of 16 correct
Minibatch 9	14 of 16 correct
Minibatch 10	16 of 16 correct
Minibatch 11	16 of 16 correct
Minibatch 12	14 of 16 correct
Minibatch 13	16 of 16 correct
Minibatch 14	15 of 16 correct
Minibatch 15	13 of 16 correct
Minibatch 16	15 of 16 correct
Minibatch 17	16 of 16 correct
Minibatch 18	14 of 16 correct
Minibatch 19	15 of 16 correct
Minibatch 20	13 of 16 correct
Minibatch 21	14 of 16 correct
Minibatch 22	13 of 16 correct
Minibatch 23	16 of 16 correct
Minibatch 24	16 of 16 correct
Minibatch 25	16 of 16 correct
Minibatch 26	16 of 16 correct
Minibatch 27	14 of 16 correct
Minibatch 28	14 of 16 correct
Minibatch 29	14 of 16 correct
Minibatch 30	15 of 16 correct
Minibatch 31	13 of 16 correct
Minibatch 32	16 of 16 correct
Minibatch 33	12 of 1

Minibatch 279	15 of 16 correct
Minibatch 280	16 of 16 correct
Minibatch 281	14 of 16 correct
Minibatch 282	15 of 16 correct
Minibatch 283	15 of 16 correct
Minibatch 284	16 of 16 correct
Minibatch 285	14 of 16 correct
Minibatch 286	14 of 16 correct
Minibatch 287	15 of 16 correct
Minibatch 288	15 of 16 correct
Minibatch 289	14 of 16 correct
Minibatch 290	15 of 16 correct
Minibatch 291	16 of 16 correct
Minibatch 292	16 of 16 correct
Minibatch 293	15 of 16 correct
Minibatch 294	16 of 16 correct
Minibatch 295	14 of 16 correct
Minibatch 296	14 of 16 correct
Minibatch 297	14 of 16 correct
Minibatch 298	16 of 16 correct
Minibatch 299	15 of 16 correct
Minibatch 300	13 of 16 correct
Minibatch 301	15 of 16 correct
Minibatch 302	15 of 16 correct
Minibatch 303	16 of 16 correct
Minibatch 304	13 of 16 correct
Minibatch 305	13 of 16 correct
Minibatch 306	15 of 16 correct
Minibatch 307	14 of 16 correct
Minibatch 308	16 of 16 correct
Minibatch 309	14 of 16 correct
Minibatch 310	14 of 16 correct
Minibatc

Accuracy of the network on the test set: 93 %
