In [187]:
import numpy as np
import torch
import torch.nn as nn
import torch.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
import lovely_tensors
lovely_tensors.monkey_patch()

In [189]:
train_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', 
                                                          download=True, 
                                                          train=True,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])), 
                                           batch_size=16, 
                                           shuffle=True)

# download and transform test dataset
test_loader = torch.utils.data.DataLoader(datasets.MNIST('../mnist_data', 
                                                          download=True, 
                                                          train=False,
                                                          transform=transforms.Compose([
                                                              transforms.ToTensor(), # first, convert image to PyTorch tensor
                                                              transforms.Normalize((0.1307,), (0.3081,)) # normalize inputs
                                                          ])), 
                                           batch_size=16, 
                                           shuffle=True)

In [190]:
class MNISTCLassifierA(nn.Module):
    def __init__(self):
        super(MNISTCLassifierA, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=5, stride=5)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=5, stride=5)
        self.dropout1 = nn.Dropout2d(p=0.25)
        self.batch_norm = nn.BatchNorm2d(64)
        self.fc1 = nn.Linear(64, 128)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        # print(x
        x = nn.ReLU()(self.conv1(x))
        # x = nn.ReLU(self.conv1(x))
        # print(x.shape)
        x = nn.ReLU()(self.batch_norm(self.conv2(x)))
        # print("SHAPE B4",x.shape)
        x = self.dropout1(x)
        x = x.view(-1, 64)
        x = nn.ReLU()(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return nn.Softmax()(x)


In [197]:
clf = MNISTCLassifierA()
opt = optim.SGD(clf.parameters(), lr=0.01, momentum=0.5)
ce_loss = torch.nn.CrossEntropyLoss()

loss_history = []
acc_history = []

def train(epoch):
    clf.train() # set model in training mode (need this because of dropout)
    
    # dataset API gives us pythonic batching 
    for batch_id, (data, label) in enumerate(train_loader):
        data = Variable(data)
        target = Variable(label)
        
        # forward pass, calculate loss and backprop!
        opt.zero_grad()
        preds = clf(data)
        loss = ce_loss(preds, target)
        loss.backward()
        loss_history.append(loss.item())
        opt.step()
        
        if batch_id % 100 == 0:
            print(loss.item())

In [198]:
def test(epoch):
    clf.eval() # set model in inference mode (need this because of dropout)
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data = Variable(data) 
            target = Variable(target)

            output = clf(data)
            test_loss += ce_loss(output, target).item()
            pred = output.data.max(1)[1] # get the index of the max log-probability
            correct += pred.eq(target.data).cpu().sum()

    test_loss = test_loss
    test_loss /= len(test_loader) # loss function already averages over batch size
    accuracy = 100. * correct / len(test_loader.dataset)
    acc_history.append(accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        accuracy))

In [199]:
for epoch in range(0, 3):
    print("Epoch %d" % epoch)
    train(epoch)
    test(epoch)

Epoch 0
2.305891275405884
2.2945823669433594
2.2299280166625977
2.177058219909668
2.1105549335479736
2.16782546043396
2.10642671585083
1.9610320329666138
2.090048313140869
1.949540138244629
2.0403521060943604
1.9811627864837646
2.0823163986206055
1.872231364250183
1.875415563583374
1.9539867639541626
1.9562400579452515
1.8272626399993896
1.8443777561187744
1.7755625247955322
1.756260871887207
1.6205592155456543
1.8064687252044678
1.767888069152832
1.7833431959152222
1.9090477228164673
1.669666051864624
1.7814490795135498
1.654571294784546
1.8014686107635498
1.658176302909851
1.7200509309768677
1.821246862411499
1.745455265045166
1.8819537162780762
1.607461929321289
1.643349051475525
1.805793285369873

Test set: Average loss: 1.6022, Accuracy: 8889/10000 (89%)

Epoch 1
1.64676034450531
1.5832107067108154
1.5542699098587036
1.541042447090149
1.776611089706421
1.7186245918273926
1.6118402481079102
1.4926457405090332
1.624924898147583
1.5144617557525635
1.5301281213760376
1.591982245445251

KeyboardInterrupt: 