In [3]:
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.transforms import Lambda
from torch.optim.lr_scheduler import StepLR
from torchsummary import summary
import numpy as np
import matplotlib.pyplot as plt
import os
import random

In [28]:
BATCH_SIZE = 64
TEST_BATCH_SIZE = 1000
EPOCHS = 14
LR = 0.001
LOG_INTERVAL = 10
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [29]:
# class MNIST(datasets.MNIST):
    
#     def __init__(self, path, train=True, download=False, transform=None):
#         super().__init__(root=path, train=train, download=download, transform=transform)
        
    
#     def __getitem__(self, index):
#         x = self.data[index]
#         y = self.targets[index]
#         return x, y
    
#     def __len__(self):
#         return(len(self.data))


###############################
# !!! You don't need this !!! #
###############################

In [30]:
# def img_to_poly(data):
#     data = data.view(-1, 784)
#     data = list(data.unbind())
#     data = [torch.arange(784)[(x > 0.8)] for x in data]
#     data = [torch.Tensor(list(x) + (351 - len(x)) * [0]) for x in data]
#     data = torch.stack(data).contiguous()
#     return torch.reshape(data, (len(data) , 1, 351))

# transform=transforms.Compose([
#     transforms.Lambda(lambda x: img_to_poly(x))
# ]) 


###################################################
# !!! Let's write it as a class (encapsulate) !!! #
# it can also by modified. but not now            #
# if you do not understand why I                  #
# use here __call__, read about callable objects  #
# in python                                       #
###################################################
class TensorToPolygon:
    
    
    def __call__(self, data):
        data = data.view(-1, 784)
        data = list(data.unbind())
        data = [torch.arange(784)[(x > 0.8)] for x in data]
        data = [torch.Tensor(list(x) + (351 - len(x)) * [0]) for x in data]
        data = torch.stack(data).contiguous()
        return torch.reshape(data, (len(data) , 1, 351))

In [31]:
# dataset1 = MNIST('~/Developer/datasets', train=True, transform=transform)
# dataset2 = MNIST('~/Developer/datasets', train=False, transform=transform)
dataset1 = datasets.MNIST('~/Developer/datasets', train=True, 
                          transform=transforms.Compose([
                              transforms.RandomCrop(28, padding=2), # random shift by +- 2 pixels in all direcitons
                              transforms.ToTensor(),
                              TensorToPolygon()
                          ]))
dataset2 = datasets.MNIST('~/Developer/datasets', train=False, 
                          transform=transforms.Compose([
                              transforms.ToTensor(),
                              TensorToPolygon()
                          ]))
# here the train dataset is with shifts
# let's first check whether it helps to improve the performance on the original test subset

In [32]:
# Do not forget to shuffle the training subset
train_loader = torch.utils.data.DataLoader(dataset1, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset2, batch_size=TEST_BATCH_SIZE, shuffle=False)

In [43]:
next(iter(train_loader))[0]

tensor([[[[ 66.,  67.,  68.,  ...,   0.,   0.,   0.]]],


        [[[155., 156., 157.,  ...,   0.,   0.,   0.]]],


        [[[ 98.,  99., 102.,  ...,   0.,   0.,   0.]]],


        ...,


        [[[234., 235., 236.,  ...,   0.,   0.,   0.]]],


        [[[258., 259., 287.,  ...,   0.,   0.,   0.]]],


        [[[131., 132., 158.,  ...,   0.,   0.,   0.]]]])

In [75]:
batch_size = 64
test_batch_size = 1000
epochs = 14
lr = 0.001
gamma = 0.7
log_interval = 10
dry_run = False
save_model = False


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(1, 32, 3, 1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, 1),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Dropout(0.25)
        )
        
        self.linear = nn.Sequential(
            nn.Linear(9216, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, 10)
        )
        
        
#         self.conv1 = nn.Conv2d(1, 32, 3, 1)
#         self.conv2 = nn.Conv2d(32, 64, 3, 1)
#         self.dropout1 = nn.Dropout(0.25)
#         self.dropout2 = nn.Dropout(0.5)
#         self.fc1 = nn.Linear(9216, 128)
#         self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
#         x = self.conv1(x)
#         x = F.relu(x)
#         x = self.conv2(x)
#         x = F.relu(x)
#         x = F.max_pool2d(x, 2)
#         x = self.dropout1(x)
#         x = torch.flatten(x, 1)
#         x = self.fc1(x)
#         x = F.relu(x)
#         x = self.dropout2(x)
#         x = self.fc2(x)


        x = self.conv(x)
        x = torch.flatten(x, 1)
        return self.linear(x)
#         output = F.log_softmax(x, dim=1)
#         return output

In [76]:
def train(model, device, train_loader, optimizer, epoch):
    model.train()
    criterion = nn.CrossEntropyLoss()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()), end="\r")
            if dry_run:
                break


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    criterion = nn.CrossEntropyLoss(reduction='sum')
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.1f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))



In [77]:
def main():
    # Training settings
    parser = argparse.ArgumentParser(description='PyTorch MNIST Example')

    device = "cpu"

    train_kwargs = {'batch_size': batch_size}
    test_kwargs = {'batch_size': test_batch_size}

    transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])
    dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
    dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1,**train_kwargs)
    test_loader = torch.utils.data.DataLoader(dataset2, **test_kwargs)

    model = Net().to(device)
    optimizer = optim.Adam(model.parameters(), lr=lr)

    scheduler = StepLR(optimizer, step_size=1, gamma=gamma)
    for epoch in range(1, epochs + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)
        scheduler.step()

    if save_model:
        torch.save(model.state_dict(), "mnist_cnn.pt")



main()

RuntimeError: mat1 and mat2 shapes cannot be multiplied (64x9216 and 9217x128)

In [73]:
transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
        ])

dataset1 = datasets.MNIST('../data', train=True, download=True,
                       transform=transform)
dataset2 = datasets.MNIST('../data', train=False,
                       transform=transform)

In [74]:
dataset2.data.shape

torch.Size([10000, 28, 28])