### Create Wasserstein Metric

In [46]:
import numpy as np
from tqdm import tqdm
from scipy.optimize import linprog
from scipy.stats import wasserstein_distance
from ot.sliced import sliced_wasserstein_distance
import ot
import time

In [66]:
# Example: 6 Koopman eigenvalues per segment
eigs1 = np.array([0.5 - 0.4j, 0.99 + 0.01j, 0.2 + 0.3j, 1.1 - 0.2j, 0.6 + 0.7j])
eigs2 = np.array([0.9 + 1.1j, 0.4 - 0.3j, 0.95 + 0.05j, 0.3 + 0.35j, 1.2 - 0.1j, 0.55 + 0.6j])

# Weights = normalized mode norms
weights1 = np.array([0.3, 1.2, 0.4, 0.6, 0.8])
weights2 = np.array([1.0, 0.4, 1.6, 0.3, 0.7, 0.5])
weights1 /= weights1.sum()
weights2 /= weights2.sum()

In [67]:
l1 = np.stack([eigs1.real, eigs1.imag], axis=1)
l2 = np.stack([eigs2.real, eigs2.imag], axis=1)


start = time.time()
for i in range(10):
    C = ot.dist(l1, l2, metric='euclidean')
    emd2_dist = ot.emd2(weights1, weights2, C)
emd2_time = time.time() - start

In [61]:
start = time.time()
for i in range(1000):
    sinkhorn_dist = ot.sinkhorn2(weights1, weights2, C, reg=1e-2)  # sinkhorn2 returns (dist, log)
sinkhorn_time = time.time() - start

In [62]:
start = time.time()
for i in range(1000):
    scipy_dist_real = wasserstein_distance(u_values=eigs1.real, v_values=eigs2.real,
                                   u_weights=weights1, v_weights=weights2)
scipy_time = time.time() - start

In [65]:
emd2_time, sinkhorn_time, scipy_time

(0.19497990608215332, 17.353342533111572, 0.0720674991607666)

In [59]:
emd2_dist, sinkhorn_dist, np.sqrt(scipy_dist_real**2 + scipy_dist_imag**2)

(0.2560964183153938,
 np.float64(0.25613216287205276),
 np.float64(0.2311722091380748))

In [82]:
l1 = np.array([[ 1.37650088e-04,  0.00000000e+00],
                 [ 4.75407495e-01,  4.14592788e-01],
                 [ 4.75407495e-01, -4.14592788e-01],
                 [ 9.98255444e-01,  5.89792744e-02],
                 [ 9.98255444e-01, -5.89792744e-02],
                 [ 9.99179206e-01,  0.00000000e+00],
                 [ 8.56796503e-01,  3.37291220e-01],
                 [ 8.56796503e-01, -3.37291220e-01]])
l2 = np.array([[-0.32054115,  0.],
                 [ 0.29132765,  0.64505518],
                 [ 0.29132765, -0.64505518],
                 [ 0.99299793,  0.11811423],
                 [ 0.99299793, -0.11811423],
                 [ 0.72857638,  0.60422197],
                 [ 0.72857638, -0.60422197],
                 [ 0.99790184,  0.        ]])

m1 = np.array([2.25563911e-05, 1.03366498e-01, 1.03366498e-01, 1.63866996e-01,
1.63866996e-01, 1.63733111e-01, 1.50888672e-01, 1.50888672e-01])
m2 = np.array([0.04836847, 0.10680301, 0.10680301, 0.15089596, 0.15089596, 0.14282695,
 0.14282695, 0.15057967])

C = np.array([[3.20678799e-01, 7.07734273e-01, 7.07734273e-01, 9.99861245e-01,
  9.99861245e-01, 9.46418077e-01, 9.46418077e-01, 9.97764190e-01,],
 [8.97452742e-01, 2.94954752e-01, 1.07551811e+00, 5.96489224e-01,
  7.42749370e-01, 3.16312681e-01, 1.04979903e+00, 6.66998891e-01,],
 [8.97452742e-01, 1.07551811e+00, 2.94954752e-01, 7.42749370e-01,
  5.96489224e-01, 1.04979903e+00, 3.16312681e-01, 6.66998891e-01,],
 [1.32011477e+00, 9.18276582e-01, 9.97703076e-01, 5.93682111e-02,
  1.77171530e-01, 6.08289731e-01, 7.15934837e-01, 5.89803344e-02,],
 [1.32011477e+00, 9.97703076e-01, 9.18276582e-01, 1.77171530e-01,
  5.93682111e-02, 7.15934837e-01, 6.08289731e-01, 5.89803344e-02,],
 [1.31972035e+00, 9.57679497e-01, 9.57679497e-01, 1.18275862e-01,
  1.18275862e-01, 6.62049908e-01, 6.62049908e-01, 1.27736656e-03,],
 [1.22469968e+00, 6.43796306e-01, 1.13347231e+00, 2.58049185e-01,
  4.75336674e-01, 2.96129070e-01, 9.50203921e-01, 3.65617400e-01,],
 [1.22469968e+00, 1.13347231e+00, 6.43796306e-01, 4.75336674e-01,
  2.58049185e-01, 9.50203921e-01, 2.96129070e-01, 3.65617400e-01,]])
# C = ot.dist(l1, l2, metric='euclidean')
dist1 = ot.emd2(m1, m2, C)  # returns scalar distance

In [87]:
C = np.ascontiguousarray(C, dtype=np.float64)

C shape: (8, 8), dtype: float64, contiguous: True
m1 shape: (8,), dtype: float64, contiguous: True
m2 shape: (8,), dtype: float64, contiguous: True


In [40]:
l1 = np.array([[1,2,3,4],[1,2,3,4]])
l2 = np.array([[1,2,2.7,4.1],[1,2,3,4]])
m1 = np.array([0.7,0.5,0.7,1.0])
m2 = np.array([0.5,0.4,0.7,1.1])
m1 /= m1.sum()
m2 /= m2.sum()

vals = wasserstein_distance(u_values=l1, v_values=l2, u_weights=m1, v_weights=m2)
C = ot.dist(l1, l2, metric='euclidean')
dist1 = ot.emd2(m1, m2, C)  # returns scalar distance
dist2 = ot.sinkhorn2(m1, m2, C, reg=1e-2)  # reg = entropy smoothing param

ValueError: Value and weight array-likes for the same empirical distribution must be of the same size.

In [29]:
l1 = np.array([1,2,3,4]) + 1j * np.array([1.1,1.9,2.3,3.8])
l2 = np.array([1,2,2.7,4.1]) + 1j * np.array([1.01,1.8,2.6,4.2])
m1 = np.array([0.7,0.5,0.7,1.0])
m2 = np.array([0.5,0.4,0.7,1.1])
m1 /= m1.sum()
m2 /= m2.sum()

l1 = np.stack([l1.real, l1.imag], axis=1)
l2 = np.stack([l2.real, l2.imag], axis=1)

wasserstein_distance(u_values=l1, v_values=l2, u_weights=m1, v_weights=m2)

# Cost matrix between supports
C = ot.dist(l1, l2, metric='euclidean')

# Compute Wasserstein-1 distance
dist = ot.emd2(m1, m2, C)  # returns scalar distance

ValueError: object too deep for desired array

In [28]:
dist

0.5295319623010462

In [6]:
def compute_wasserstein_metric(l1, l2, m1, m2, q):
    
    # --- Step 1: Sample Inputs ---
    n = len(l1)
    n_bar = len(l1)

    m1 = m1 / np.sum(m1)
    m2 = m2 / np.sum(m2)
    
    assert np.isclose(np.sum(m1), 1.0)
    assert np.isclose(np.sum(m2), 1.0)
    
    # --- Step 2: Compute cost matrix C (shape n x n_bar) ---
    C = np.linalg.norm(l1[:, None, :] - l2[None, :, :], axis=2) ** 2
    # Flatten to 1D for linprog
    c = C.flatten()  # size (n * n_bar,)
    
    # --- Step 3: Equality Constraints ---
    
    # Total number of variables
    N = n * n_bar
    
    # 1. Row sum constraints: each row i must sum to m[i]
    A_eq_rows = np.zeros((n, N))
    for i in range(n):
        for j in range(n_bar):
            A_eq_rows[i, i * n_bar + j] = 1
    b_eq_rows = m1
    
    # 2. Column sum constraints: each column j must sum to m_bar[j]
    A_eq_cols = np.zeros((n_bar, N))
    for j in range(n_bar):
        for i in range(n):
            A_eq_cols[j, i * n_bar + j] = 1
    b_eq_cols = m2
    
    # 3. Total mass constraint
    A_eq_total = np.ones((1, N))
    b_eq_total = np.array([1.0])
    
    # Combine constraints
    A_eq = np.vstack([A_eq_rows, A_eq_cols, A_eq_total])
    b_eq = np.concatenate([b_eq_rows, b_eq_cols, b_eq_total])
    
    # --- Step 4: Bounds ---
    bounds = [(0, None)] * N  # rho_ij >= 0
    
    # --- Step 5: Solve ---
    result = linprog(c, A_eq=A_eq, b_eq=b_eq, bounds=bounds, method='highs')
    
    # --- Step 6: Extract solution ---
    if result.success:
        rho_star = result.x.reshape((n, n_bar))
        transport_cost = np.sum(rho_star * C)
        wasserstein_metric = transport_cost ** (1 / q)
    else:
        print("Optimization failed:", result.message)

    return wasserstein_metric

In [7]:
dim = 10
l1 = np.random.random(dim) + 1j * np.random.random(dim)
l2 = np.random.random(dim) + 1j * np.random.random(dim)
m1 = np.random.random(dim) + 1j * np.random.random(dim)
m2 = np.random.random(dim) + 1j * np.random.random(dim)
q = 1

In [8]:
for i in tqdm(range(100)):
    manual_metric = compute_wasserstein_metric(l1[:,np.newaxis],l2[:,np.newaxis],m1,m2,q)

  b = np.array(b, dtype=float, copy=True).squeeze()
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 119.44it/s]


In [9]:
for i in tqdm(range(100)):
    scipy_metric = wasserstein_distance(u_values=l1, v_values=l2, u_weights=m1, v_weights=m2)

  values = np.asarray(values, dtype=float)
  weights = np.asarray(weights, dtype=float)
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:00<00:00, 2015.44it/s]


In [4]:
def compute_swd(l1, l2, m1, m2, n_projections=50):
    return sliced_wasserstein_distance(l1, l2, n_projections=n_projections, p=1, a=m1, b=m2)

In [5]:
for i in tqdm(range(1000)):
    compute_swd(l1[:,np.newaxis], l2[:,np.newaxis], m1, m2)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [00:01<00:00, 644.44it/s]
