Skip to content
This repository has been archived by the owner on Aug 31, 2022. It is now read-only.

Numerical instability of Sinkhorn (even with lse_mode=True) & weird behavior with lse_mode=False #6

Closed
theouscidda6 opened this issue May 6, 2021 · 3 comments

Comments

@theouscidda6
Copy link

theouscidda6 commented May 6, 2021

Generate two empirical masures:

  • points in the support generated i.i.d from U([0,1]^5)
  • weights generated i.i.d from U([0,1]) and normalized
rng = jax.random.PRNGKey(1) 
keys = jax.random.split(rng, 2)

# parameters of the measures
dim = 5
n = 100
m = 150

# define the size of the grid
x = jax.random.uniform(keys[0], (n, dim))
y = jax.random.uniform(keys[1], (m, dim))

# # weights of the measures 
a = jax.random.uniform(keys[0], (len(x),)) 
b = jax.random.uniform(keys[1], (len(y),))
a = a.ravel() / jnp.sum(a)
b = b.ravel() / jnp.sum(b)

Compute the regularized Wasserstein distance for decreasing epsilons with lse_mode=True:

# regularization strength candidates 
eps_cand = [10**(-i) for i in range(10)]

for eps in eps_cand:
    
    # define the geometry
    geom = pointcloud.PointCloud(x, y, epsilon=eps)

    # run the Sinkhorn algorithm
    out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=True) # just set to True to emphasize it, by default it is set to True

    print(f'epsilon = {eps}: regularised optimal transport cost = {out.reg_ot_cost}')
    
    if jnp.isnan(out.reg_ot_cost):
        break

Compute the regularized Wasserstein distance for decreasing epsilons with lse_mode=False:

# regularization strength candidates 
eps_cand = [10**(-i) for i in range(10)]

for eps in eps_cand:
    
    # define the geometry
    geom = pointcloud.PointCloud(x, y, epsilon=eps)

    # run the Sinkhorn algorithm
    out = sinkhorn.sinkhorn(geom, a=a, b=b, lse_mode=False)

    print(f'epsilon = {eps}: regularised optimal transport cost = {out.reg_ot_cost}')
    
    if jnp.isnan(out.reg_ot_cost):
        break

Comments:

Using the logsumexp mode (les_mode=True), I get overflow (i.e. nan) from epsilon = 10e-3. This is quite strange because as the support points are randomly drawn in U([0,1]^5), the maximum distance between two support points of the first and second measure is 25. So we have ||C||_inf / eps ~ 2.5 * 10e4 which is reasonable, so the logsumexp should not generate any overflow.

Moreover, by not using the logsumexp mode (les_mode=False), I don't get any overflow (until epsilon = 1e-9). This is strange since logsumexp is supposed to be more stable than the version using matrix products of vectors against Gibbs Kernel. On the other hand, as epsilon decreases, the regularized wassertein distance tends to 0 (9.99 * 10e-10 for epsilon = 1e-9). But when epislon becomes very small, the regularized Wassersetin distance tends towards the Wasserstein distance. This is therefore strange because there is no reason a priori for the Wasserstein distance between the two measures to be zero. Indeed, even if the points of the support of each measure are drawn according to the same law (U([0,1]^5), the weights are not uniform and are also drawn randomly then normalized. The regularized Wasserstein distance that we compute is therefore not a priori an estimator of the regularized Wasserstein distance between two measures following the same law (U([0,1]^5), which in this case would make sense to tend towards 0.

I hope to have helped you with this remark, and I thank you very much for the development of OTT which really facilitates the use of numerical optimal transport, especially for the differentiation of OT metrics.

@ersisimou
Copy link

ersisimou commented May 6, 2021

Hi @theouscidda6 ,

When I run your code snippet with the latest OTT release I get:

epsilon = 1: regularised optimal transport cost = 0.7936503887176514
epsilon = 0.1: regularised optimal transport cost = 0.43291008472442627
epsilon = 0.01: regularised optimal transport cost = 0.18572384119033813
epsilon = 0.001: regularised optimal transport cost = 0.15050488710403442
epsilon = 0.0001: regularised optimal transport cost = 0.13975945115089417
epsilon = 1e-05: regularised optimal transport cost = 0.11445839703083038
epsilon = 1e-06: regularised optimal transport cost = 0.10161541402339935
epsilon = 1e-07: regularised optimal transport cost = 0.09994244575500488
epsilon = 1e-08: regularised optimal transport cost = 0.09977512806653976
epsilon = 1e-09: regularised optimal transport cost = 0.0790291577577591
epsilon = 1: regularised optimal transport cost = 0.7936503887176514
epsilon = 0.1: regularised optimal transport cost = 0.43291008472442627
epsilon = 0.01: regularised optimal transport cost = 0.18572384119033813
epsilon = 0.001: regularised optimal transport cost = 0.044053830206394196
epsilon = 0.0001: regularised optimal transport cost = 0.00014410055882763118
epsilon = 1e-05: regularised optimal transport cost = 9.999999747378752e-06
epsilon = 1e-06: regularised optimal transport cost = 9.999999974752427e-07
epsilon = 1e-07: regularised optimal transport cost = 1.0000000116860974e-07
epsilon = 1e-08: regularised optimal transport cost = 9.99999993922529e-09
epsilon = 1e-09: regularised optimal transport cost = 9.999999717180685e-10

So, I would suggest you download the latest release (which might not be the one you get with pip install). The overflow is related to a previous issue. I agree with you though that for epsilon < =0.001 with lse_mode=False, instead of getting a NaN (so that we know that there was a numerical issue) we get extremely low values for the OT objective. It would be great if this could be fixed.

@LaetitiaPapaxanthos
Copy link
Contributor

Hi both,

Thanks a lot for raising this issue!

As @ersisimou said, it is likely you get the nans in lse_mode=True because you do not have the latest version of OTT.

Regarding the very small reg_ot_cost values, thanks a lot for the comment, we will see what we can do to indicate numerical issues. In fact, with such small values of epsilon (<1e-4) and lse_mode=False, you can see that the kernel matrix (geom.kernel_matrix) is null. Meanwhile, it is possible to follow other outputs of sinkhorn to verify that the algorithm has indeed converged. For example:

  • out.converged: indicates True (converged) or False. In the case of the code you shared, it indicates False as soon as epsilon=1e-3 for lse_mode=False and epsilon=1e-4 for lse_mode=True. In some cases, Sinkhorn could be reaching max_iterations before having the time to converge. It is possible to increase the number of max_iterations at the expense of a longer runtime.
  • out.errors: enables you to verify the size of the error (second marginal error in our case). Looking at it, for very small values of epsilon and lse_mode=False for example, you can see that the error is 1 or close to 1 from epsilon=1e-3 and below. It indicates that the algorithm has not converged at all and therefore the results should not be trusted.

Nevertheless, we will see how to better indicate numerical issues.

@marcocuturi
Copy link
Contributor

Hi Théo,
so I guess the take home message is to look into converged. We thought (as is usually done, e.g., in scipy.optimize with OptimizeResults to return a solution even if it is sub-optimal or wrong.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants