## Solving regularized Optimal Transport using Sinkhorn Iterations

This short colabs provides a lightweight interface for the computation of the
optimal transport map used for the debiasing step in this [paper](https://openreview.net/forum?id=5NxJuc0T1P).

We consider the simpler case of the data stemming from solving the Kuramoto-Sivashinsky equation in 1D. We provide some simple statistical metrics for quick evaluation.



### Downloading dependencies

The only non-trivial dependency is ott-jax, a flexible tool box for solving optimal transport problems in jax.

In [None]:
!pip install ott-jax

In [None]:
import jax
jax.config.update("jax_enable_x64", True)
import jax.numpy as jnp

import h5py
import numpy as np
import matplotlib.pyplot as plt

import ott
from ott.tools import plot, sinkhorn_divergence
from ott.geometry import costs, pointcloud
from ott.problems.linear import linear_problem
from ott.solvers.linear import sinkhorn

### Downloading the data from the google cloud bucket.

The data was generated using [jax-cfd](https://github.com/google/jax-cfd), and uploaded to a Google Cloud bucket. The file contains both high- and low-resolution datasets.

We use gsutil for downloading the data. If you are running this notebook in colab, it should be already installed, otherwise, you can follow these [instructions](https://cloud.google.com/storage/docs/gsutil_install).

In [None]:
!gsutil cp gs://gresearch/swirl_dynamics/downscaling/KS_finite_volumes_vs_pseudo_spectral.hdf5 .

In [None]:
file_name = 'KS_finite_volumes_vs_pseudo_spectral.hdf5'

with h5py.File(file_name,'r+') as f1:
    # Trajectories with finite volumes.
    u_lr    = f1['u_fd'][()]
    # Trajectories with pseudo spectral methods.
    u_hr    = f1['u_sp'][()]
    # Time stamps for the trajectories.
    t    = f1['t'][()]
    # Grid in which the trajectories are computed. 512 equispaced points with
    # periodic boundary conditions.
    x    = f1['x'][()]

In [None]:
# Plotting the low-res data.
# We choose which trajectory we want to plot.
plot_idx = 1
# Spatial downsampling factor.
ds_x = 4

# Define domain in time and space.
t_ = t
x_ = jnp.concatenate([x, jnp.array(x[-1] + x[1] - x[0]).reshape((-1,))])[::ds_x]
print(f"Shape of the spatial domain: {x_.shape}")

# Plots the low-resolution data.
fig = plt.figure(figsize=(14, 4))
plt.imshow(u_lr[plot_idx, :, :].T)
plt.xlabel("time")
plt.ylabel("x")
plt.show()

fig = plt.figure()
plt.plot(x_, u_lr[plot_idx, 0, :])
plt.show()

In [None]:
# defining the low-res data using a simple sub-sampling
u_lr_hf = u_hr[:, :, ::ds_x]
x_lr_hf = x_[::ds_x]
u_lr_lf = u_lr

print(f"Shape of the low-resolution high-fidelity data {u_lr_hf.shape}")
print(f"Shape of the low-resolution grid {x_lr_hf.shape}")
print(f"Shape of the low-resolution low-fidelity data {u_lr_lf.shape}")

In [None]:
# Plot marginal histograms for all times.
spatial_idx_x = 1

plt.figure(figsize=(9, 6))
plt.hist(u_lr[:, :, spatial_idx_x].flatten(), 
         bins=100,
         alpha=0.5,
         density=True,
         label='Finite Volumes')
plt.hist(u_lr_hf[:, :, spatial_idx_x].flatten(),
         bins=100,
         alpha=0.5,
         density=True,
         label='Pseudo Spectral')
plt.legend()
plt.title("Histograms for the high- and low-fidelity solutions")
plt.show()

In [None]:
# Further downsample in time and space.
time_subsample  = 1
space_subsample = 1

x_src = u_lr_lf[:,::time_subsample,::space_subsample]
x_trgt = u_lr_hf[:,::time_subsample,::space_subsample]

# We squeeze all the data (scramble the time-step and trajectory information).
x_src = x_src.reshape((-1, x_src.shape[-1]))
x_trgt = x_trgt.reshape((-1, x_trgt.shape[-1]))

print(f'Total data set size source: {x_src.shape} target: {x_trgt.shape}')

# Define training and test data split.
train_split = 0.9
test_split  = 0.1

# Define sample sizes.
n_train = int(np.floor(x_src.shape[0]*train_split))
n_eval = int(np.floor(x_src.shape[0]*test_split))

# Divide samples.
x_src_train = x_src[:n_train,:]
x_trgt_train = x_trgt[:n_train,:]

x_src_valid = x_src[n_train:,:]
x_trgt_valid = x_trgt[n_train:,:]

print('Training data set size')
print(f"Shape of the source training data: {x_src_train.shape}")
print(f"Shape of the target training data: {x_trgt_train.shape}")

print('Validation data set size')
print(f"Shape of the source validation data: {x_src_valid.shape}")
print(f"Shape of the target validation data: {x_trgt_valid.shape}")

del u_hr

In [None]:
# Compute distance between distributions.
@jax.jit
def sinkhorn_loss(x: jax.Array, y: jax.Array, epsilon: float=0.1) -> jax.Array:
    """Computes transport between (x, a) and (y, b) via Sinkhorn algorithm."""
    # We assume equal weights for all points.
    a = jnp.ones(len(x)) / len(x)
    b = jnp.ones(len(y)) / len(y)

    sdiv = sinkhorn_divergence.sinkhorn_divergence(
        pointcloud.PointCloud, x, y, epsilon=epsilon, a=a, b=b
    )

    return sdiv[0]


## Sinkhorn Iteration

Here we instatiate the solver leveraging jax-ott. As the complexity of the Sinkhorn iteration is quadratic on the number of datapoints, we use a smaller data set (with a adjustable size) so the computation is realtively fast. In order to obtaining a transport map with better metrics, a larger n_max would be needed.

In [None]:
# Maximum number of points to be used for the transport.
# (this takes around 30 seconds to compute with an A100)
n_max = 20_040

momentum = ott.solvers.linear.acceleration.Momentum(value=.5)

# Defining the geometry.
geom = pointcloud.PointCloud(x_src_train[:n_max],
                             x_trgt_train[:n_max],
                             epsilon=0.001)

# Computing the potentials.
out = sinkhorn.Sinkhorn(max_iterations=1000,
                        momentum=momentum,
                        parallel_dual_updates=True)(
                            linear_problem.LinearProblem(geom))
dual_potentials = out.to_dual_potentials()

### Computing the sinkhorn divergence.

In [None]:
# Compute sinkhorn distance before transport.
sinkhorn_dist = sinkhorn_loss(x_src_train[:n_max],
                              x_trgt_train[:n_max],
                              epsilon=0.001)
print(f"Sinkhorn distance between source and target data: {sinkhorn_dist:.3f}")

# Compute sinkhorn distance after transport.
tx_src_train = dual_potentials.transport(x_src_train[:n_max])
sinkhorn_dist = sinkhorn_loss(tx_src_train,
                              x_trgt_train[:n_max],
                              epsilon=0.001)
print(f"Sinkhorn distance between transported source and target data: {sinkhorn_dist:.3f}")

# Compute validation distance.
tx_src_valid = dual_potentials.transport(x_src_valid[:n_max])
sinkhorn_dist = sinkhorn_loss(tx_src_valid,
                              x_trgt_valid[:n_max],
                              epsilon=0.001)
print(f"Sinkhorn distance (validation) between transported source and target data: {sinkhorn_dist:.3f}")

### Comparing distributions

In [None]:
# Transports the validation set.
tx_src_valid = np.array(dual_potentials.transport(x_src_valid))

In [None]:
# We fix one point in space to plot the histograms.
idx_x = 2

plt.figure()
plt.hist(x_src_valid[:, idx_x], bins=50, density=True,
         alpha=0.5, label='Finite Volumes')
plt.hist(x_trgt_valid[:, idx_x], bins=50, density=True,
         alpha=0.5, label='Pseudo Spectral')
plt.hist(tx_src_valid[:, idx_x], bins=50, density=True,
         alpha=0.5, label='Finite Volumes Debiased')
plt.legend()
plt.show()