In [563]:
from solvers.adapted_ns.suot_l2 import SUOT_L2
from utils.utils import get_instance, primal_cost, primal_cost_full_quad
from solvers_reg.cvx import Cvx_ineq
import numpy as np


EPS = 1e-15
n = 50
# np.random.seed(0)
a, b, C = get_instance(n)
tau = 1


solver = Cvx_ineq(C, a, b, n, tau=tau)
tp_graph = solver.solve(eps=1e-15)

In [564]:
obj = (tp_graph * C).sum() + 0.5 / tau * np.sum(
    (tp_graph.T @ np.ones(n) - b) ** 2
)

In [565]:
def projection_simplex(V, z=1, axis=None):
    """
    Projection of x onto the simplex, scaled by z:
        P(x; z) = argmin_{y >= 0, sum(y) = z} ||y - x||^2
    z: float or array
        If array, len(z) must be compatible with V
    axis: None or int
        axis=None: project V by P(V.ravel(); z)
        axis=1: project each V[i] by P(V[i]; z[i])
        axis=0: project each V[:, j] by P(V[:, j]; z[j])
    """
    if axis == 1:
        n_features = V.shape[1]
        U = np.sort(V, axis=1)[:, ::-1]
        z = np.ones(len(V)) * z
        cssv = np.cumsum(U, axis=1) - z[:, np.newaxis]
        ind = np.arange(n_features) + 1
        cond = U - cssv / ind > 0
        rho = np.count_nonzero(cond, axis=1)
        theta = cssv[np.arange(len(V)), rho - 1] / rho
        return np.maximum(V - theta[:, np.newaxis], 0)

    elif axis == 0:
        return projection_simplex(V.T, z, axis=1).T

    else:
        V = V.ravel().reshape(1, -1)
        return projection_simplex(V, z, axis=1).ravel()


import cvxpy as cp


def proj_simplex_cvx(V, a):
    matrix = V.copy()
    n, m = V.shape
    eps = 1e-15
    P = cp.Variable((n, m))
    u = np.ones((n, 1))
    constraints = [P >= 0, cp.matmul(P, u) <= a[:, None]]
    objective = cp.Minimize(cp.sum_squares(matrix - P))
    prob = cp.Problem(objective, constraints)
    prob.solve(
        solver=cp.CLARABEL,
        max_iter=int(1e4),
        tol_gap_abs=eps,
        tol_feas=eps,
        tol_gap_rel=eps,
        tol_infeas_abs=eps,
        tol_infeas_rel=eps,
        tol_ktratio=eps,
        # verbose=True,
    )
    return P.value


def proj_adapted_err(V, a):
    V_pos = V * (V >= 0)
    mask = V_pos @ np.ones(V.shape[0]) > a
    if np.sum(mask) > 0:
        V_proj = projection_simplex(V_pos.copy()[mask], a[mask], axis=1)
        V_pos[mask] = V_proj
        return V_pos
    return V_pos


def pdhg_transport_one_relaxed_proj(
    C,
    a,
    b,
    alpha,
    tau=0.1,
    sigma=0.1,
    niter=1000,
    verbose=False,
    P0=None,
    v0=None,
    tol=1e-14,
    proj_ineq=False,
):
    """
    PDHG pour :  min_{P>=0, P1=a} <P,C> + (1/(2α)) ||P^T1 - b||^2
    """
    n, m = C.shape
    if P0 is not None:
        P = P0.copy()
    else:
        P = np.zeros((n, m))
    if v0 is not None:
        v = v0.copy()

    else:
        v = np.zeros(n)
    u = np.zeros(n)
    # print(u, v)

    P_prev = np.zeros_like(P)

    obj_values = []
    kkt_conds = []
    # u = None

    for k in range(niter):
        # Extrapolation (optionnelle)
        kkt_cond = 0  # kkt(P, None, v, C, a, b, alpha, proj=True)
        kkt_conds.append(kkt_cond)
        if k % 100 == 0:
            print(k)
        if kkt_cond + 3 * tol < tol:
            print(k)
            break
        P_bar = 2 * P - P_prev if k > 0 else P.copy()

        # Dual updates
        col_sum = P_bar.T @ np.ones(n)

        # col_sum = P.T @ np.ones(n)
        factor = 1 / (1 + sigma * alpha)
        v = factor * (v + sigma * (b - col_sum))

        # Primal update (avec projection sur P >= 0)
        # if not proj_ineq:
        #     grad = C - v[None, :]
        #     P_next = projection_simplex(P - tau * grad, a, axis=1)
        grad = C - v[None, :]
        # print(a)
        # print(P - tau * grad)
        # P_next_prime = proj_simplex_cvx((P - tau * grad).copy(), a)
        P_next = proj_adapted_err((P - tau * grad).copy(), a)
        # print(a)
        # print("P_next")
        # print(P_next)

        # if np.abs(P_next - P_next_prime).max() > 1e-4:
        #     print((P - tau * grad).tolist())
        #     print(np.abs(P_next - P_next_prime).max())
        #     raise ValueError
        # Sauvegarde
        P_prev = P
        P = P_next

        # Évaluation de l'objectif
        obj = np.sum(P * C) + 0.5 / 1 * np.sum((P.T @ np.ones(n) - b) ** 2)
        obj_values.append(obj)

        if verbose and (k % 1000 == 0 or k == niter - 1):
            print(f"Iter {k:4d} | Obj = {obj:.2e} | KKT = {kkt_cond:.2e}")

    print("k", k, f"KKT = {0:.2e}")
    print("FO", np.abs(P.T @ np.ones(n) - (b - v)).max())

    return P, u, v, np.array(obj_values), kkt_conds


In [566]:
tp_graph_prime, u_prime, v_prime, _, _ = pdhg_transport_one_relaxed_proj(
    C,
    a,
    b,
    tau,
    tau=(2 * n) ** (-0.5),
    sigma=(2 * n) ** (-0.5),
    P0=tp_graph.copy(),
    # v0=g.copy(),
    niter=int(1e3),
    tol=1e-13,
    proj_ineq=True,
)
obj_prime = (tp_graph_prime * C).sum() + 0.5 / tau * np.sum(
    (tp_graph_prime.T @ np.ones(n) - b) ** 2
)

0
100
200
300
400
500
600
700
800
900
k 999 KKT = 0.00e+00
FO 5.551115123125783e-17


In [567]:
obj_prime - obj

np.float64(-6.990935608186533e-16)