In [1]:
import numpy as np

In [2]:
# Discrete measures
a = np.array([.5, .5])  # (n,)
b = np.array([.4, .3, .3])  # (m,)

n = a.shape[0]
m = b.shape[0]

In [3]:
# Transportation costs
M_xy = np.array([[0., 1., 2.], [1., 0., 1.]])  # (n, m)

In [4]:
# Hyper parameters
gamma = 1.

In [5]:
# Sinkhorn's algorithm
K = np.exp(-M_xy / gamma)  # (n, m), not optimized
u = np.random.random(n)  # (n,)
v = np.random.random(m)  # (m,)

for _ in range(10):
    u = a / np.matmul(K, v)  # (n,)
    v = b / np.matmul(K.T, u)  # (m,) = (m,)
P_gamma = np.matmul(np.matmul(np.diag(u), K), np.diag(v))  # (n, m)
W = np.sum(P_gamma * M_xy)  # (,)

In [6]:
print('*** Entropic regularized OT ***')
print(P_gamma)

print('\n*** Wasserstein distance ***')
print(W)

*** Entropic regularized OT ***
[[0.30973228 0.09513388 0.09513388]
 [0.09026772 0.20486612 0.20486612]]

*** Wasserstein distance ***
0.5805354719411625


In [7]:
# Calculate P_gamma & W using POT (https://pythonot.github.io/)
import ot

P_gamma_ = ot.sinkhorn(a, b, M_xy, 1.)
W_ = ot.sinkhorn2(a, b, M_xy, 1.)

print('*** Entropic regularized OT (calculated by POT) ***')
print(P_gamma_)

print('\n*** Wasserstein distance (calculated by POT) ***')
print(W_)

*** Entropic regularized OT (calculated by POT) ***
[[0.30973227 0.09513387 0.09513387]
 [0.09026773 0.20486613 0.20486613]]

*** Wasserstein distance (calculated by POT) ***
[0.58053546]
