In [None]:
%pylab inline

In [None]:
import re
import torch

import collections

import numpy as np
import scipy.io
import scipy.ndimage

import PIL

import logging
logging.getLogger("PIL").setLevel(logging.INFO)

import common.plotting
import torch 
import torch.nn as nn
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import tensorflow

import os
import ot
import itertools
import datetime
import sys
from random import randint
from pyriemann.utils.distance import *

os.environ['TORCH_MODEL_ZOO'] =  os.environ['PYTORCH_DATA_PATH']

In [None]:
# We strongly recommend training using CUDA on lab computers
CUDA = True

def to_np(x):
    if isinstance(x, Variable):
        x = x.data
    return x.cpu().numpy()

def to_variable(x, **kwargs):
    x = torch.from_numpy(x)
    if CUDA:
        x = x.cuda()
    return Variable(x, **kwargs)

def log(text):
    print('%s | %s' % (datetime.datetime.now(), text))
    sys.stdout.flush()

class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view((x.size(0), ) + self.shape)

In [None]:
import ot

def ground_matrix(n):
    x = []
    for i in range(n):
        for j in range(n):
            x.append([i, j])
    x = np.array(x)
    M = ot.dist(x, x, 'sqeuclidean')
    return M

def sqeuclidean_wasserstein_distance(x, y):
    x = to_np(x)
    y = to_np(y)
    M = ground_matrix(x.shape[1])
    x = x.reshape(x.shape[0], -1)
    y = y.reshape(y.shape[0], -1)
    return torch.Tensor([ot.emd2(x[i], y[i], M) for i in range(0, x.shape[0])])

def kl(x, y):
    x = torch.squeeze(x, 1)
    y = torch.squeeze(y, 1)
    return torch.nn.functional.kl_div(x, y, size_average=False)

In [None]:
log('start')
data_path = os.environ.get('PYTORCH_DATA_PATH', '../data')

dataset = torchvision.datasets.MNIST(data_path, train=True, download=True)
imagesList = dataset.train_data
imagesList = imagesList.squeeze(1).float()
imagesList = imagesList.div(imagesList.sum(1).sum(1).unsqueeze(1).unsqueeze(1))
labelsList = dataset.train_labels

indexes = set()
while len(indexes) < 1000000:
    indexes.add((randint(0, imagesList.size(0)-1), randint(0, imagesList.size(0)-1)))
indexes = [j for i in list(indexes) for j in i]
log('finish')

In [None]:
log('start')

images = torch.index_select(imagesList, 0, torch.LongTensor(indexes))
labels = torch.index_select(labelsList, 0, torch.LongTensor(indexes))
images = images.view(-1, 2, images.size(1), images.size(2))
labels = labels.view(-1, 2)
distances = sqeuclidean_wasserstein_distance(images[:, 0], images[:, 1])
print(distances.sum())
print(distances.min())
print(distances.max())

n_train = int(images.size(0) * 0.7)
n_valid = int(images.size(0) * 0.9)
train_images = images[:n_train]
train_labels = labels[:n_train]
train_distances = distances[:n_train]

valid_images = images[n_train:n_valid]
valid_labels = labels[n_train:n_valid]
valid_distances = distances[n_train:n_valid]

test_images = images[n_valid:]
test_labels = labels[n_valid:]
test_distances = distances[n_valid:]

log('finish')

In [None]:
def compute_error_rate(model, x, y):
    batch_size = 100
    i = 0
    mse = 0.0
    while i < x.size(0):
        outputs = model.forward(x[i:i+batch_size])
        diff = (outputs - y[i:i+batch_size]).data
        mse = mse + torch.sum(diff ** 2)
        i = i + batch_size
    #return 'sum = %.2f, median = %.2f, mse = %.2f' % (diff.sum(), diff.median(), torch.sum(diff ** 2))
    return 'mse = %.2f' % (mse / x.size(0))

In [None]:
log('start')

num_epochs = 1000
learning_rate = 0.01
batch_size = 100

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(1, 20, kernel_size=3, padding=0),
                nn.ReLU()),
            nn.Sequential(
                nn.Conv2d(20, 10, kernel_size=3, padding=0),
                nn.ReLU()),
            nn.Sequential(
                nn.Conv2d(10, 5, kernel_size=5, padding=0),
                nn.ReLU()),
            Reshape(-1),
            nn.Sequential(
                nn.Linear(2000, 100),
                nn.ReLU()),
            nn.Sequential(
                nn.Linear(100, 50),
                nn.ReLU())
        )
        
#         self.decoder = nn.Sequential(
#             nn.Sequential(
#                 nn.Linear(50, 100),
#                 nn.ReLU()),
#             nn.Sequential(
#                 nn.Linear(100, 5 * 28 * 28),
#                 nn.ReLU()),
#             Reshape(5, 28, 28),
#             nn.Sequential(
#                 nn.Conv2d(5, 1, kernel_size=1, padding=0),
#                 nn.ReLU()),
#             nn.Sequential(
#                 nn.Conv2d(10, 20, kernel_size=3, padding=0),
#                 nn.ReLU()),
#             nn.Sequential(
#                 nn.Conv2d(20, 1, kernel_size=3, padding=0),
#                 nn.ReLU()),
#             nn.Softmax2d()
#         )
        
    def forward(self, data):
        data1 = data[:, 0].unsqueeze(1)
        encoder1 = self.encoder(data1)
        #decoder1 = self.decoder(encoder1)
        kl_factor1 = Variable(torch.zeros(data1.size(0)).cuda())
        
        data2 = data[:, 1].unsqueeze(1)
        encoder2 = self.encoder(data2)
        #decoder2 = self.decoder(encoder2)
        kl_factor2 = Variable(torch.zeros(data1.size(0)).cuda())
            
        encoder_difference = encoder1 - encoder2
        encoder_factor = torch.torch.matmul(encoder_difference, encoder_difference.transpose(0, 1)).diag()
        
        wasserstein_factor = Variable(torch.zeros(data1.size(0)).cuda())
        
        return encoder_factor#.sum(1)
#         return kl_factor1 + (encoder_factor - wasserstein_factor).pow(2) + kl_factor2
        
cnn = CNN()
if CUDA:
    cnn.cuda()

# Loss and Optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)

train_images_variable = Variable(train_images)
train_labels_variable = Variable(train_labels)
train_distances_variable = Variable(train_distances)

valid_images_variable = Variable(valid_images)
valid_labels_variable = Variable(valid_labels)
valid_distances_variable = Variable(valid_distances)

test_images_variable = Variable(test_images)
test_labels_variable = Variable(test_labels)
test_distances_variable = Variable(test_distances)

if CUDA:
    train_images_variable = train_images_variable.cuda()
    train_labels_variable = train_labels_variable.cuda()
    train_distances_variable = train_distances_variable.cuda()

    valid_images_variable = valid_images_variable.cuda()
    valid_labels_variable = valid_labels_variable.cuda()
    valid_distances_variable = valid_distances_variable.cuda()

    test_images_variable = test_images_variable.cuda()
    test_labels_variable = test_labels_variable.cuda()
    test_distances_variable = test_distances_variable.cuda()

# Train the Model
for epoch in range(num_epochs):
    i = 0
    while i < train_images.size(0):
        # Forward + Backward + Optimize
        optimizer.zero_grad()
        outputs = cnn(train_images_variable[i:i+batch_size])
        loss = criterion(outputs, train_distances_variable[i:i+batch_size])
        loss.backward()
        optimizer.step()
        
#         log((train_distances_variable - outputs).data[1])
#         log('Epoch [%d/%d], Loss: %.4f' %(epoch+1, num_epochs, loss.data[0]))
        i = i + batch_size
    log('epoch [%d/%d]' % (epoch + 1, num_epochs))
    log('train set errors:      ' + compute_error_rate(cnn, train_images_variable, train_distances_variable))
    log('validation set errors: ' + compute_error_rate(cnn, valid_images_variable, valid_distances_variable))
    print('')
        
# Test the Model
cnn.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
log('test set errors:       ' + compute_error_rate(cnn, test_images_variable, test_distances_variable))

# Save the Trained Model
torch.save(cnn.state_dict(), 'cnn.pkl')

log('finish')