# Testing Optimal Transport Solver (OTS) on Images Benchmark
**GPU-only implementation.**

In [None]:
import os, sys
sys.path.append("..")

import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import numpy as np
import torch
import gc

import pandas as pd

from sklearn.decomposition import PCA
from src.tools import unfreeze, freeze
from src.resnet2 import ResNet_G, weights_init_G, ResNet_D
from src.icnn import View
from torch import nn
import src.map_benchmark as mbm
from src.unet import UNet
import torch.nn.functional as F

from tqdm import tqdm_notebook
from IPython.display import clear_output

## Changable Config

In [None]:
BATCH_SIZE = 64
GPU_DEVICE = 7

T_ITERS = 10
MAX_ITER = 25000

assert torch.cuda.is_available()
torch.cuda.set_device(GPU_DEVICE)
torch.manual_seed(0x000000); np.random.seed(0x000000)

OUT_PATH = '../checkpoints/IPM_vs_OTS/'
if not os.path.exists(OUT_PATH):
    os.makedirs(OUT_PATH)

## Benchmark Loading

In [None]:
benchmark = mbm.CelebA64Benchmark(which='Early', batch_size=BATCH_SIZE)

# Swap output and input
X_sampler = benchmark.output_sampler
Y_sampler = benchmark.input_sampler

## Fixed images for plotting

In [None]:
Y_fixed = Y_sampler.sample(10).requires_grad_(True)
X_fixed = benchmark.output_sampler.potential.push(Y_fixed).detach()
Y_fixed.requires_grad_(False)

## Initializing networks

In [None]:
# Generator
T = nn.Sequential(
    View(3,64,64),
    UNet(n_channels=3, n_classes=3),
    View(64*64*3)
)
T = T.cuda()

In [None]:
# Potential
D = nn.Sequential(
    View(3, 64, 64),
    ResNet_D(),
)

def weights_init_D(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
    elif classname.find('BatchNorm') != -1:
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

D.apply(weights_init_D)
D = D.cuda()

## Plotting

In [None]:
def plot():
    fig, axes = plt.subplots(3, len(X_fixed), figsize=(len(X_fixed)*2+.1, 3*2))
    T_X_fixed = T(X_fixed).reshape(-1, 3, 64, 64).mul(.5).add(.5).clip(0, 1).cpu().detach()
    X_in = X_fixed.reshape(-1, 3, 64, 64).mul(.5).add(.5).clip(0, 1).cpu().detach()
    Y_out = Y_fixed.reshape(-1, 3, 64, 64).mul(.5).add(.5).clip(0, 1).cpu().detach()
    for i in range(len(X_fixed)):
        axes[0,i].imshow(X_in[i].permute(1, 2, 0))
        axes[1,i].imshow(T_X_fixed[i].permute(1, 2, 0))
        axes[2,i].imshow(Y_out[i].permute(1, 2, 0))
        
    for i, ax in enumerate(axes.flatten()):
        ax.get_xaxis().set_visible(False)
        ax.set_yticks([])
        
    axes[0, 0].set_ylabel(r'$x$', fontsize=30)
    axes[1, 0].set_ylabel(r'$\hat{T}(x)$', fontsize=30)
    axes[2, 0].set_ylabel(r'$y$', fontsize=30)
    
    fig.tight_layout(h_pad=0.01, w_pad=0.01)
    gc.collect(); torch.cuda.empty_cache()
    return fig, axes

fig, axes = plot()

## Evaluation

In [None]:
def evaluate(size=2**14):
    losses = []
    for i in tqdm_notebook(range(0, size, BATCH_SIZE)):
        Y = Y_sampler.sample(10).requires_grad_(True)
        X = benchmark.output_sampler.potential.push(Y).detach()
        with torch.no_grad():
            losses.append((Y-T(X)).square().sum(dim=1).mean().item())
    return 100 * np.mean(losses) / Y_sampler.var

## Training Loss

In [None]:
def Wasserstein2Loss(D,T,X,Y):
    T_X = T(X)
    return F.mse_loss(X,T_X).mean() - D(T_X).mean() + D(Y).mean()

## Main Training

In [None]:
T_opt = torch.optim.Adam(T.parameters(), betas=(0.,0.9), lr = 1e-4)
D_opt = torch.optim.Adam(D.parameters(), betas=(0.,0.9), lr = 1e-4)
best_L2_UVP = np.inf

In [None]:
for iteration in tqdm_notebook(range(MAX_ITER)):
    ##########################################################
    ## Outer maximization loop
    ##########################################################   
    freeze(T); unfreeze(D)
    X, Y = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)
    D_loss = -Wasserstein2Loss(D, T, X, Y)
    D_opt.zero_grad(); D_loss.backward(); D_opt.step()
    del D_loss; gc.collect(); torch.cuda.empty_cache()

    ##########################################################
    ## Inner minimization loop
    ##########################################################
    freeze(D); unfreeze(T)
    for it in range(T_ITERS):
        X, Y = X_sampler.sample(BATCH_SIZE), Y_sampler.sample(BATCH_SIZE)
        T_loss = Wasserstein2Loss(D, T, X, Y)
        T_opt.zero_grad(); T_loss.backward(); T_opt.step()
        del T_loss; gc.collect(); torch.cuda.empty_cache()
    
    if iteration % 50 == 0:
        clear_output(wait=True)
        print('Iteration: {}'.format(iteration))
        
        current_L2_UVP = evaluate()
        if current_L2_UVP < best_L2_UVP:
            best_L2_UVP = current_L2_UVP
            torch.save(T.state_dict(), os.path.join(OUT_PATH, 'T.pt'))
            
        print('Current L2-UVP: {}'.format(current_L2_UVP))
        print('Best L2-UVP: {}'.format(best_L2_UVP))
        
        fig, axes = plot()
        plt.show()