In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
import torchvision
import numpy as np
import cv2

In [0]:
#### Load the MNIST dataset #####
train_data = datasets.MNIST('~/data/mnist/', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST('~/data/mnist/', train=False, transform=transforms.ToTensor(), download=True)
data_loader = torch.utils.data.DataLoader(train_data, batch_size=128, 
            shuffle=True, num_workers=4, drop_last=True)

In [0]:
class Model(nn.Module):
  def __init__(self):
    super(Model, self).__init__()
    self.conv1 = nn.Conv2d(1, 16, 3, stride=1, padding=1)
    self.conv2 = nn.Conv2d(16, 8, 3, stride=1, padding=1)
    self.conv3 = nn.Conv2d(8, 8, 3, stride=1, padding=1)
    self.pool1 = nn.MaxPool2d(2, padding=0)
    self.pool2 = nn.MaxPool2d(2, padding=1)
    self.fc = nn.Linear(128, 10, bias=False)
  def forward(self, image):
    conv1 = self.conv1(image)
    relu1 = F.relu(conv1) #28x28x16
    pool1 = self.pool1(relu1) #14x14x16
    conv2 = self.conv2(pool1) #14x14x8
    relu2 = F.relu(conv2)
    pool2 = self.pool1(relu2) #7x7x8
    conv3 = self.conv3(pool2) #7x7x8
    relu3 = F.relu(conv3)
    pool3 = self.pool2(relu3) #4x4x8
    pool3 = pool3.view([image.size(0), 8*4*4]).cuda()
    fc = self.fc(pool3)
    logits = F.sigmoid(fc)
    return logits

In [0]:
model = Model()
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)


In [29]:
epochs = 15
for epoch in range(epochs): 
  print('Epoch {}\{}'.format(epoch,epochs-1))
  tot_loss = 0.0
  correct = 0
  scheduler.step()
  for data in data_loader:
    inputs, labels = data
    inputs, labels = Variable(inputs).cuda(), Variable(labels).cuda()
    optimizer.zero_grad()
    logits = model(inputs)
    _, preds = torch.max(logits.data,1)
    loss = criterion(logits, labels)
    loss.backward()
    optimizer.step()
    tot_loss += loss.data[0]
    correct += torch.sum(preds == labels.data)
  print('Epoch loss: ', tot_loss/len(train_data))
  print('Epoch acc: ', correct/len(train_data))
    
    

Epoch 0\14
Epoch loss:  0.011562374889850617
Epoch acc:  0.9753
Epoch 1\14
Epoch loss:  0.011558146832386653
Epoch acc:  0.9759166666666667
Epoch 2\14
Epoch loss:  0.011556019711494446
Epoch acc:  0.9757833333333333
Epoch 3\14
Epoch loss:  0.011553491445382437
Epoch acc:  0.9761333333333333
Epoch 4\14
Epoch loss:  0.011549216602245966
Epoch acc:  0.9767
Epoch 5\14
Epoch loss:  0.011547996417681377
Epoch acc:  0.97675
Epoch 6\14
Epoch loss:  0.01154741065899531
Epoch acc:  0.97685
Epoch 7\14
Epoch loss:  0.011547359971205394
Epoch acc:  0.9767333333333333
Epoch 8\14
Epoch loss:  0.011547131723165512
Epoch acc:  0.9768166666666667
Epoch 9\14
Epoch loss:  0.011546660949786505
Epoch acc:  0.9768333333333333
Epoch 10\14
Epoch loss:  0.011546688608328502
Epoch acc:  0.9769
Epoch 11\14
Epoch loss:  0.011545861105124155
Epoch acc:  0.9770666666666666
Epoch 12\14
Epoch loss:  0.011545821313063304
Epoch acc:  0.9770333333333333
Epoch 13\14
Epoch loss:  0.011545942836999893
Epoch acc:  0.97696666