# Calculation of the discrete OT maps for colored MNIST

## 1. Imports

In [1]:
import os, sys
sys.path.append("..")

import math
import matplotlib
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline 

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchvision
import gc
import pdb

from src import distributions
import torch.nn.functional as F
from ema_pytorch import EMA

from src.resnet2 import ResNet_D
# from src.resnet_generator import ResnetGenerator
from src.cunet import CUNet
from src.improved_diffusion import UNetModel

from src.tools import unfreeze, freeze
from src.tools import load_dataset, get_sde_pushed_loader_stats
from src.fid_score import calculate_frechet_distance
from src.tools import weights_init_D
from src.plotters import plot_random_sde_images, plot_fixed_sde_images, plot_fixed_sde_trajectories, plot_random_sde_trajectories, plot_several_fixed_sde_trajectories, plot_several_random_sde_trajectories

from collections import defaultdict
from copy import deepcopy
import json

from tqdm import tqdm_notebook as tqdm
from IPython.display import clear_output

import wandb
from src.tools import fig2data, fig2img # for wandb

# This needed to use dataloaders for some datasets
from PIL import PngImagePlugin
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)

In [2]:
gc.collect(); torch.cuda.empty_cache()

## 2. Pairwise distance calculation

In [None]:
from ot.bregman import sinkhorn, sinkhorn_stabilized
from ot.lp import emd
from ot.stochastic import averaged_sgd_entropic_transport, c_transform_entropic, plan_dual_entropic, solve_dual_entropic
import warnings
warnings.simplefilter("always")

DATASET1, DATASET1_PATH = 'MNIST-colored_2', '/home/data/MNIST'
DATASET2, DATASET2_PATH = 'MNIST-colored_3', '/home/data/MNIST'

IMG_SIZE = 32
BATCH_SIZE = 100
N = 1000

In [27]:
X_sampler, X_test_sampler = load_dataset(DATASET1, DATASET1_PATH,
                                         img_size=IMG_SIZE, batch_size=BATCH_SIZE,
                                         shuffle=False, device="cpu")
Y_sampler, Y_test_sampler = load_dataset(DATASET2, DATASET2_PATH,
                                         img_size=IMG_SIZE, batch_size=BATCH_SIZE,
                                         shuffle=False, device="cpu")
    
torch.cuda.empty_cache(); gc.collect()
clear_output()

X = X_test_sampler.loader.dataset[:N][0]
Y = Y_test_sampler.loader.dataset[:N][0]

In [131]:
M = np.zeros((N, N))

for i in tqdm(range(N)):
    M[i] = ((X[i][None, :] - Y)**2).sum(dim=(1,2,3))

a = np.ones(N)/N
b = np.ones(N)/N

M = np.array(M, dtype=np.float128)

  and should_run_async(code)
Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for i in tqdm(range(N)):


HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1000.0), HTML(value='')))




## 3. Discrete OT calculation

### Discrete OT mapping calculation ($\epsilon = 0$)

In [None]:
mapping = emd(a, b, M)

In [230]:
fname = "../discrete_transport_mapping/eps_0"
with open(fname, "wb") as f:
    np.save(f, mapping)

  and should_run_async(code)


### Discrete entropic OT mapping calculation

### $\epsilon$ = 1

In [None]:
epsilon = 1

reg = epsilon
scale = 1/(3*IMG_SIZE*IMG_SIZE)
distance = M*scale
reg_normed = reg*scale
mapping = sinkhorn(a, b, distance, reg=reg_normed, warn=True, verbose=True, numItermax=100000)

In [42]:
fname = "../discrete_transport_mapping/eps_1"
with open(fname, "wb") as f:
    np.save(f, mapping)

  and should_run_async(code)


### $\epsilon = 10$

In [None]:
epsilon = 10

reg = epsilon
scale = 1/(3*IMG_SIZE*IMG_SIZE)
distance = M*scale
reg_normed = reg*scale
mapping = sinkhorn(a, b, distance, reg=reg_normed, warn=True, verbose=True, numItermax=100000)

In [None]:
fname = "../discrete_transport_mapping/eps_10"
with open(fname, "wb") as f:
    np.save(f, mapping)