In [None]:
import torch 
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
from torch.autograd import Variable
import numpy as np
from matplotlib import pyplot as plt

device = 'cuda' 

In [None]:
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor()])

train_data = torchvision.datasets.MNIST(root = './data', train = True, download = True, transform = transform )

test_data = torchvision.datasets.MNIST(root = './data',train = False ,download = True,  transform = transform)

 

print(train_data)
print(test_data)

plt.imshow(np.asarray(train_data[1][0].reshape(28,28)))


In [None]:
BATCH_SIZE = 32 
train_loader = torch.utils.data.DataLoader (train_data, batch_size = BATCH_SIZE, )
test_loader = torch.utils.data.DataLoader (test_data, batch_size = 1, )

print(train_loader)

In [None]:
class MyConvNet(nn.Module):
  def __init__(self):
    super(MyConvNet, self).__init__()
    
    self.conv1 = nn.Conv2d(1, 16, kernel_size=3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
    self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
    self.conv4 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
    self.fc1 = nn.Linear(7*7*128, 256)
    self.fc2 = nn.Linear(256, 10)
    self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
   
    #__end
  def forward (self, x):
    
    x = F.relu(self.conv1(x))
    x = F.relu(self.conv2(x))
    x = self.maxpool(x)
    x = F.relu(self.conv3(x))
    x = F.relu(self.conv4(x))
    x = self.maxpool(x)
    x = x.view(-1,7*7*128)
    x = F.relu(self.fc1(x))
    x = self.fc2(x)

    return F.softmax(x, dim = 1) 


In [None]:
net = MyConvNet()
net.to(device)
print(net)
tld = iter(train_loader)
im = next(tld)[0].to(device)
print ('Size of the output tensor:' ,net.forward(im).size())
print (net.forward(im))

In [None]:
def train(model, train_loader, EPOCHS = 6, lossF = None):
  
  if lossF == None:
    lossF = nn.CrossEntropyLoss()  
  
  optim = torch.optim.Adam (model.parameters(), lr = 4e-4, weight_decay=1e-3) 
  
  model.train() 
  
  for epoch in range(EPOCHS):
    correct = 0
    for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
      
      #Move data to device
      var_X_batch = Variable(X_batch).to(device) 
      var_y_batch = Variable(y_batch).to(device)
    
      output = model(var_X_batch)
      
      loss = lossF (output, var_y_batch)
      
      loss.backward()
      
      optim.step()
      optim.zero_grad() 

      predicted = torch.max(output.data, axis = 1).indices

      #Calculate the number of correct predictions in a batch
      correct += (predicted == var_y_batch).sum()

      if (batch_idx % 200) == 0:
          print('Epoch : {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t Accuracy:{:.3f}%'.format(
                    epoch, batch_idx*len(X_batch), len(train_loader.dataset), 100.*batch_idx / len(train_loader), loss.item(), float(correct*100) / float(BATCH_SIZE*(batch_idx+1))))



def test(model, test_loader):
    
    correct = 0 
    
    for X,y in test_loader:
      X = Variable(X).to(device) 
      y = Variable(y).to(device) 
      output = model(X)
      predictions = torch.max(output.data,axis =1).indices
      correct += (predictions == y).sum()
  
    print("Test accuracy:{:.3f}% ".format( float(correct * 100) / (len(test_loader))))


In [None]:
train (net, train_loader, EPOCHS = 5)

In [None]:
test(net, test_loader)

Visualization of features learnt by the model

In [None]:
def train_im(model, digit = 7 ,iters = 1000, lossF = None):

  im = torch.zeros_like(train_data[1][0]).view(1, 1, 28, 28).to(device)
  im = Variable(im, requires_grad = True)
  digit = Variable(torch.tensor(digit)).to(device).view(1)
  
  lossfunc = nn.CrossEntropyLoss()
  optimizer = torch.optim.Adam ([im,], lr = 4e-4, weight_decay=1e-3) 
  
  for i in range(iters):
    output = model(im)
    loss = lossfunc(output, digit)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
  return im

In [None]:
for i in range (10):
  im = train_im (net, digit=i, iters=1000)
  im = np.asarray(im.view(28, 28).cpu().detach())
  print (i)
  plt.imshow(im)
  plt.show()