In [1]:
import numpy as np
import ot

In [2]:
def projection_simplex(V, z=1, axis=None):
    """
    Projection of x onto the simplex, scaled by z:
        P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
    z: float or array
        If array, len(z) must be compatible with V
    axis: None or int
        axis=None: project V by P(V.ravel(); z)
        axis=1: project each V[i] by P(V[i]; z[i])
        axis=0: project each V[:, j] by P(V[:, j]; z[j])
    """
    if axis == 1:
        n_features = V.shape[1]
        U = np.sort(V, axis=1)[:, ::-1]
        z = np.ones(len(V)) * z
        cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
        ind = np.arange(n_features) + 1
        cond = U - cssv / ind > 0
        rho = np.count_nonzero(cond, axis=1)
        theta = cssv[np.arange(len(V)), rho - 1] / rho
        return np.maximum(V - theta[:, np.newaxis], 0)

    elif axis == 0:
        return projection_simplex(V.T, z, axis=1).T

    else:
        V = V.ravel().reshape(1, -1)
        return projection_simplex(V, z, axis=1).ravel()


In [5]:
n = 5
a = np.random.rand(n)
b = np.random.rand(n)
a = a / a.sum()
b = b / b.sum()
C = np.random.randint(0, 10, n**2).reshape(n, n)
Pot, log_dict = ot.emd(a, b, C, log=True)

In [11]:
P = np.eye(n)
for _ in range(10000):
    P1 = projection_simplex(P, a, axis=1)
    P2 = projection_simplex(P, b, axis=0)
    P = 0.5 * (P1 + P2)

np.sum(P * C), np.sum(Pot * C)

(np.float64(5.553056709701885), np.float64(2.2212935318919604))

In [151]:
dual_res = C - (log_dict["u"][:, None] + log_dict["v"][None, :])

In [152]:
dual_res_flat = dual_res.flatten()
P_flat = P.flatten()

In [154]:
np.sum((P_flat < 1e-10) & (dual_res_flat < 1e-10)) + np.sum(P_flat > 1e-10)

np.int64(11)

In [155]:
print(np.round(P_flat, 3))
print(np.round(dual_res_flat, 3))


[0.    0.195 0.    0.    0.    0.    0.    0.    0.036 0.051 0.    0.
 0.236 0.    0.056 0.    0.038 0.    0.    0.157 0.19  0.041 0.    0.
 0.   ]
[8. 0. 6. 8. 9. 0. 5. 3. 0. 0. 5. 4. 0. 2. 0. 0. 0. 8. 6. 0. 0. 0. 6. 6.
 6.]


In [None]:
from solvers.adapted_ns.uot_l2 import UOT_L2


In [1]:
import numpy as np

np.zeros(3)

array([0., 0., 0.])

In [None]:
import unittest
from time import time

import numpy as np

from solvers.adapted_ns.uot_l2 import UOT_L2
from solvers.cvx import Cvx
from utils.utils import get_instance, primal_cost, primal_cost_full_quad

EPS = 1e-15
n = 5
for seed in range(100):
    np.random.seed(seed)
    tau = 0.1
    a, b, C = get_instance(n)

    # weight_a = np.random.rand(n)
    # weight_b = np.random.rand(n)

    # Graph method
    graph = UOT_L2(
        C.T,
        b,
        a,
        n,
        tau=tau,
        # weight_mu=weight_b,
        # weight_nu=weight_a,
    )
    t0 = time()
    tp_graph = graph.solve(n_max=1000).T
    tf = time() - t0

    graph = UOT_L2(
        C,
        a,
        b,
        n,
        tau=tau,
        # weight_mu=weight_b,
        # weight_nu=weight_a,
    )
    t0 = time()
    tp_graph = graph.solve(n_max=1000)
    tf = time() - t0

    # CVXPY solver
    # cvx = Cvx(C, a, b, n, tau=tau, alpha=weight_a, beta=weight_b)
    # t_cvx = cvx.solve(reg="full_quad")

    # difference for the primal
    # diff_cost = np.abs(
    #     primal_cost_full_quad(tp_graph, C, a, b, tau, n, weight_a, weight_b)
    #     - primal_cost_full_quad(t_cvx, C, a, b, tau, n, weight_a, weight_b)
    # )
    # error_plan = np.abs(tp_graph - t_cvx).max()

    mask_graph = tp_graph > EPS
    sat = C - graph.f_array[:, None] - graph.g_array[None, :]
    # print(mask_graph)
    unsat = sat > EPS
    # print("SAT")
    # print(unsat)
    print((sat > -1e-13).sum() == n**2, np.abs(sat * tp_graph).sum() < 1e-13)
    # print("support", np.sum(tp_graph > 1e-10))
    print(
        seed,
        "support",
        np.sum(tp_graph > 1e-10),
        np.sum((np.abs(sat) < 1e-10) & (tp_graph < 1e-10)),
        np.sum(sat > 1e-10),
        np.sum(tp_graph > 1e-10) + np.sum(sat > 1e-10),
    )
    if np.sum((np.abs(sat) < 1e-10) & (tp_graph < 1e-10)) > 0:
        print("NONE")
        break

    if np.sum(tp_graph > 1e-10) + np.sum(sat > 1e-10) != n**2:
        print("error")
        break
    # print(mask_graph * tp_graph)

    # print(f"Error: {diff_cost:2e}| Plan error: {error_plan:2e}| time: {tf:2e}")


True True
0 support 8 0 17 25
True True
1 support 9 0 16 25
True True
2 support 6 0 19 25
True True
3 support 8 0 17 25
True True
4 support 7 0 18 25
True True
5 support 7 0 18 25
True True
6 support 7 0 18 25
True True
7 support 7 0 18 25
True True
8 support 6 0 19 25
True True
9 support 8 0 17 25
True True
10 support 6 0 19 25
True True
11 support 7 0 18 25
True True
12 support 8 0 17 25
True True
13 support 7 0 18 25
True True
14 support 8 0 17 25
True True
15 support 7 0 18 25
True True
16 support 8 0 17 25
True True
17 support 8 0 17 25
True True
18 support 6 0 19 25
True True
19 support 6 0 19 25
True True
20 support 6 0 19 25
True True
21 support 9 0 16 25
True True
22 support 7 0 18 25
True True
23 support 6 0 19 25
True True
24 support 8 0 17 25
True True
25 support 9 0 16 25
True True
26 support 5 0 20 25
True True
27 support 5 0 20 25
True True
28 support 6 0 19 25
True True
29 support 8 0 17 25
True True
30 support 7 0 18 25
True True
31 support 9 0 16 25
True True
32 suppo