### Optimal Transport Experiments

In [None]:
import sys
sys.path.append('../')

import numpy as np
import matplotlib.pylab as plt

def dist(x1, x2):
    return torch.mean(torch.square(x1-x2))

def graph(ax, x1, x2, title, lines=True):
    if isinstance(x1, torch.Tensor):
        x1 = x1.numpy()
    if isinstance(x2, torch.Tensor):
        x2 = x2.numpy()
    # Plot points
    ax.plot(x1[:,0], x1[:,1], 'r+', label='Source')
    ax.plot(x2[:,0], x2[:,1], 'bx', label='Target')
    if lines:
        for (p1, p2) in list(zip(x1, x2)):
                ax.plot((p1[0], p2[0]), (p1[1], p2[1]), 'b-')
    ax.set_title(title)
    
def gen_points(n):
    mu_s = np.array([0, 0])
    cov_s = np.array([[1, 0], [0, 1]])
    
    mu_t = np.array([4, 4])
    cov_t = np.array([[1, -.8], [-.8, 1]])
    
    xs = ot.datasets.make_2D_samples_gauss(n, mu_s, cov_s)
    xt = ot.datasets.make_2D_samples_gauss(n, mu_t, cov_t)
    return xs, xt

def get_emd_matrix(a, b, xs, xt):
    M = ot.dist(xs, xt)
    M /= M.max()
    emd_matrix = ot.emd(a, b, M)
    return emd_matrix

def fix_points(xs, xt, emd_matrix):
    perm = None
    if isinstance(emd_matrix, torch.Tensor):
        perm = torch.nonzero(emd_matrix)[:,1]
    else:
        perm = np.argwhere(emd_matrix)[:,1]
    return xs, xt[perm]

def fast_fix_points_torch(xs, xt):
    n = xs.shape[0]
    a, b = torch.ones(n) / n, torch.ones(n) / n  # uniform distribution on samples
    emd_matrix = get_emd_matrix(a, b, xs, xt)
    return fix_points(xs, xt, emd_matrix)

def fast_fix_points_numpy(xs, it):
    n = xs.shape[0]
    a, b = ot.unif(n), ot.unif(n)  # uniform distribution on samples
    emd_matrix = get_emd_matrix(a, b, xs, xt)
    return fix_points(xs, xt, emd_matrix)


In [None]:
# Optimal Transport Experiments - Part 1

# Sample random points
n = 64
xs, xt = gen_points(n)

# Graph the original pairings
fig, axs = plt.subplots(1, 2)
graph(axs[0], xs, xt, 'Original Mapping')

# Permute target points to get OT samples
xsc, xtc = fast_fix_points_numpy(xs, xt)

# Graph the better pairings
graph(axs[1], xsc, xtc, 'EMD Mapping')
plt.show()

In [None]:
# Optimal Transport Experiments - Part 2
import ot
import torch
from datasets.dist import GMM, Gaussian, Funnel

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

# Initialize source and target
mean = torch.Tensor([0.0, 0.0], device=device)
covar = torch.Tensor([[3,0.25],[0.25,4]], device=device)
source = Gaussian(mean, covar)
target = GMM(device, nmode=3)
 
# Sample from source and target
n = 64
xs = source.sample(n)
xt = target.sample(n)

# Generate plots and plot original mapping
fig, axs = plt.subplots(1, 3, figsize=(30,10))
graph(axs[0], xs, xt, 'Samples', lines=False)
graph(axs[1], xs, xt, f'Original Mapping, Dist={dist(xs, xt)}')

# Compute exact OT mapping and graph again
xsc, xtc = fast_fix_points_torch(xs, xt)
graph(axs[2], xsc, xtc, f'EMD Mapping, Dist={dist(xsc, xtc)}')

In [None]:
# Optimal Transport Experiments - Part 3
import torch
import ot
from model.ot import verlet_emd_reorder
from datasets.dist import GMM, Gaussian, Funnel, VerletGaussian, VerletGMM

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

# Initialize source and target
mean = torch.Tensor([0.0, 0.0], device=device)
covar = torch.Tensor([[3,0.25],[0.25,4]], device=device)
source_q = Gaussian(mean, covar)
source_p = Gaussian(torch.zeros_like(mean, device=device), torch.eye(2, device=device))
source = VerletGaussian(source_q, source_p)
target_q = GMM(device, nmode=3)
target_p = Gaussian(torch.zeros_like(mean, device=device), torch.eye(2, device=device))
target = VerletGMM(target_q, target_p)
 
# Sample from source and target
n = 64
xs = source.sample(n)
xt = target.sample(n)

# Generate plots and plot original mapping
fig, axs = plt.subplots(1, 3, figsize=(30,10))
graph(axs[0], xs.q, xt.q, 'Samples', lines=False)
graph(axs[1], xs.q, xt.q, f'Original Mapping, Dist={dist(xs.get_combined(), xt.get_combined())}')

# Compute exact OT mapping and graph again
xsc, xtc = verlet_emd_reorder(xs, xt)
graph(axs[2], xsc.q, xtc.q, f'EMD Mapping, Dist={dist(xsc.get_combined(), xtc.get_combined())}')