In [1]:
"""
Programmer: Chris Tralie
Purpose: To use the POT library (https://github.com/rflamary/POT)
to compute the Entropic regularized Wasserstein distance 
between points on a 2D grid
"""
import numpy as np
import matplotlib.pyplot as plt
import ot

  from .autonotebook import tqdm as notebook_tqdm


In [21]:
## Step 1: Setup problem
pix = np.linspace(-1, 1, 17)
# Setup grid
X, Y = np.meshgrid(pix, pix)
# Compute pariwise distances between points on 2D grid so we know
# how to score the Wasserstein distance
coords = np.array([X.flatten(), Y.flatten()]).T
coordsSqr = np.sum(coords**2, 1)
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T)
M[M < 0] = 0
M = np.sqrt(M)

In [107]:
n = 3

## Step 1: Setup problem
pix = np.linspace(-1, 1, n)
# Setup grid
X, Y = np.meshgrid(pix, pix)
# Compute pariwise distances between points on 2D grid so we know
# how to score the Wasserstein distance
coords = np.array([X.flatten(), Y.flatten()]).T
coordsSqr = np.sum(coords**2, 1)
M = coordsSqr[:, None] + coordsSqr[None, :] - 2*coords.dot(coords.T)
M[M < 0] = 0
M = np.sqrt(M)

A = np.random.rand(n,n)
B = np.random.rand(n,n)

wass = ot.emd2(A.flatten(), B.flatten(), M, 1.0)
print(wass)

wass = ot.emd2(A.flatten(), A.flatten(), M, 1.0)
print(wass)

AssertionError: 
Arrays are not almost equal to 6 decimals
a and b vector must have the same sum
Mismatched elements: 1 / 1 (100%)
Max absolute difference: 0.99403697
Max relative difference: 0.23798587
 x: array(5.170911)
 y: array([4.176874])

# GROMOV-WASSERSTIN

In [112]:
# Way more metricy

n = 17

A = np.random.rand(n,n)
B = np.random.rand(n,n)

wass = ot.gromov_wasserstein2(A, B)
print(wass)

wass = ot.gromov_wasserstein2(B, A)
print(wass)

wass = ot.gromov_wasserstein2(A*5, B)
print(wass)

wass = ot.gromov_wasserstein2(A, A)
print(wass)

0.09879834613051615
0.09879834613051632
5.945480213377253
2.0816681711721685e-17
