## MNIST Pytorch 

In [11]:
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

In [13]:
parser = argparse.ArgumentParser(description='Mnist Arg Parser')
parser.add_argument('--train_epoch',type=int,default=32)

_StoreAction(option_strings=['--train_epoch'], dest='train_epoch', nargs=None, const=None, default=32, type=<class 'int'>, choices=None, help=None, metavar=None)

In [None]:
epochs = 20

In [2]:
# pin_memory: speed up the data transfer with pinned RAM (Only helpful for GPU)

# num_worker: asyncronous use of data transfer.

In [14]:
train_data = datasets.MNIST(root='data',train=True,download=False, transform=transforms.Compose([
    transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3081,))
]))
test_data  = datasets.MNIST(root='data',train=False,download=False, transform=transforms.Compose([transforms.ToTensor()]))

train_data_loader = DataLoader(train_data,shuffle=True,batch_size=64,num_workers=2, pin_memory=True)
test_data_loader  = DataLoader(test_data,shuffle=False,batch_size=64,num_workers=2, pin_memory=True)

In [None]:
## Calculation of flatern layer
# ((image height - kernel size+1 )/2   * (image width - kernel size +1) /2     )* number of filters
# 
# 
#  nllloss  -- log_softmax
#
#################
#One way to cut the computation graph is to use .detach(), which you may use when passing on a hidden 
#state when training RNNs with truncated backpropagation-through-time. It's also handy when differentiating
#a loss where one component is the output of another network, but this other network shouldn't be optimised 
#with respect to the loss - examples include training a discriminator from a generator's outputs in GAN training,
#or training the policy of an actor-critic algorithm using the value function as a baseline (e.g. A2C). 
#Another technique for preventing gradient calculations that is efficient in GAN training (training the generator 
#from the discriminator) and typical in fine-tuning is to loop through a networks parameters and set param.requires_grad = False.

In [4]:
# network 
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.Conv1   =  nn.Conv2d(1,10,kernel_size=5)
        self.Conv2   =  nn.Conv2d(10,20,kernel_size=5)
        self.conv_drop = nn.Dropout2d(.2)
        self.fc1     =  nn.Linear(320,128)
        self.fc2     =  nn.Linear(128,10)
    def forward(self,x):
        out =  F.relu(F.max_pool2d(self.Conv1(x),2))
        out =  F.relu(F.max_pool2d(self.conv_drop(self.Conv2(out)),2))
        #print('out1 shape ',out.size())
        out =  out.view(-1,320)
        out =  self.fc1(out)
        out =  self.fc2(out)
        out =  F.log_softmax(out,dim=1)
        #print('outLen ',len(out))
        return out
        
    

In [5]:
model = Net()

optimizer = torch.optim.SGD(lr=0.001,params=model.parameters())


In [17]:

#Training
model.train()

for epoch in range(epochs):
    epoch_loss = []
    for i,(images,labels) in enumerate(train_data_loader):
        optimizer.zero_grad()
        output = model(images)
        #print(len(output))
        loss = F.nll_loss(output,labels)
        torch.max(output)
        epoch_loss.append(loss.item()/64)
        loss.backward()
        optimizer.step()
        
    print('Epoch Loss', sum(epoch_loss))
        

Epoch Loss 1.3742467608681181
Epoch Loss 1.395789196511032
Epoch Loss 1.3499742877320386
Epoch Loss 1.3723979643691564
Epoch Loss 1.3350034433969995
Epoch Loss 1.3301743593765423
Epoch Loss 1.3098389453880372
Epoch Loss 1.2878030736756045
Epoch Loss 1.2910278917042888
Epoch Loss 1.2581438823108329
Epoch Loss 1.2549563892462174
Epoch Loss 1.2428965955332387
Epoch Loss 1.219821699996828
Epoch Loss 1.2083377115268377
Epoch Loss 1.2043094441105495
Epoch Loss 1.1814545951965556
Epoch Loss 1.1877205644486821
Epoch Loss 1.1746998113012523
Epoch Loss 1.1477801086075488
Epoch Loss 1.15787238865596


In [18]:
# Evaluating 

model.eval()
with torch.no_grad():
    
    correct = 0
    model.eval()
    for epoch in range(1):

        for i,(images,labels) in enumerate(test_data_loader):
            output = model(images)
            pred = output.argmax(dim=1,keepdim=True)
            correct += pred.eq(labels.view_as(pred)).sum().item()
    print("Accuracy {:2f}".format(correct/len(test_data_loader.dataset)*100))        

Accuracy 97.940000
