In [1]:
%pylab inline

Populating the interactive namespace from numpy and matplotlib


In [2]:
import re
import torch

import collections

import numpy as np
import scipy.io
import scipy.ndimage

import PIL

import logging
logging.getLogger("PIL").setLevel(logging.INFO)

import common.plotting
import torch 
import torch.nn as nn
import torchvision
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.autograd import Variable
import tensorflow

import os
import ot
import itertools
import datetime
import sys
from random import randint
from pyriemann.utils.distance import *

os.environ['TORCH_MODEL_ZOO'] =  os.environ['PYTORCH_DATA_PATH']

In [3]:
# We strongly recommend training using CUDA on lab computers
CUDA = True

LOG_FILE = 'stdout.txt'

if os.path.exists(LOG_FILE):
    os.remove(LOG_FILE)

def to_np(x):
    if isinstance(x, Variable):
        x = x.data
    return x.cpu().numpy()

def to_variable(x, **kwargs):
    x = torch.from_numpy(x)
    if CUDA:
        x = x.cuda()
    return Variable(x, **kwargs)

def log(text):
    text = '%s | %s' % (datetime.datetime.now(), text)
    with open(LOG_FILE, 'a') as file:
        file.write(text + '\n')
        file.flush()
    print(text)
    sys.stdout.flush()

class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args

    def forward(self, x):
        return x.view((x.size(0), ) + self.shape)

In [4]:
import ot

def ground_matrix(n):
    x = []
    for i in range(n):
        for j in range(n):
            x.append([i, j])
    x = np.array(x)
    M = ot.dist(x, x, 'sqeuclidean')
    return M

def sqeuclidean_wasserstein_distance(x, y):
    x = to_np(x)
    y = to_np(y)
    M = ground_matrix(x.shape[1])
    x = x.reshape(x.shape[0], -1)
    y = y.reshape(y.shape[0], -1)
    return torch.Tensor([ot.emd2(x[i], y[i], M) for i in range(0, x.shape[0])])

def kl(x, y):
    x = torch.squeeze(x, 1)
    y = torch.squeeze(y, 1)
    return torch.nn.functional.kl_div(x, y, size_average=False)

In [5]:
log('start')
samples = 100000

patch='/pio/scratch/1/i233123/data_mnist'
try:
    log('read images, labels, distances from files')
    images = torch.load('%s/%d/images.pt' % (patch, samples))
    labels = torch.load('%s/%d/labels.pt' % (patch, samples))
    distances = torch.load('%s/%d/distances.pt' % (patch, samples))
except Exception as e:
    log('error, calucalting new data')
    data_path = os.environ.get('PYTORCH_DATA_PATH', '../data')

    dataset = torchvision.datasets.MNIST(data_path, train=True, download=True)
    imagesList = dataset.train_data
    imagesList = imagesList.squeeze(1).double()
    imagesList = imagesList.div(imagesList.sum(1).sum(1).unsqueeze(1).unsqueeze(1))
    labelsList = dataset.train_labels

    indexes = set()
    while len(indexes) < samples:
        indexes.add((randint(0, imagesList.size(0)-1), randint(0, imagesList.size(0)-1)))
    indexes = [j for i in list(indexes) for j in i]

    images = torch.index_select(imagesList, 0, torch.LongTensor(indexes))
    labels = torch.index_select(labelsList, 0, torch.LongTensor(indexes))
    images = images.view(-1, 2, images.size(1), images.size(2))
    labels = labels.view(-1, 2)
    distances = sqeuclidean_wasserstein_distance(images[:, 0], images[:, 1])
    
    images = images.float()
    distances = distances.float()
    
    if not os.path.exists(patch):
        os.mkdir(patch)
    if not os.path.exists('%s/%d' % (patch, samples)):
        os.mkdir('%s/%d' % (patch, samples))
        
    torch.save(images, '%s/%d/images.pt' % (patch, samples))
    torch.save(labels, '%s/%d/labels.pt' % (patch, samples))
    torch.save(distances, '%s/%d/distances.pt' % (patch, samples))
    
log('%d samples count' % images.size(0))
log('distances sum: %.2f' % distances.sum())
log('distances min: %.2f' % distances.min())
log('distances max: %.2f' % distances.max())

n_train = int(images.size(0) * 0.7)
n_valid = int(images.size(0) * 0.9)
train_images = images[:n_train]
train_labels = labels[:n_train]
train_distances = distances[:n_train]

valid_images = images[n_train:n_valid]
valid_labels = labels[n_train:n_valid]
valid_distances = distances[n_train:n_valid]

test_images = images[n_valid:]
test_labels = labels[n_valid:]
test_distances = distances[n_valid:]

log('finish')

2018-01-09 23:34:36.173482 | start
2018-01-09 23:34:36.189326 | read images, labels, distances from files
2018-01-09 23:34:36.364430 | 100000 samples count
2018-01-09 23:34:36.379456 | distances sum: 1139728.84
2018-01-09 23:34:36.390840 | distances min: 0.00
2018-01-09 23:34:36.402711 | distances max: 63.68
2018-01-09 23:34:36.414516 | finish


In [6]:
distances_matrix_sum = np.zeros((10, 10))
distances_matrix_count = np.zeros((10, 10))

for i in range(samples):
    l1 = max(labels[i])
    l2 = min(labels[i])
    distances_matrix_sum[l1][l2] += distances[i]
    distances_matrix_count[l1][l2] += 1
    
# np.set_printoptions(precision=2)
# print(distances_matrix_sum)
# print(distances_matrix_count)
# print(distances_matrix_sum / distances_matrix_count)

# from common.plotting import plot_mat
# for i in range(10):
#     d = sqeuclidean_wasserstein_distance(images[i][0].unsqueeze(0), images[i][1].unsqueeze(0))[0]
#     lol = to_np(images[i])
#     lol = np.array([lol[0], lol[1]])
#     lol = np.expand_dims(lol, axis=1)
#     plot_mat(lol, cmap='gray')
#     plt.title("Distance: %.2f" % d)
#     show()

In [None]:
def compute_error_rate(model, allX, allY):
    batch_size = 200
    i = 0
    mse = 0.0
    while i < allX.size(0):
        x = Variable(allX[i:i+batch_size])
        y = Variable(allY[i:i+batch_size])
        if CUDA:
            x = x.cuda()
            y = y.cuda()
        outputs = model(x)
        diff = (outputs - y).data
        mse = mse + torch.sum(diff ** 2)
        i = i + batch_size
    return mse / allX.size(0)

In [None]:
log('start')

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.encoder = nn.Sequential(
            nn.Sequential(
                nn.Conv2d(1, 20, kernel_size=3, padding=0),
                nn.ReLU()),
            nn.Sequential(
                nn.Conv2d(20, 10, kernel_size=3, padding=0),
                nn.ReLU()),
            nn.Sequential(
                nn.Conv2d(10, 5, kernel_size=5, padding=0),
                nn.ReLU()),
            Reshape(-1),
            nn.Sequential(
                nn.Linear(2000, 100),
                nn.ReLU()),
            nn.Sequential(
                nn.Linear(100, 50),
                nn.ReLU())
        )
        
#         self.decoder = nn.Sequential(
#             nn.Sequential(
#                 nn.Linear(50, 100),
#                 nn.ReLU()),
#             nn.Sequential(
#                 nn.Linear(100, 5 * 28 * 28),
#                 nn.ReLU()),
#             Reshape(5, 28, 28),
#             nn.Sequential(
#                 nn.Conv2d(5, 1, kernel_size=1, padding=0),
#                 nn.ReLU()),
#             nn.Sequential(
#                 nn.Conv2d(10, 20, kernel_size=3, padding=0),
#                 nn.ReLU()),
#             nn.Sequential(
#                 nn.Conv2d(20, 1, kernel_size=3, padding=0),
#                 nn.ReLU()),
#             nn.Softmax2d()
#         )
        
    def forward(self, data):
        data1 = data[:, 0].unsqueeze(1)
        encoder1 = self.encoder(data1)
        #decoder1 = self.decoder(encoder1)
        kl_factor1 = Variable(torch.zeros(data1.size(0)).cuda())
        
        data2 = data[:, 1].unsqueeze(1)
        encoder2 = self.encoder(data2)
        #decoder2 = self.decoder(encoder2)
        kl_factor2 = Variable(torch.zeros(data1.size(0)).cuda())
            
        encoder_difference = encoder1 - encoder2
        encoder_factor = torch.torch.matmul(encoder_difference, encoder_difference.transpose(0, 1)).diag()
        
        wasserstein_factor = Variable(torch.zeros(data1.size(0)).cuda())
        
        return encoder_factor
#         return kl_factor1 + (encoder_factor - wasserstein_factor).pow(2) + kl_factor2
      
num_epochs = 1
patience_expansion = 1.5
best_value_error = 1000000.0
learning_rate = 0.0001
batch_size = 100
epoch = 0
best_params = None

cnn = CNN()
if CUDA:
    cnn.cuda()

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(cnn.parameters(), lr=learning_rate)

try:
    while epoch < num_epochs:
        epoch += 1
        i = 0
        while i < train_images.size(0):
            optimizer.zero_grad()
            x = Variable(train_images[i:i+batch_size])
            y = Variable(train_distances[i:i+batch_size])
            if CUDA:
                x = x.cuda()
                y = y.cuda()
            outputs = cnn(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            i = i + batch_size

        value_error = compute_error_rate(cnn, valid_images, valid_distances)
        if (value_error < best_value_error):
            best_value_error = value_error
            num_epochs = int(np.maximum(num_epochs, epoch * patience_expansion + 1))
            best_params = [p.clone().cpu() for p in cnn.parameters()]

        log('epoch [%d/%d]' % (epoch, num_epochs))        
        log('validation set errors: %.2f' % value_error)
        print('')
        
        torch.save(cnn.state_dict(), 'cnn.pkl')
        torch.save(optimizer.state_dict(), 'optimizer.pkl')
        with open('config.txt', 'w') as file:
            file.write("samples            %d\n" % (samples))
            file.write("num_epochs         %d\n" % (num_epochs))
            file.write("epoch              %d\n" % (epoch))
            file.write("batch_size         %d\n" % (batch_size))
            file.write("patience_expansion %.2f\n" % (patience_expansion))
            file.write("best_value_error   %.2f\n" % (best_value_error))
            file.write("learning_rate      %.2f\n" % (learning_rate))
            file.flush()
            
except KeyboardInterrupt:
    pass
    
if best_params is not None:
    cnn.parameters = best_params
    
# Test the Model
cnn.eval()  # Change model to 'eval' mode (BN uses moving mean/var).
log('test set errors: %.2f' % compute_error_rate(cnn, test_images, test_distances))

torch.save(cnn.state_dict(), 'cnn.pkl')
torch.save(optimizer.state_dict(), 'optimizer.pkl')
        
log('finish')

2018-01-09 23:34:38.319422 | start
2018-01-09 23:34:46.560819 | epoch [1/2]
2018-01-09 23:34:46.573431 | validation set errors: 54.31

2018-01-09 23:34:52.559497 | epoch [2/4]
2018-01-09 23:34:52.571184 | validation set errors: 20.75

2018-01-09 23:34:58.468679 | epoch [3/5]
2018-01-09 23:34:58.480284 | validation set errors: 18.31

