In [1]:
# Imports
import torch
import torch.nn as nn
import torch.optim as optim  #optimization algorithms
import torch.nn.functional as F  #relu tanh ect..
import torch.utils.data as tdata  #DataLoader
# pip install torchvision
import torchvision.datasets as datasets  #mnist
import torchvision.transforms as transforms



In [2]:
# Set device 
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
# Hyperparameters
input_size = 784
num_classes = 10
learning_rate = 0.001
batch_size = 64
num_epochs = 10

In [4]:
# Load Data
train_dataset = datasets.MNIST(root='/tmp/dataset/', train=True, transform=transforms.ToTensor(), download=True)  #downloads as numpy, so convert it
train_loader = tdata.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_dataset = datasets.MNIST(root='/tmp/dataset/', train=False, transform=transforms.ToTensor(), download=True)  #downloads as numpy, so convert it
test_loader = tdata.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=True)

In [5]:
# Create Fully Connected Neural Network
class NN(nn.Module):
    def __init__(self, input_size, num_classes):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50,50)
        self.fc3 = nn.Linear(50,50)
        self.fc4 = nn.Linear(50, num_classes)
    def forward(self, x):
        x = F.tanh(self.fc1(x))  #or relu
        x = F.tanh(self.fc2(x))
        x = F.tanh(self.fc3(x))
        x = self.fc4(x)
        return x

In [6]:
def sanity_test():   
    """
    Just check to see if the neural network functions.
    If this works, output will be: 
    torch.Size([64, 10])
    """ 
    DATA_SIZE = 784                 # 784 is just random data, 28x28pxl
    model = NN(DATA_SIZE, 10)       # 10 is the number of digits, 
    x = torch.randn(64, DATA_SIZE)  # 
    print(model(x).shape)           


sanity_test()

torch.Size([64, 10])




In [7]:
# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval() #change to evaluation mode, lock the network
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=device)
            y = y.to(device=device)
            x = x.reshape(x.shape[0], -1)

            scores = model(x)  #64x10
            #print(scores)  # 64 images x 10 scores
            _, predictions = scores.max(1)  # the prediction for a digit (0-9) is the max score over all the digits
            #print(predictions)  #we have 64 images so this is the prediction of the digit for each image as a list
            num_correct += (predictions == y).sum()  #number correct 
            num_samples += predictions.size(0)

            print(f'{num_correct} / {num_samples}, accuracy: {100.0*float(num_correct)/float(num_samples)}' )
    model.train()

In [8]:
# Initialize network
model = NN(input_size=input_size, num_classes=num_classes).to(device)

In [9]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Train Network
for epoch in range(num_epochs): #1 epoch - the network has seen all the images in the dataset
    for batch_idx, (data, targets) in enumerate(train_loader):  #data=images target=correct label for each image
        #put to cuda if we can
        data = data.to(device=device)
        targets = targets.to(device=device)

        ##print(data.shape) # torch.Size([64, 1, 28, 28])  64=#images, 1=channels (black/white), 28x28 pixels
        #flaten the data into a single dimension
        data = data.reshape(data.shape[0], -1)
        ##print(data.shape)

        #forward
        scores = model(data)
        loss = criterion(scores, targets)

        #backward
        optimizer.zero_grad()  #don't store from previous forward props
        loss.backward()

        # gradient decent or adam step or whatever...
        optimizer.step()  #update the weights found in backward

    print("**********************************************")
    check_accuracy(test_loader, model)
    print("**********************************************")

**********************************************
61 / 64, accuracy: 95.3125
124 / 128, accuracy: 96.875
182 / 192, accuracy: 94.79166666666667
243 / 256, accuracy: 94.921875
302 / 320, accuracy: 94.375
364 / 384, accuracy: 94.79166666666667
424 / 448, accuracy: 94.64285714285714
486 / 512, accuracy: 94.921875
547 / 576, accuracy: 94.96527777777777
609 / 640, accuracy: 95.15625
667 / 704, accuracy: 94.74431818181819
726 / 768, accuracy: 94.53125
786 / 832, accuracy: 94.47115384615384
848 / 896, accuracy: 94.64285714285714
908 / 960, accuracy: 94.58333333333333
968 / 1024, accuracy: 94.53125
1029 / 1088, accuracy: 94.57720588235294
1092 / 1152, accuracy: 94.79166666666667
1154 / 1216, accuracy: 94.90131578947368
1215 / 1280, accuracy: 94.921875
1272 / 1344, accuracy: 94.64285714285714
1332 / 1408, accuracy: 94.60227272727273
1395 / 1472, accuracy: 94.76902173913044
1454 / 1536, accuracy: 94.66145833333333
1517 / 1600, accuracy: 94.8125
1579 / 1664, accuracy: 94.89182692307692
1641 / 1728, 

5151 / 5376, accuracy: 95.81473214285714
5212 / 5440, accuracy: 95.80882352941177
5275 / 5504, accuracy: 95.83938953488372
5334 / 5568, accuracy: 95.79741379310344
5395 / 5632, accuracy: 95.7919034090909
5456 / 5696, accuracy: 95.78651685393258
5517 / 5760, accuracy: 95.78125
5580 / 5824, accuracy: 95.81043956043956
5639 / 5888, accuracy: 95.7710597826087
5703 / 5952, accuracy: 95.81653225806451
5763 / 6016, accuracy: 95.79454787234043
5824 / 6080, accuracy: 95.78947368421052
5885 / 6144, accuracy: 95.78450520833333
5947 / 6208, accuracy: 95.7957474226804
6009 / 6272, accuracy: 95.80676020408163
6068 / 6336, accuracy: 95.77020202020202
6131 / 6400, accuracy: 95.796875
6192 / 6464, accuracy: 95.79207920792079
6256 / 6528, accuracy: 95.83333333333333
6316 / 6592, accuracy: 95.8131067961165
6377 / 6656, accuracy: 95.80829326923077
6438 / 6720, accuracy: 95.80357142857143
6498 / 6784, accuracy: 95.78419811320755
6560 / 6848, accuracy: 95.79439252336448
6622 / 6912, accuracy: 95.80439814814

**********************************************
63 / 64, accuracy: 98.4375
126 / 128, accuracy: 98.4375
187 / 192, accuracy: 97.39583333333333
248 / 256, accuracy: 96.875
309 / 320, accuracy: 96.5625
371 / 384, accuracy: 96.61458333333333
430 / 448, accuracy: 95.98214285714286
492 / 512, accuracy: 96.09375
551 / 576, accuracy: 95.65972222222223
613 / 640, accuracy: 95.78125
676 / 704, accuracy: 96.02272727272727
740 / 768, accuracy: 96.35416666666667
801 / 832, accuracy: 96.27403846153847
864 / 896, accuracy: 96.42857142857143
927 / 960, accuracy: 96.5625
987 / 1024, accuracy: 96.38671875
1048 / 1088, accuracy: 96.32352941176471
1110 / 1152, accuracy: 96.35416666666667
1173 / 1216, accuracy: 96.46381578947368
1234 / 1280, accuracy: 96.40625
1296 / 1344, accuracy: 96.42857142857143
1359 / 1408, accuracy: 96.51988636363636
1420 / 1472, accuracy: 96.46739130434783
1480 / 1536, accuracy: 96.35416666666667
1542 / 1600, accuracy: 96.375
1606 / 1664, accuracy: 96.51442307692308
1669 / 1728, ac

5200 / 5376, accuracy: 96.72619047619048
5264 / 5440, accuracy: 96.76470588235294
5326 / 5504, accuracy: 96.76598837209302
5390 / 5568, accuracy: 96.80316091954023
5453 / 5632, accuracy: 96.82173295454545
5516 / 5696, accuracy: 96.83988764044943
5578 / 5760, accuracy: 96.84027777777777
5642 / 5824, accuracy: 96.875
5702 / 5888, accuracy: 96.84103260869566
5764 / 5952, accuracy: 96.84139784946237
5828 / 6016, accuracy: 96.875
5891 / 6080, accuracy: 96.89144736842105
5951 / 6144, accuracy: 96.85872395833333
6012 / 6208, accuracy: 96.84278350515464
6076 / 6272, accuracy: 96.875
6134 / 6336, accuracy: 96.81186868686869
6197 / 6400, accuracy: 96.828125
6260 / 6464, accuracy: 96.8440594059406
6321 / 6528, accuracy: 96.82904411764706
6383 / 6592, accuracy: 96.82949029126213
6447 / 6656, accuracy: 96.85997596153847
6508 / 6720, accuracy: 96.8452380952381
6570 / 6784, accuracy: 96.84551886792453
6630 / 6848, accuracy: 96.81658878504673
6691 / 6912, accuracy: 96.80266203703704
6754 / 6976, accur

**********************************************
64 / 64, accuracy: 100.0
125 / 128, accuracy: 97.65625
189 / 192, accuracy: 98.4375
251 / 256, accuracy: 98.046875
313 / 320, accuracy: 97.8125
376 / 384, accuracy: 97.91666666666667
439 / 448, accuracy: 97.99107142857143
503 / 512, accuracy: 98.2421875
565 / 576, accuracy: 98.09027777777777
627 / 640, accuracy: 97.96875
689 / 704, accuracy: 97.86931818181819
752 / 768, accuracy: 97.91666666666667
815 / 832, accuracy: 97.95673076923077
879 / 896, accuracy: 98.10267857142857
943 / 960, accuracy: 98.22916666666667
1006 / 1024, accuracy: 98.2421875
1068 / 1088, accuracy: 98.16176470588235
1131 / 1152, accuracy: 98.17708333333333
1193 / 1216, accuracy: 98.10855263157895
1257 / 1280, accuracy: 98.203125
1316 / 1344, accuracy: 97.91666666666667
1376 / 1408, accuracy: 97.72727272727273
1437 / 1472, accuracy: 97.62228260869566
1499 / 1536, accuracy: 97.59114583333333
1562 / 1600, accuracy: 97.625
1623 / 1664, accuracy: 97.5360576923077
1685 / 1728

5155 / 5312, accuracy: 97.04442771084338
5216 / 5376, accuracy: 97.02380952380952
5280 / 5440, accuracy: 97.05882352941177
5344 / 5504, accuracy: 97.09302325581395
5406 / 5568, accuracy: 97.09051724137932
5468 / 5632, accuracy: 97.08806818181819
5528 / 5696, accuracy: 97.0505617977528
5589 / 5760, accuracy: 97.03125
5650 / 5824, accuracy: 97.01236263736264
5711 / 5888, accuracy: 96.99388586956522
5774 / 5952, accuracy: 97.00940860215054
5838 / 6016, accuracy: 97.04122340425532
5901 / 6080, accuracy: 97.05592105263158
5965 / 6144, accuracy: 97.08658854166667
6026 / 6208, accuracy: 97.06829896907216
6089 / 6272, accuracy: 97.08227040816327
6152 / 6336, accuracy: 97.0959595959596
6214 / 6400, accuracy: 97.09375
6276 / 6464, accuracy: 97.09158415841584
6337 / 6528, accuracy: 97.07414215686275
6400 / 6592, accuracy: 97.0873786407767
6461 / 6656, accuracy: 97.0703125
6523 / 6720, accuracy: 97.06845238095238
6586 / 6784, accuracy: 97.08136792452831
6648 / 6848, accuracy: 97.07943925233644
671

**********************************************
62 / 64, accuracy: 96.875
124 / 128, accuracy: 96.875
186 / 192, accuracy: 96.875
246 / 256, accuracy: 96.09375
310 / 320, accuracy: 96.875
373 / 384, accuracy: 97.13541666666667
436 / 448, accuracy: 97.32142857142857
499 / 512, accuracy: 97.4609375
562 / 576, accuracy: 97.56944444444444
624 / 640, accuracy: 97.5
687 / 704, accuracy: 97.58522727272727
751 / 768, accuracy: 97.78645833333333
813 / 832, accuracy: 97.71634615384616
876 / 896, accuracy: 97.76785714285714
938 / 960, accuracy: 97.70833333333333
1001 / 1024, accuracy: 97.75390625
1064 / 1088, accuracy: 97.79411764705883
1125 / 1152, accuracy: 97.65625
1186 / 1216, accuracy: 97.53289473684211
1248 / 1280, accuracy: 97.5
1310 / 1344, accuracy: 97.4702380952381
1374 / 1408, accuracy: 97.58522727272727
1435 / 1472, accuracy: 97.48641304347827
1498 / 1536, accuracy: 97.52604166666667
1560 / 1600, accuracy: 97.5
1623 / 1664, accuracy: 97.5360576923077
1684 / 1728, accuracy: 97.453703703