<a rel="license" href="http://creativecommons.org/licenses/by/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by/4.0/88x31.png" /></a><br /><span xmlns:dct="http://purl.org/dc/terms/" property="dct:title"><b>Computational Optimal Transport</b></span> by <a xmlns:cc="http://creativecommons.org/ns#" href="http://mate.unipv.it/gualandi" property="cc:attributionName" rel="cc:attributionURL">Stefano Gualandi</a> is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by/4.0/">Creative Commons Attribution 4.0 International License</a>. Based on a project at <a xmlns:dct="http://purl.org/dc/terms/" href="https://github.com/mathcoding/opt4ds" rel="dct:source">https://github.com/mathcoding/compopt</a>.

# Computing Wasserstein Barycenters
In this book, we discuss the fundamental challenges in computing Wasserstein Barycenters.

The exercise are based on the following references:

* Chapter 9 in [Computational Optimal Transport]()
* Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). *Iterative Bregman projections for regularized transportation problems*. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
* Bouchet, P.Y., Gualandi, S. and Rousseau, L.M., 2020. *Primal heuristics for wasserstein barycenters*. CPAIOR 2020, Proceedings 17 (pp. 239-255), Springer.
* Auricchio, G., Bassetti, F., Gualandi, S. and Veneroni, M., 2019. *Computing Wasserstein barycenters via linear programming*. CPAIOR 2019 Proceedings 16 (pp. 355-363), Springer.

**IMPORTANT:** One of the best application of OT is in the study of gene expression profiles:

* Schiebinger, G., Shu, J., Tabaka, M., Cleary, B., Subramanian, V., Solomon, A., Gould, J., Liu, S., Lin, S., Berube, P. and Lee, L., 2019. [Optimal-transport analysis of single-cell gene expression identifies developmental trajectories in reprogramming](https://www.cell.com/cell/pdf/S0092-8674(19)30039-X.pdf). Cell, 176(4), pp.928-943.

We suggest to watch the following YouTube research seminars:

* [Statistical and Computational aspects of Wasserstein Barycenters - P. Rigollet @ MAD+ (8 Apr 2020)](https://www.youtube.com/watch?v=MMlR-PeMsgg)
* [Pavel Dvurechensky - "Wasserstein barycenters from the computational perspective" | MoCCA'20](https://www.youtube.com/watch?v=j93j9W4JgyI&t=245s)
* [Darina Dvinskikh - "Decentralized Algorithms for Wasserstein Barycenters" | MoCCA'20](https://www.youtube.com/watch?v=cC9xGHSx3B0)

## Wasserstein Barycenters in 1D
The following two exercises are from [POT examples](https://pythonot.github.io/auto_examples/index.html#wasserstein-barycenters).

### Using the optimal transportation plan for interpolation between two measures
Try to play with the following code, by changing the value of *aplha* (the weights of the barycenter).

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D  # noqa
from matplotlib.collections import PolyCollection

In [None]:
# nb bins
n = 100

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)  # m= mean, s= std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

print(M.shape)

alpha = 0.5  # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
bary_wass = ot.bregman.barycenter(A, M, reg, weights)

f, (ax1, ax2) = plt.subplots(2, 1, tight_layout=True, num=1)
ax1.plot(x, A, color="black")
ax1.set_title('Distributions')

ax2.plot(x, bary_l2, 'r', label='l2')
ax2.plot(x, bary_wass, 'g', label='Wasserstein')
ax2.set_title('Barycenters')

plt.legend()
plt.show()

To automate the change of alpha, and having a plot in 3D (of 1D distributions) try the following code.

In [None]:
# Try 11 values of alpha between 0 and 1
n_alpha = 11
alpha_list = np.linspace(0, 1, n_alpha)

B_l2 = np.zeros((n, n_alpha))

B_wass = np.copy(B_l2)

# Compute the barycenter for each value of alpha
for i in range(n_alpha):
    alpha = alpha_list[i]
    weights = np.array([1 - alpha, alpha])
    B_l2[:, i] = A.dot(weights)
    B_wass[:, i] = ot.bregman.barycenter(A, M, reg, weights)

plt.figure(1)

# Plot the barycenter interpolations
cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
    ys = B_l2[:, i]
    verts.append(list(zip(x, ys)))

ax = plt.gcf().add_subplot(projection='3d')

poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
plt.title('Barycenter interpolation with l2')
plt.tight_layout()

Next, try to compute the Wasserstein Barycenters with the following code.

In [None]:
plt.figure(1)
cmap = plt.cm.get_cmap('viridis')
verts = []
zs = alpha_list
for i, z in enumerate(zs):
    ys = B_wass[:, i]
    verts.append(list(zip(x, ys)))

ax = plt.gcf().add_subplot(projection='3d')

poly = PolyCollection(verts, facecolors=[cmap(a) for a in alpha_list])
poly.set_alpha(0.7)
ax.add_collection3d(poly, zs=zs, zdir='y')
ax.set_xlabel('x')
ax.set_xlim3d(0, n)
ax.set_ylabel('$\\alpha$')
ax.set_ylim3d(0, 1)
ax.set_zlabel('')
ax.set_zlim3d(0, B_l2.max() * 1.01)
plt.title('Barycenter interpolation with Wasserstein')
plt.tight_layout()

plt.show()

### Compare Barycenters computed with Exact and Entropic Optimal Transport
The following code show the difference in practice for computing Barycenters of using the exact Optimal Transport solved with a simplex algorithm, versus the use of an entropic regularized solver.

In [None]:
import numpy as np
import matplotlib.pylab as pl
import ot
# necessary for 3d plot even if not used
from mpl_toolkits.mplot3d import Axes3D  # noqa
from matplotlib.collections import PolyCollection  # noqa

problems = []

n = 100  # nb bins

# bin positions
x = np.arange(n, dtype=np.float64)

# Gaussian distributions
# Gaussian distributions
a1 = ot.datasets.make_1D_gauss(n, m=20, s=5)  # m= mean, s= std
a2 = ot.datasets.make_1D_gauss(n, m=60, s=8)

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

alpha = 0.5  # 0<=alpha<=1
weights = np.array([1 - alpha, alpha])

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
ot.tic()
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
ot.toc()

ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights)
ot.toc()

problems.append([A, [bary_l2, bary_wass, bary_wass2]])

# STAIRS DATA
a1 = 1.0 * (x > 10) * (x < 50)
a2 = 1.0 * (x > 60) * (x < 80)

a1 /= a1.sum()
a2 /= a2.sum()

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
ot.tic()
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
ot.toc()

ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights)
ot.toc()

problems.append([A, [bary_l2, bary_wass, bary_wass2]])

# DIRAC DELTA DATA
a1 = np.zeros(n)
a2 = np.zeros(n)

a1[10] = .25
a1[20] = .5
a1[30] = .25
a2[80] = 1

a1 /= a1.sum()
a2 /= a2.sum()

# creating matrix A containing all distributions
A = np.vstack((a1, a2)).T
n_distributions = A.shape[1]

# loss matrix + normalization
M = ot.utils.dist0(n)
M /= M.max()

# l2bary
bary_l2 = A.dot(weights)

# wasserstein
reg = 1e-3
ot.tic()
bary_wass = ot.bregman.barycenter(A, M, reg, weights)
ot.toc()

ot.tic()
bary_wass2 = ot.lp.barycenter(A, M, weights)
ot.toc()

problems.append([A, [bary_l2, bary_wass, bary_wass2]])

pl.figure(1, (20, 6))
pl.clf()

for i in range(nbm):
    A = problems[i][0]
    bary_l2 = problems[i][1][0]
    bary_wass = problems[i][1][1]
    bary_wass2 = problems[i][1][2]

    pl.subplot(2, nbm, 1 + i)
    for j in range(n_distributions):
        pl.plot(x, A[:, j])
    if i == nbm2:
        pl.title('Distributions')
    pl.xticks(())
    pl.yticks(())

    pl.subplot(2, nbm, 1 + i + nbm)

    pl.plot(x, bary_l2, 'r', label='L2 (Euclidean)')
    pl.plot(x, bary_wass, 'g', label='Reg Wasserstein')
    pl.plot(x, bary_wass2, 'b', label='LP Wasserstein')
    if i == nbm - 1:
        pl.legend()
    if i == nbm2:
        pl.title('Barycenters')

    pl.xticks(())
    pl.yticks(())
    
pl.show()

## Implementing and testing the Iterative Bregman Projection algorithm
Consider the 10 double-centric ellipses, loaded from a Matlab .dat file with the following snippet.

In [None]:
import scipy

def ReadEggs():
    imported_ellipses = scipy.io.loadmat('../data/ellipses.mat')
    ellipses=imported_ellipses['ellipses']
    return ellipses

def PlotEggs(ellipses):
    fig, axs = plt.subplots(2, 5, figsize=(20, 8))  # 2 rows, 5 columns

    for i in range(2):  # Rows
        for j in range(5):  # Columns
            axs[i, j].imshow(ellipses[i*5+j], cmap='binary')  # Example plot
            axs[i, j].set_xticks([])  # Remove x-axis tick labels
            axs[i, j].set_yticks([])  # Remove y-axis tick labels#axs[i, j].set_title(f'Plot {(i*5)+j+1}')

    plt.tight_layout()  # Adjust subplots to fit into the figure area.
    plt.show()  # Display the figure

ellipses = ReadEggs()
PlotEggs(ellipses)

**QUESTION:** Which is the empirical distribution that best represent this 10 samples?

You can try using the POT library using either the function [convolutional barycenter 2d](https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.convolutional_barycenter2d) or the [convolutional barycenter debiased](https://pythonot.github.io/gen_modules/ot.bregman.html#ot.bregman.convolutional_barycenter2d_debiased) as in the following script (try separately the two functions, using the debiased as second option):

In [None]:
def ConvBary(A, reg=0.004):
    fig, axs = plt.subplots(3, 5, figsize=(20, 8))  # 2 rows, 5 columns
    for i in range(2):  # Rows
        for j in range(5):  # Columns
            axs[i, j].imshow(ellipses[i*5+j], cmap='binary')  # Example plot
            axs[i, j].set_xticks([])  # Remove x-axis tick labels
            axs[i, j].set_yticks([])  # Remove y-axis tick labels#axs[i, j].set_title(f'Plot {(i*5)+j+1}')

    for j, r in enumerate([0.1, 0.01, 0.005, 0.002, 0.001]):
        
        B = ot.bregman.convolutional_barycenter2d(A, r)
        
        # COMMENT THE PREVIOUS LINE AND UNCOMMENT THE FOLLOWING LINE TO USE THE NON-DEBIASED VERSION
        #B = ot.bregman.convolutional_barycenter2d_debiased(A, r)

        axs[2, j].imshow(B, cmap='binary')  # Example plot
        axs[2, j].set_xticks([])  # Remove x-axis tick labels
        axs[2, j].set_yticks([])  # Remove y-axis tick labels
        axs[2, j].set_title(f'reg={r}')

    plt.tight_layout()  # Adjust subplots to fit into the figure area.
    plt.show()

ConvBary(ellipses)

### Implement your own Entropic Wasserstein Barycenter algorithm
Following the suggestions given during the lecture, implement your own algorithm to compute a Wasserstein Barycenter.

Possible options:

1. Implement the Iterative Bregman Projection as explained during the class (but exploiting the 2D grid structure of the cost matrix)
2. Implement any of the iterative pairing heuristic as explained in the slides
3. **IMPORTANT:** Design your own new algorithm, make extensive computational test providing computational evidence of your algorithm, and write a paper.

In [None]:
def WB(A, C, reg=0.05):
    # Prepare cost matrix
    l = A.shape[0]
    n = A.shape[1]
    N = n**2

    # TODO: complete with your code

In order to save time, you can precompute once the cose matrix of images 60x60, save it as a binary file, and loaded when ncessary.

Look at following code snippet.

In [None]:
def GenerateCostMatrix(n, p=2):
    N = n**2
    C = np.zeros((N, N))
    for i in range(n):
        for j in range(n):
            for v in range(n):
                for w in range(n):
                    if p == 1:
                        C[i*n+j, v*n+w] = np.abs(i-v) + np.abs(j-w)
                    elif p == 2:
                        C[i*n+j, v*n+w] = (i-v)**2 + (j-w)**2
    return C

if False:
    C = GenerateCostMatrix(60, p=2)
    np.save('n60p2s.npy', C)
else:
    C = np.load('../data/n60p2.npy')

In [None]:
C.shape