In [2]:
import torch as tch
import torchvision.datasets as dt
import torchvision.transforms as trans
import torch.nn as nn
import matplotlib.pyplot as plt
from time import time

In [3]:
train = dt.MNIST(root="./datasets", train=True, transform=trans.ToTensor(), download=True)
test = dt.MNIST(root="./datasets", train=False, transform=trans.ToTensor(), download=True)
print("No. of Training examples: ",len(train))
print("No. of Test examples: ",len(test))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./datasets\MNIST\raw\train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting ./datasets\MNIST\raw\train-images-idx3-ubyte.gz to ./datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting ./datasets\MNIST\raw\train-labels-idx1-ubyte.gz to ./datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting ./datasets\MNIST\raw\t10k-images-idx3-ubyte.gz to ./datasets\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting ./datasets\MNIST\raw\t10k-labels-idx1-ubyte.gz to ./datasets\MNIST\raw

No. of Training examples:  60000
No. of Test examples:  10000


In [4]:
train_batch = tch.utils.data.DataLoader(train, batch_size=30, shuffle=True)

In [5]:
input = 784
hidden = 490
output = 10

In [6]:
model = nn.Sequential(nn.Linear(input, hidden),
                      nn.LeakyReLU(),
                      nn.Linear(hidden, output),
                      nn.LogSoftmax(dim=1))

In [7]:
lossfn = nn.NLLLoss()
images, labels = next(iter(train_batch))
images = images.view(images.shape[0], -1)

logps = model(images)
loss = lossfn(logps, labels)
loss.backward()

In [8]:
optimize = tch.optim.SGD(model.parameters(), lr=0.003, momentum=0.9)
time_start = time()
epochs = 18
for num in range(epochs):
    run=0
    for images, labels in train_batch:
        images = images.view(images.shape[0], -1)
        optimize.zero_grad()
        output = model(images)
        loss = lossfn(output, labels)
        loss.backward()
        optimize.step()
        run += loss.item()
    else:
        print("Epoch Number : {} = Loss : {}".format(num, run/len(train_batch)))
Elapsed=(time()-time_start)/60
print("\nTraining Time (in minutes) : ",Elapsed)

Epoch Number : 0 = Loss : 0.5149400651156902
Epoch Number : 1 = Loss : 0.261456840605475
Epoch Number : 2 = Loss : 0.20588867816049605
Epoch Number : 3 = Loss : 0.16964873825758695
Epoch Number : 4 = Loss : 0.1434834775705822
Epoch Number : 5 = Loss : 0.12429279719106853
Epoch Number : 6 = Loss : 0.10908355080941692
Epoch Number : 7 = Loss : 0.09697999537643046
Epoch Number : 8 = Loss : 0.08723836344201118
Epoch Number : 9 = Loss : 0.07917423069826328
Epoch Number : 10 = Loss : 0.07214489371958188
Epoch Number : 11 = Loss : 0.06623679360805546
Epoch Number : 12 = Loss : 0.060786034525139254
Epoch Number : 13 = Loss : 0.05600704051565845
Epoch Number : 14 = Loss : 0.05210975646332372
Epoch Number : 15 = Loss : 0.04836869774857769
Epoch Number : 16 = Loss : 0.045035426611895676
Epoch Number : 17 = Loss : 0.04181636443955358

Training Time (in minutes) :  3.022224660714467


In [9]:
correct=0
all = 0
for images,labels in test:
  img = images.view(1, 784)
  with tch.no_grad():
    logps = model(img)   
  ps = tch.exp(logps)
  probab = list(ps.numpy()[0])
  prediction = probab.index(max(probab))
  truth = labels
  if(truth == prediction):
    correct += 1
  all += 1

print("Number Of Images Tested : ", all)
print("Model Accuracy : ", (correct/all))

Number Of Images Tested :  10000
Model Accuracy :  0.9777


In [10]:
tch.save(model, './mnist_model.pt')