In [1]:
import torch
import torch.distributions as dist

# Uniform

In [2]:
lx = ly = 3

torch.manual_seed(0)
U_x = U_y = dist.Uniform(torch.zeros(lx), torch.ones(lx))

n = 10

X = U_x.sample(
    (
        lx,
        n,
    )
)
Y = U_y.sample(
    (
        ly,
        n,
    )
)

# Numpy

In [3]:
N_1 = dist.MultivariateNormal(
    loc=torch.zeros(lx),
    covariance_matrix=torch.eye(lx),
)

In [4]:
assert ly == lx == 3

c2 = torch.eye(3)
c2[-1, -1] = (1 / 10) ** 2

c3 = torch.eye(3)
c3[1, 1] = (1 / 2) ** 2
c3[-1, -1] = (1 / 10) ** 2

N_2 = dist.MultivariateNormal(
    loc=torch.zeros(ly),
    covariance_matrix=c2,
)

N_3 = dist.MultivariateNormal(
    loc=torch.zeros(ly),
    covariance_matrix=c3,
)

# global gw

In [5]:
from cryo_challenge.map_to_map.gromov_wasserstein.global_gw import gw_global

## 6 1d points

In [6]:
import numpy as np

n_points = 6
X = np.random.rand(n_points, 1)
np.random.seed(0)
noise = 0.01 * np.random.randn(n_points)
Y = X + noise

In [7]:
transport_plan_normalized, logs = gw_global(
    X, Y, IterMax=20, log=True, do_optimize_with_sparse=False, verbose=True
)

do_optimize_with_sparse False
Iter |Bound gap
----------------------
    0|5.438691e+00
    1|2.081324e+00
    2|1.932792e+00
    3|4.704649e-03
    4|1.897541e-03
    5|0.000000e+00


In [8]:
transport_plan_normalized * n_points

array([[1., 0., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0.],
       [0., 0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 0., 1.]])

## N3

In [19]:
n_points = 300

torch.manual_seed(0)
X = N_3.sample((n_points,))
Y = N_3.sample((n_points,))

# noise_level = 0.01

# noise = 0.01 * N_3.sample((n_points,))
# Y = X + noise

In [20]:
transport_plan_normalized = gw_global(
    X.numpy(),
    Y.numpy(),
    IterMax=200,
    log=False,
    do_optimize_with_sparse=True,
    verbose=True,
    epsilon=1e-2,
)

do_optimize_with_sparse True
Iter |Bound gap
----------------------
    0|2.272398e+05
    1|2.271654e+05
    2|2.268146e+05
    3|2.267639e+05
    4|2.264206e+05
    5|2.253235e+05
    6|2.250026e+05
    7|2.142872e+05
    8|7.026240e+04
    9|6.983631e+04
   10|6.442906e+04
   11|5.887060e+04
   12|5.052449e+04
   13|4.779783e+04
   14|4.492473e+04
   15|4.237846e+04
   16|3.283231e+04
   17|3.218363e+04
   18|2.755369e+04
   19|2.634179e+04
   20|2.011572e+04
   21|1.920572e+04
   22|1.836656e+04
   23|1.687362e+04
   24|1.679705e+04
   25|1.676905e+04
   26|1.670450e+04
   27|1.600120e+04
   28|1.535170e+04
   29|1.496941e+04
   30|1.411362e+04
   31|1.374033e+04
   32|1.334262e+04
   33|1.233532e+04
   34|1.227521e+04
   35|1.216696e+04
   36|1.048445e+04
   37|9.868544e+03
   38|9.534319e+03
   39|8.920476e+03
   40|8.668986e+03
   41|8.652804e+03
   42|8.569905e+03
   43|8.282840e+03
   44|7.894350e+03
   45|7.398217e+03
   46|7.014508e+03
   47|6.260786e+03
   48|4.136922e+03
 