# Optimal Transport and Wasserstein Distances

The goal of this class is to introduce computational optimal transport, and implement some applications of optimal transport in machine learning.

In the first part of the class, you will familiarize yourself with optimal transport and learn to compute optimal transport distances (also called Wasserstein distances).

In the second part of the class, you will use optimal transport as a nice geometrical tool in machine learning.

In this class, you will need to install the package ``POT``:
* Install with pip: ```bash pip install pot```
* Install with conda: ```bash conda install -c conda-forge pot ```

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imread
from mpl_toolkits.mplot3d import Axes3D
import ot

## 1. Computational Optimal Transport: Linear Programming and Sinkhorn Algorithm

Optimal Transport is a theory that allows us to compare two (weighted) points clouds $(X, a)$ and $(Y, b)$, where $X \in \mathbb{R}^{n \times d}$ and $Y \in \mathbb{R}^{m \times d}$ are the locations of the $n$ (resp. $m$) points in dimension $d$, and $a \in \mathbb{R}^n$, $b \in \mathbb{R}^m$ are the weights. We ask that the total weights sum to one, i.e. $\sum_{i=1}^n a_i = \sum_{j=1}^m b_j = 1$.

The basic idea of Optimal Transport is to "transport" the mass located at points $X$ to the mass located at points $Y$.

Let us denote by $\mathcal{U}(a,b) = \left\{ P \in \mathbb{R}^{n \times m} \,|\, P \geq 0, \sum_{j=1}^m P_{ij} = a_i, \sum_{i=1}^n P_{ij} = b_j\right\}$.

If $P \in \mathcal{U}(a,b)$, the quantity $P_{ij} \geq 0$ should be regarded as the mass transported from point $X_i$ to point $Y_j$. For this reason, it is called a "transport plan".

For any transport plan $P \in \mathcal{U}(a,b)$, we define its cost $K_C(P) := \langle C, P \rangle = \sum_{ij} C_{ij }P_{ij}$ where $C \in \mathbb{R}^{n \times m}$. The value $C_{ij}$ should be regarded as the cost, or price, we must pay for transporting a unit of mass from point $X_i$ to point $Y_j$, and is often chosen as $C_{ij} = \|X_i - Y_j\|^2$.

In "Optimal Transport", there is the word _Optimal_. Indeed, we want to find a transport plan $P \in \mathcal{U}(a,b)$ that will minimize the cost $K_C$. In other words, we want to solve
$$
    \min_{P \in \mathcal{U}(a,b)} K_C(P) = \min_{P \in \mathcal{U}(a,b)} \sum_{ij} C_{ij }P_{ij}
$$

This problem is a Linear Program: the objective function is linear, and the constraints are linear. We can thus solve this problem using classical Linear Programming algorithms, such as the simplex algorithm.

If $P^*$ is a solution to the Optimal Transport problem, we will say that $P^*$ is an optimal transport plan between $(X, a)$ and $(Y, b)$, and that $K_C(P^*)$ is the optimal transport distance (or Wasserstein distance) between $(X, a)$ and $(Y, b)$.

### 1.a. Computing Optimal "Croissant" Transport

We will solve the Bakery/Cafés problem of transporting croissants from a number of Bakeries to Cafés in a City (In this case Manhattan). We did a quick google map search in Manhattan for bakeries and Cafés:

![bak.png](https://remi.flamary.com/cours/otml/bak.png)

We extracted from this search their positions and generated fictional production and sale number (that both sum to the same value).

We have acess to the position of Bakeries ```bakery_pos``` and their respective production ```bakery_prod``` which describe the source distribution. The Cafés where the croissants are sold are defiend also by their position ```cafe_pos``` and ```cafe_prod```. For fun we also provide a map ```Imap``` that will illustrate the position of these shops in the city.

In [None]:
# Load the data
data = np.load('data/manhattan.npz')

bakery_pos = data['bakery_pos']
bakery_prod = data['bakery_prod']
cafe_pos = data['cafe_pos']
cafe_prod = data['cafe_prod']
Imap = data['Imap']

print('Bakery production: {}'.format(bakery_prod))
print('Cafe sale: {}'.format(cafe_prod))
print('Total croissants : {}'.format(cafe_prod.sum()))

In [None]:
plt.figure(figsize=(8,8))
plt.imshow(Imap, interpolation='bilinear') # plot the map
plt.scatter(bakery_pos[:,0], bakery_pos[:,1], s=5*bakery_prod, c='r', edgecolors='k', label='Bakeries')
plt.scatter(cafe_pos[:,0], cafe_pos[:,1], s=5*cafe_prod, c='b', edgecolors='k', label='Cafés')
plt.legend(fontsize=20)
plt.axis('off')
plt.title('Manhattan Bakeries and Cafés', fontsize=25);

Let us now compute the cost matrix $C \in \mathbb{R}^{n \times m}$.

In [None]:
C = # TODO

We can now compute the optimal transport plan to transport the croissants from the bakeries to the cafés, using function ``ot.emd``.

In [None]:
optimal_plan = # TODO
print(optimal_plan)
optimal_cost = # TODO
print(optimal_cost)

In [None]:
plt.figure(figsize=(8,8))
plt.imshow(Imap, interpolation='bilinear')
plt.scatter(bakery_pos[:,0], bakery_pos[:,1], s=5*bakery_prod, c='r', edgecolors='k', label='Bakeries')
plt.scatter(cafe_pos[:,0], cafe_pos[:,1], s=5*cafe_prod, c='b', edgecolors='k', label='Cafés')
for i in range(8):
    for j in range(5):
        plt.plot([bakery_pos[i,0], cafe_pos[j,0]], [bakery_pos[i,1], cafe_pos[j,1]], c='k', lw=0.1*optimal_plan[i,j], alpha=0.8)
plt.legend(fontsize=20)
plt.axis('off')
plt.title('Manhattan Bakeries and Cafés', fontsize=25);

### 1.b Entropy Regularized Optimal Transport

In real applications, and especially in Machine Learning, we often have to deal with huge numbers of points. In this case, the linear programming algorithms which are cubic will take too much time to run.

That's why, in practise, people minimize another criterion given by
$$
    \min_{P \in \mathcal{U}(a,b)} \langle C, P \rangle + \epsilon \sum_{ij} P_{ij} [ \log(P_{ij}) - 1].
$$
When $\epsilon$ is sufficiently small, we can consider that a solution to the above problem (often refered to as "Entropy-regularized Optimal Transport") is a good approximation of a real optimal transport plan.

In order to solve this problem, one can remark that the optimality conditions imply that a solution $P_\epsilon^*$ necessarily is of the form $P_\epsilon^* = \text{diag}(u) \, K \, \text{diag}(v)$, where $K = \exp(-C/\epsilon)$ and $u,v$ are two non-negative vectors.

$P_\epsilon^*$ should verify the constraints, i.e. $P_\epsilon^* \in \mathcal{U}(a,b)$, so that
$$
    P_\epsilon^* 1_m = a \text{  and  } (P_\epsilon^*)^T 1_n = b
$$
which can be rewritten as
$$
    u \odot (Kv) = a \text{  and  } v \odot (K^T u) = b
$$

Then Sinkhorn's algorithm alternate between the resolution of these two equations, and reads
$$
    u \leftarrow \frac{a}{Kv} \text{  and  } v \leftarrow \frac{b}{K^T u}
$$

In [None]:
def sinkhorn(a, b, C, epsilon=0.1, max_iters=100):
    """Run Sinnkhorn's algorithm"""
    
    return # TODO

We first show that this algorithm is consistent with classical optimal transport, using the "croissant" transport example.

In [None]:
plan_diff = []
distance_diff = []
for epsilon in np.linspace(0.01, 1, 100):
    optimal_plan_sinkhorn = sinkhorn(bakery_prod, cafe_prod, C/C.max(), epsilon)
    optimal_cost_sinkhorn = np.sum(optimal_plan_sinkhorn*C)
    plan_diff.append(np.linalg.norm(optimal_plan_sinkhorn-optimal_plan)/cafe_prod.sum())
    distance_diff.append(100*np.abs(optimal_cost_sinkhorn-optimal_cost)/optimal_cost)

In [None]:
plt.figure(figsize=(16,5))
plt.loglog(np.linspace(0.01, 1, 100), plan_diff, lw=4)
plt.xlabel('Regularization Strength $\epsilon$', fontsize=25)
plt.ylabel('$||P^* - P_\epsilon^*||$', fontsize=25)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()

In [None]:
plt.figure(figsize=(16,5))
plt.loglog(np.linspace(0.01, 1, 100), distance_diff, lw=4)
plt.xlabel('Regularization Strength $\epsilon$', fontsize=25)
plt.ylabel('Error in %', fontsize=25)
plt.xticks(fontsize=20)
plt.yticks(fontsize=20)
plt.show()

Let us now compare the running time for sinkhorn and classical optimal transport algorithm on more data.

In [None]:
n = 1000
m = 1000
d = 2

X = np.random.randn(n,d)
Y = np.random.randn(m,d)

a = np.ones(n)
b = np.ones(m)

C = np.zeros((n,m))
for i in range(n):
    for j in range(n):
        C[i,j] = np.linalg.norm(X[i] - Y[j])**2

In [None]:
%timeit ot.emd(a,b,C)

In [None]:
%timeit sinkhorn(a,b,C)

We see that sinkhorn is faster. What is even more interesting is that sinkhorn can be parallelerized on GPUs, giving further acceleration.

## 2. Application of Optimal Transport in Machine Learning: Color Transfer

We will now use optimal transport in color transfer. We are given two pictures, and the goal is to transfer the color of the first one to the other.

In [None]:
I1 = imread('./data/klimt.jpg').astype(np.float64) / 256
I2 = imread('./data/schiele.jpg').astype(np.float64) / 256

def showImage(I,myPreferredFigsize=(8,8)):
    plt.figure(figsize=myPreferredFigsize)
    plt.imshow(I)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

showImage(I1)
showImage(I2)

Those are two beautiful paintings of respectively Gustav Klimt and Egon Schiele. Now we will treat them as empirical distributions, in the color space.

In [None]:
def im2mat(I):
    """Converts and image to matrix (one pixel per line)"""
    return # TODO

def mat2im(X, shape):
    """Converts back a matrix to an image"""
    return # TODO

X1 = im2mat(I1)
X2 = im2mat(I2)

We will need to plot the distributions in the color space, using the following function:

In [None]:
def showImageAsPointCloud(X, myPreferredFigsize=(8,8)):
    fig = plt.figure(figsize=myPreferredFigsize)
    ax = fig.add_subplot(111, projection='3d')
    ax.set_xlim(0,1)
    ax.scatter(X[:,0], X[:,1], X[:,2], c=X, s=50, marker='o', alpha=1.0)
    ax.set_xlabel('R',fontsize=22)
    ax.set_xticklabels([])
    ax.set_ylim(0,1)
    ax.set_ylabel('G',fontsize=22)
    ax.set_yticklabels([])
    ax.set_zlim(0,1)
    ax.set_zlabel('B',fontsize=22)
    ax.set_zticklabels([])
    ax.grid('off')
    plt.show()

It is unlikely that our solver, as efficient it can be, can handle so large distributions (1Mx1M for the coupling). We will use the Mini batch k-means procedure from sklearn to subsample those distributions. Write the code that performs this subsampling (you can choose a size of 50 clusters to have a good approximation of the image).
__Note that computing the centroids can take some time.__

In [None]:
from sklearn.cluster import KMeans

nbsamples = 50

kmeans1 = # TODO
X1_sampled = # TODO
showImageAsPointCloud(X1_sampled)


kmeans2 = # TODO
X2_sampled = # TODO
showImageAsPointCloud(X2_sampled)

Let us now compute the optimal transport plan between X1_sampled and X2_sampled using ``ot.emd``.

In [None]:
C = # TODO
optimal_transport_plan = # TODO

Since only the centroid of clusters have changed, we need to figure out a simple way of transporting all the pixels in the original image. We will apply a simple strategy where the new value of the pixel corresponds simply to the new position of its corresponding centroid.

In [None]:
X1_transformed = # TODO

In [None]:
I1_transformed = mat2im(X1_transformed, I1.shape)

In [None]:
showImage(I1)
showImage(I1_transformed)

In [None]:
X2_transformed = # TODO

In [None]:
I2_transformed = mat2im(X2_transformed, I2.shape)

In [None]:
showImage(I2)
showImage(I2_transformed)