First, install the POT package for comparison. https://pythonot.github.io/

In [1]:
!pip install POT



Import packages.

In [2]:
import gc
import ot
import torch as th
import time

from mdot_tnt import solve_OT
from mdot_tnt.rounding import round_altschuler 
device = "cuda:0"

Add a function for sampling random OT problems.

In [3]:
def sample_random_problem(N, M, dim=100):
    # Sample some distributions r and c according to a Dirichlet distribution.
    r = th.distributions.Dirichlet(th.ones(N)).sample()
    c = th.distributions.Dirichlet(th.ones(M)).sample()

    # Sample N points x and M points y from a multivariate normal distribution in 100D.
    x = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((N,))
    y = th.distributions.MultivariateNormal(th.zeros(dim), th.eye(dim)).sample((M,))

    # Compute the cost matrix C = ||x - y||_2^2.
    C = th.cdist(x, y, p=2) ** 2
    C /= C.max()

    # Change to double precision.
    # r, c, C = r.double(), c.double(), C.double()
    r /= r.sum()
    c /= c.sum()

    return r, c, C

Sample a large problem (n=14000)

In [6]:
n = 14000
r, c, C = sample_random_problem(n, n)
r, c, C = r.double().to(device), c.double().to(device), C.double().to(device)

Let's use the highly efficient exact solver (CPU-based) of the POT library and time it.

In [7]:
r_, c_, C_ = r.cpu().numpy(), c.cpu().numpy(), C.cpu().numpy()
time_start = time.time()
cost_emd = ot.emd2(r_, c_, C_, numItermax=int(1e10))
elapsed = time.time() - time_start
print("OT Cost: {:.10f}, Time: {:.3f}".format(cost_emd, elapsed))

OT Cost: 0.3182945673, Time: 97.668


Now use mdot_tnt to tackle the problem on a GPU (NVIDIA RTX 2080 Ti in this case).

In [8]:
time_start = time.time()
cost = solve_OT(r, c, C, gamma_f=1000)  # gamma_f is the inverse of the final regularization weight (1e-3 here) 
elapsed = time.time() - time_start
print("MDOT-TNT error: {:.3e}, Time: {:.3f}".format((cost - cost_emd), elapsed))
gc.collect()
th.cuda.empty_cache()

MDOT-TNT error: 8.803e-05, Time: 4.435


4-5 decimal precision with more than 20x speedup. Needless to say, the speedup can be better on higher-end GPUs.
Let's also check the speedup using FP32 precision.

In [9]:
time_start = time.time()
cost = solve_OT(r.float(), c.float(), C.float(), gamma_f=1000)  # gamma_f is the inverse of the final regularization weight (1e-3 here) 
elapsed = time.time() - time_start
print("MDOT-TNT error: {:.3e}, Time: {:.3f}".format((cost - cost_emd), elapsed))
gc.collect()
th.cuda.empty_cache()

MDOT-TNT error: 8.821e-05, Time: 1.705


57x speedup on this random problem! Not bad!

If either marginal is known to have many tiny entries (is effectively a sparse vector), we can further accelerate computation by dropping those particles by setting `drop_tiny=True`. Note that this feature was not used in the paper for fairness in benchmarking, but can be useful in practice.

In [10]:
# Set a random half of the entries of r and c to 1e-20, and renormalize.
r2 = r.clone()
c2 = c.clone()
r2[th.randperm(n)[:n // 2]] = 1e-20
c2[th.randperm(n)[:n // 2]] = 1e-20
r2 /= r2.sum()
c2 /= c2.sum()

In [11]:
time_start = time.time()
cost_emd2 = ot.emd2(r2.cpu().numpy(), c2.cpu().numpy(), C.cpu().numpy(), numItermax=int(1e10))
elapsed = time.time() - time_start
print("OT Cost: {:.10f}, Time: {:.3f}".format(cost_emd, elapsed))

OT Cost: 0.3182945673, Time: 82.562


A similar runtime as before for the exact solver... Let's rerun MDOT-TNT with `drop_tiny=True`.

In [12]:
time_start = time.time()
cost = solve_OT(r2, c2, C, gamma_f=1000, drop_tiny=True)  # gamma_f is the inverse of the final regularization weight (1e-3 here) 
elapsed = time.time() - time_start
print("MDOT-TNT error: {:.3e}, Time: {:.3f}".format((cost - cost_emd2), elapsed))
gc.collect()
th.cuda.empty_cache()

Dropped 7028 entries from r and 7032 entries from c.
MDOT-TNT error: 8.172e-05, Time: 1.155


Same level of precision as before, but this time ~70x speedup. And now doing the same with FP32 precision.

In [13]:
time_start = time.time()
cost = solve_OT(r2.float(), c2.float(), C.float(), gamma_f=1000, drop_tiny=True)  # gamma_f is the inverse of the final regularization weight (1e-3 here) 
elapsed = time.time() - time_start
print("MDOT-TNT error: {:.3e}, Time: {:.3f}".format((cost - cost_emd2), elapsed))
gc.collect()
th.cuda.empty_cache()

Dropped 7028 entries from r and 7032 entries from c.
MDOT-TNT error: 8.187e-05, Time: 0.535


154x speedup. Let's go back to the original problem (dense marginals) and see how Sinkhorn fares; starting with strong regularization and gradually decreasing regularization weight.

In [14]:
gc.collect()
th.cuda.empty_cache()
time_start = time.time()
plan = ot.sinkhorn(r, c, C, reg=1/100, numItermax=int(1e10))
plan = round_altschuler(plan, r, c)
cost = (plan * C).sum()
elapsed = time.time() - time_start
print("Sinkhorn error: {:.3e}, Time: {:.3f}".format(cost - cost_emd, elapsed))
del plan

Sinkhorn error: 1.458e-02, Time: 0.511


Remember the optimal cost is about 0.318. Relative error here is about 0.0146 * 100 / 0.318 = 4.6% (hardly negligible). Let's run at the same temperature as MDOT-TNT.

In [15]:
time_start = time.time()
plan = ot.sinkhorn(r, c, C, reg=1/1000, numItermax=int(1e10))
plan = round_altschuler(plan, r, c)
cost = (plan * C).sum()
elapsed = time.time() - time_start
print("Sinkhorn error: {:.3e}, Time: {:.3f}".format(cost - cost_emd, elapsed))
del plan

Sinkhorn error: 7.315e-05, Time: 67.102


MDOT-TNT exhibits 15x speedup (took 4.435 seconds under the same setup of dense vectors + FP64 precision). As we show in the paper, the gap grows with weaker regularization.

Let's also give Greenkhorn by Altschuler et al. (2017) a try.

In [17]:
gc.collect()
th.cuda.empty_cache()
time_start = time.time()
plan = ot.bregman.greenkhorn(r, c, C, reg=1/1000, numItermax=int(1e10))
plan = round_altschuler(plan, r, c)
cost = (plan * C).sum()
elapsed = time.time() - time_start
print("Greenkhorn error: {:.3e}, Time: {:.3f}".format(cost - cost_emd, elapsed))
del plan

Greenkhorn error: 7.461e-05, Time: 2929.723


For this value of n=14000, Greenkhorn suffers from low GPU utilization. Even if the total number of row or column updates are fewer than those of Sinkhorn, in practice, it is substantially slower because of limited parallelization, updating one row/column at a time.