In [1]:
import torch
import scipy
from scipy import ndimage
import matplotlib.pyplot as plt
%matplotlib inline

import torch,torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import numpy as np

from torchvision.datasets import MNIST
import torchvision.transforms as transforms

In [2]:
class Flatten(nn.Module):
    def forward(self,x):
        return x.view(x.size(0),-1)
    
class Empty(nn.Module):
    def forward(self,x):
        return x
    
def get_convnet(nonlinearity,dropout = 0,device = torch.device("cpu")):
    convnet = nn.Sequential(nn.Conv2d(1,6,5),
                            nonlinearity,
                            nn.MaxPool2d(2),
                            nn.Conv2d(6,16,5),
                            nonlinearity,
                            nn.MaxPool2d(2),
                            Flatten(),
                            nn.Linear(400,120),
                            nonlinearity,
                            nn.Linear(120,84),
                            nonlinearity,
                            nn.Dropout(dropout),
                            nn.Linear(84,10))
    return convnet.to(device)

def train_convnet(network,X,Y,num_iters = 1e6,threshold = 1e-3,device = torch.device("cpu")):
    network.train()
    X = Variable(torch.FloatTensor(X)).to(device)
    Y = Variable(torch.LongTensor(Y)).to(device)
    opt = torch.optim.SGD(network.parameters(),lr = 0.01,weight_decay=1e-4,momentum=0.9)
    NLLLoss = nn.modules.loss.NLLLoss()
    history = []
    for iter in range(num_iters):
        Y_logp = F.log_softmax(network(X),dim = -1)
        loss = NLLLoss(Y_logp,Y)
        history.append(loss.data.cpu().numpy())
        if iter > 0 and history[-1] - history[-2]>threshold:
            break
        opt.zero_grad()
        loss.backward()
        opt.step()
    network.eval()
    return network,history



def get_minst_data():
    trans = transforms.Compose([transforms.Pad(2),transforms.ToTensor(), transforms.Normalize((0.5,), (1.0,))])
    mnist = MNIST('MNIST',download=True,transform=trans)

    data_loader = torch.utils.data.DataLoader(mnist,
                                              batch_size=60000,
                                              shuffle=True)
    imags, labels = data_loader.__iter__().__next__()
    return imags, labels

In [None]:
convnet = torch.load('convnet.p', map_location='cpu')
imags, labels = get_minst_data()

In [None]:
plt.figure(figsize=(20, 10))
convnet.train()
IMAGE = imags[int((labels==1).sort()[1][-3])] # choose image of 1

size_experiment = 100
angles = np.linspace(0, 110, 12)

data_probas = np.zeros((10, len(angles), size_experiment))

for i_angle, angle in enumerate(angles):
    
    image = IMAGE.data.numpy()
    image = ndimage.interpolation.rotate(image[0], angle, reshape=False, cval=-0.5)
    plt.subplot(2, 6, i_angle + 1)
    plt.imshow(image)
    image = torch.unsqueeze(torch.unsqueeze(torch.tensor(image), 0), 0)
    
    for num_experiment in range(size_experiment):
        proba, val = F.softmax(convnet(image), dim=-1).sort()
        proba = proba.data.numpy()[0]
        val = val.data.numpy()[0]
        
        for i in range(10):        
            data_probas[val[i], i_angle, num_experiment] = proba[i]

In [None]:
plt.figure(figsize=(15,7))
for i, angle in enumerate(angles):
    plt.scatter([angle] * 100, data_probas[1, i], alpha=0.5, c='green', marker='_')
    plt.scatter([angle] * 100, data_probas[7, i], alpha=0.5, c='orange',  marker='_')
    plt.scatter([angle] * 100, data_probas[5, i], alpha=0.5, c='blue', marker='_')
    
plt.title('green = 1, blue = 5, orange = 7');