In [None]:
# Library
import numpy as np
from scipy.sparse.linalg import cg
from numpy.linalg import norm
from scipy.io import loadmat

In [None]:
def gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    m = len(z)
    H = Phi.T @ X @ Phi
    
    # Compute g_gamma
    g_gamma = (Q @ gamma - z) / (2 * lambda_2) - np.diag(H)
    
    # Compute g_X 
    Phi_diag_gamma_Phi_T = Phi @ np.diag(gamma) @ Phi.T
    g_X = - (Phi_diag_gamma_Phi_T + lambda_1 * np.eye(m))
    
    return g_gamma, g_X

In [None]:
def T_operator(P_positive, P_negative, S, mu, epsilon):
    '''
    Given a matrix S, this function returns the T operator evaluated at S, 
    using the decomposition given in page 9.
    '''
    U = P_positive.T @ S
    G = P_positive @ (1/(2*mu) * (U @ P_positive) @ P_positive.T + epsilon * (U @ P_negative) @ P_negative.T)
    return G + G.T

In [None]:
def build_A_matrix(Phi):
    '''
    Given a matrix Phi, this function constructs the matrix A as given in page 8 of the article.
    '''
    n = Phi.shape[1]

    # Initialize the A matrix with appropriate size
    A = np.zeros((n, n*n))

    # Build the L matrix
    for i in range(n):
        col_i = Phi[:, i]  # Select column i of Phi
        A[i, :] = np.kron(col_i.T, col_i.T)
    return A


In [None]:
def residue(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    r_gamma, g_X = gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2)
    X_new = X - g_X
    D, V = np.linalg.eigh(X_new)
    X_new = V @ np.maximum(D, 0) @ V.T
    r_X = X - X_new
    return r_gamma, r_X

In [None]:
def SSN_step(r_1, r_2, gamma, X, mu, Q, Phi, lambda_1, lambda_2, A):
    
    m = len(gamma)
    Z = X - (Phi @ np.diag(gamma) @ Phi.T + lambda_1 * np.eye(m)) 
    sigma, P = np.linalg.eigh(Z)
    alpha_pos = sigma[sigma > 0]
    alpha_neg = sigma[sigma <= 0] 
    
    # Select corresponding eigenvectors
    P_positive = P[:, sigma > 0]
    P_negative = P[:, sigma <= 0]

    # Construction of Omega matrix
    n = len(alpha_pos)
    Omega = np.zeros((m, m)) 
    Omega[:n, :n] = 1 
    eta_ = np.array([[a / (b-a) for b in alpha_neg] for a in alpha_pos])
    Omega[:n, n:] = eta_ 
    Omega[n:, :n] = eta_.T 
  
    # Construction of Gamma and L
    Gamma = np.diag(np.ravel(Omega.T)) 
    L = Gamma / (mu+1-Gamma)
  
    # Construction of Psi matrix from Omega
    Psi = np.zeros((m, m))
    Epsilon = eta_/(mu + 1 - eta_)
    Psi[:n, :n] = (1/mu) * np.ones((n, n))
    Psi[:n, n:] = Epsilon 
    Psi[n:, :n] = Epsilon.T 

    # Compute the value of T operator at r_2 and add r_2
    T_r = r_2 + T_operator(P_positive, P_negative, r_2, mu, Epsilon)  

    # Calculation of a_1 and a_2
    a_1 = -r_1 - (1 / (1 + mu)) * np.diag(Phi.T @ T_r @ Phi)
    a_2 = -r_2

    # Solving the linear equations system
    P_tilde = np.kron(P,P) 
    T_k = P_tilde @ L @ P_tilde.T 
    M = Q/(2 * lambda_2) + mu * np.eye(m) + A @ T_k @ A.T 
    a_tilde_1 = cg(M , a_1)[0] 
    T_a_2 = T_operator(P_positive, P_negative, a_2, mu, Epsilon)
    a_tilde_2 = 1/(mu + 1) * ( a_2 + T_a_2)
  
    # Calculation of delta_omega_1 and delta_omega_2
    delta_omega_1 = a_tilde_1
    Phi_star_a_tilde_1 = Phi @ np.diag(a_tilde_1) @ Phi.T
    T_Phi_star = T_operator(P_positive, P_negative, Phi_star_a_tilde_1, mu, Epsilon)
    delta_omega_2 = a_tilde_2 - T_Phi_star

    return delta_omega_1, delta_omega_2

In [None]:
def SSN(data: dict, lambda_1: float, lambda_2: float, tau: float, kappa: float, alpha_1: float, alpha_2: float,
        beta_0: float, beta_1: float, beta_2: float, theta_lower: float,  theta_upper: float, eta: float):

    
    #Verify if the user has provide credible hyperparameters
    valid_hyperparameters = (alpha_2 >= alpha_1) and \
                        (alpha_1 > 0) and \
                       (0 < beta_0 < 1) and \
                        (0 < beta_1 < 1) and \
                        (beta_2 > 1) and \
                        (theta_lower > 0) and \
                        (theta_upper > 0) and \
                        (lambda_1 > 0) and \
                        (lambda_2 > 0)  
    valid_hyperparameters = True
   
    if not valid_hyperparameters:
        print("Invalid hyperparameters provided.")
        return None
        
    else :
        # Informations from data
        Phi = data['Phi']  
        m = len(data['M'])
        Q = data['KX1'] + data['KY1']
        z = np.mean(data['KX2'], axis=1) + np.mean(data['KY2'], axis=1) - 2 * lambda_2 * data['M']
        A = build_A_matrix(Phi)
        
        # Initialization;,
        nIter = 50
        w_gamma = np.ones(m) / m   
        w_X = W_X = np.random.randn(m, m) @ np.random.randn(m, m).T
        v_gamma = w_gamma
        v_X = W_X
        theta = 0.01
        
        for _ in range(nIter):
            
            # One-step EG to update v: Line 4 algo
            
            # compute mid the step 
            g_gamma, g_X = gradient(v_gamma, v_X, Phi, Q, z, lambda_1, lambda_2)
            v_gamma_mid = v_gamma - eta * g_gamma
            v_X_mid = v_X + eta * g_X
            # compute the extra step
            g_gamma, g_X = gradient(v_gamma_mid, v_X_mid, Phi, Q, z, lambda_1, lambda_2)
            v_gamma = v_gamma - eta * g_gamma
            v_X = v_X + eta * g_X
            
            # One-step SSN to update w: Line 5 and 6
            w_r_1, w_r_2 = residue(w_gamma, w_X, Phi, Q, z, lambda_1, lambda_2)
            mu = theta * (np.linalg.norm(w_r_1) + np.linalg.norm(w_r_2, 'fro'))
            delta_1, delta_2 = SSN_step(w_r_1, w_r_2, w_gamma, w_X, mu, Q, Phi, lambda_1, lambda_2, A)
            
            # Compute w_title: line 7 algo
            w_tilde_1 = w_gamma + delta_1
            w_tilde_2 = w_X + delta_2
            
            # Update theta in the adaptive manner: line 8 algo
            r_1, r_2 = residue(w_tilde_1, w_tilde_2, Phi, Q, z, lambda_1, lambda_2)
            r = np.vstack((r_1.reshape(-1, 1), np.ravel(r_2).reshape(-1, 1)))
            delta = np.vstack((delta_1.reshape(-1, 1), np.ravel(delta_2).reshape(-1, 1)))
            rho = - r.T @ delta
            delta_norm = np.linalg.norm(delta) ** 2
            
            if rho >= alpha_2 * delta_norm :
                theta = max(theta_lower, beta_0 * theta)
            elif rho >= alpha_1 * delta_norm :
                theta = beta_1 * theta
            else:
                theta = min(theta_upper, beta_2 * theta)
                
            # Ligne 9
            if (np.linalg.norm(r_1) + np.linalg.norm(r_2,  ord ='fro')) <= (np.linalg.norm(v_gamma) + np.linalg.norm(v_X,  ord ='fro')):
                w_gamma = w_tilde_1
                w_X = w_tilde_2
            else :
                w_gamma = v_gamma
                w_X = v_X
            
    return w_gamma, w_X

# Experiments

## 1. Synthetic data

In [None]:
%matplotlib inline

import numpy as np
from scipy.stats import norm

import ot

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec


### Quasi-random sequences
from sobol_seq import i4_sobol_generate

### Kernel SoS OT functions
from optim import interior_point
from utils import make_kernels, transport_cost, transport_1D, potential_1D

### Sample data

In [None]:
nfill = 100
nsamples = 100

#### Sample from mu and nu

In [None]:
n = 2000

mu1 = [0.7, 0.3]
mu2 = [0.2, 0.5, 0.75]


t1 = [0.4, 0.6]
t2 = [0.2, 0.2, 0.6]

x = np.linspace(0, 1, n)
    
r_tmp = []    
for mode in mu1:
    r_tmp.append(norm.pdf(x,mode, 0.09))
        
c_tmp = []    
for mode in mu2:
    c_tmp.append(norm.pdf(x,mode, 0.075))
        
mu = np.dot(t1,r_tmp)
nu = np.dot(t2,c_tmp)

In [None]:
np.random.seed(123)

u1 = np.random.rand(nsamples)
u2 = np.random.rand(nsamples)

X = np.zeros(nsamples)
Y = np.zeros(nsamples)

for i in range(nsamples):
    if u1[i] < t1[0]:
        X[i] = np.random.randn() * .1 + mu1[0]
    else:
        X[i] = np.random.randn() * .1 + mu1[1]
    if u2[i] < t2[0]:
        Y[i] = np.random.randn() * .075 + mu2[0]
    elif u2[i] < t2[1] + t2[0]:
        Y[i] = np.random.randn() * .075 + mu2[1]
    else:
        Y[i] = np.random.randn() * .075 + mu2[2]

In [None]:
x = np.linspace(0, 1, n)

f, ax = plt.subplots()

ax.plot(x, mu, label = 'mu density')
ax.plot(x, nu, label = 'nu density')

ax.scatter(X, mu[(n * X).astype(int)], label = 'mu samples')
ax.scatter(Y, nu[np.minimum((n * Y).astype(int), n-1)], label = 'nu samples')


plt.legend()
plt.show()

In [None]:
### Sobol quasi-random samples to fill the space X x Y. 
sob = i4_sobol_generate(2 , nfill, skip = 3000)


## Add some points in the corners (optional)
sob = np.insert(sob, 0, np.array([1e-2, 1e-2]))
sob = np.insert(sob, 0, np.array([1-1e-2, 1-1e-2]))
sob = np.insert(sob, 0, np.array([1e-2, 1-1e-2]))
sob = np.insert(sob, 0, np.array([1.-1e-2, 1e-2]))

sob = sob.reshape(-1, 2)[:-4 , :]


X_fill = sob[:, :1]
Y_fill = sob[:, 1:]  

In [None]:
plt.scatter(X_fill, Y_fill)

In [None]:
kernel = 'gaussian'
l = .1
Phi, M, Kx1, Ky1, Kx2, Ky2, Kx3, Ky3 = make_kernels(X[:, None], Y[:, None], X_fill, Y_fill, l=l, kernel = kernel)

In [None]:
## Regularization parameters

lbda_1 = 1 / nfill
lbda_2 =  1 / nsamples


## Optimization problem parameters

eps_start = nfill
eps_end = 1e-8

tau = 1e-8

niter = 1000

In [None]:
G, eps = interior_point(M, Phi, Kx1, Kx2, Kx3, Ky1, Ky2, Ky3,lbda_1=lbda_1, lbda_2=lbda_2,
                      eps_start=eps_start, eps_end=eps_end, eps_div = 2,
                      tau=tau, niter=niter,
                      verbose=True, report_interval=100)

kernel_sos_ot = transport_cost(G, Kx2, Kx3, Ky2, Ky3, lbda_2, product_sampling=False)

In [None]:
print(kernel_sos_ot)

In [None]:
### Compute OT from samples
x = np.linspace(0., 1., n)

M_ot = ((x[:, None] - x)**2) / 2
P, log = ot.emd(mu / mu.sum(), nu / nu.sum(), M_ot, log = True)
sampled_ot = (P * M_ot).sum()

In [None]:

print(f"Plugin estimator (n={n}): {sampled_ot:.3e}\nKernel SoS estimator (n={nsamples}, l={nfill}): {kernel_sos_ot:.3e}")

In [None]:
import matplotlib.gridspec as gridspec

plt.clf()

fig = plt.figure(figsize=(5, 5))

gs = gridspec.GridSpec(3, 3, wspace=0.0, hspace=0.0)

xp, yp = np.where(P > 0)

na, nb = P.shape

xa = np.arange(na)
xb = np.arange(nb)

Txa = np.argmax(P, 1)


ax1 = plt.subplot(gs[0, 1:])
ax1.plot(xa, mu,  'r', label='Source distribution')
ax1.fill_between(xa, mu, color = 'red', alpha=.1)
plt.ylim(ymin=0)
plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)
ax1.axis('off')

ax2 = plt.subplot(gs[1:, 0])
ax2.plot((nu), xb, 'b', label='Target distribution')
ax2.fill_between((nu)[:], xb[:], color = 'blue', interpolate=True, alpha = .1)
ax2.set_xlim(xmin=0)
ax2.invert_xaxis()

ax2.axis('off')


ax3 = plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
ax3.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)

ax3.plot(xa, Txa , linewidth = 2, color = 'black', ls ='--', label='True map')

x = np.linspace(0, 1, len(xa)) 
TX = transport_1D(x, G, X, X_fill, lbda_2, kernel=kernel, l=l) 

ax3.plot(xa, TX * na, color = 'r', lw =2, label = 'Inferred map')

ax1.scatter(X * na, mu[(n * X).astype(int)], label = 'mu samples', marker = 'x', color = 'r', s=50)
ax2.scatter( nu[np.minimum((n * Y).astype(int), n-1)], Y * na, 
            label = 'nu samples', marker = 'x', color = 'b', s = 50)


ax3.scatter(sob[:, 0] * na, sob[:, 1] * na, color = 'violet', s =20, label = 'Filling samples')

    
plt.tight_layout()
plt.legend(fontsize = 14)

plt.show()

In [None]:
data = {}
data["M"] = M
data["Phi"] = Phi
data["KX1"] = Kx1
data["KY1"] = Ky1
data["KX2"] = Kx2
data["KY2"] = Ky2

In [None]:
w_gamma, w_X = SSN(data, lambda_1 = lbda_1, lambda_2 = lbda_2, tau = tau, kappa = 0.4 , alpha_1 = 10e-4 , alpha_2 = 1.0,beta_0 = 0.5, beta_1 = 1.9, beta_2 = 5, theta_lower = 0.1,  theta_upper = 5, eta = 0.002)

In [None]:
len(w_gamma)

In [None]:
import matplotlib.gridspec as gridspec

plt.clf()

fig = plt.figure(figsize=(5, 5))

gs = gridspec.GridSpec(3, 3, wspace=0.0, hspace=0.0)

xp, yp = np.where(P > 0)

na, nb = P.shape

xa = np.arange(na)
xb = np.arange(nb)

Txa = np.argmax(P, 1)


ax1 = plt.subplot(gs[0, 1:])
ax1.plot(xa, mu,  'r', label='Source distribution')
ax1.fill_between(xa, mu, color = 'red', alpha=.1)
plt.ylim(ymin=0)
plt.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)
ax1.axis('off')

ax2 = plt.subplot(gs[1:, 0])
ax2.plot((nu), xb, 'b', label='Target distribution')
ax2.fill_between((nu)[:], xb[:], color = 'blue', interpolate=True, alpha = .1)
ax2.set_xlim(xmin=0)
ax2.invert_xaxis()

ax2.axis('off')


ax3 = plt.subplot(gs[1:, 1:], sharex=ax1, sharey=ax2)
ax3.tick_params(axis='both', which='both', bottom=False, top=False, labelbottom=False, right=False, left=False, labelleft=False)

ax3.plot(xa, Txa , linewidth = 2, color = 'black', ls ='--', label='True map')

x = np.linspace(0, 1, len(xa)) 
TX = transport_1D(x, w_gamma, X, X_fill, lbda_2, kernel=kernel, l=l) 

ax3.plot(xa, TX * na, color = 'r', lw =2, label = 'Inferred map')

ax1.scatter(X * na, mu[(n * X).astype(int)], label = 'mu samples', marker = 'x', color = 'r', s=50)
ax2.scatter( nu[np.minimum((n * Y).astype(int), n-1)], Y * na, 
            label = 'nu samples', marker = 'x', color = 'b', s = 50)


ax3.scatter(sob[:, 0] * na, sob[:, 1] * na, color = 'violet', s =20, label = 'Filling samples')

    
plt.tight_layout()
plt.legend(fontsize = 14)

plt.show()

In [None]:
kernel_sos_ot = transport_cost(w_gamma, Kx2, Kx3, Ky2, Ky3, lbda_2, product_sampling=False)

# JAX

In [None]:
# Library
import jax.numpy as jnp
from jax.scipy.sparse.linalg import cg
from jax.numpy.linalg import norm
import jax

def gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    m = len(z)
    H = jnp.dot(Phi.T, jnp.dot(X, Phi))
    
    # Compute g_gamma
    g_gamma = (jnp.dot(Q, gamma) - z) / (2 * lambda_2) - jnp.diag(H)
    
    # Compute g_X
    Phi_diag_gamma_Phi_T = jnp.dot(Phi, jnp.dot(jnp.diag(gamma), Phi.T))
    g_X = - (Phi_diag_gamma_Phi_T + lambda_1 * jnp.eye(m))
    
    return g_gamma, g_X

def T_operator(P_positive, P_negative, S, mu, epsilon):
    '''
    Given a matrix S, this function returns the T operator evaluated at S, 
    using the decomposition given in page 9.
    '''
    U = jnp.transpose(P_positive) @ S
    G = P_positive @ (1 / (2 * mu) * (U @ P_positive) @ jnp.transpose(P_positive) + epsilon * (U @ P_negative) @ jnp.transpose(P_negative))
    return G + jnp.transpose(G)

def build_A_matrix(Phi):
    '''
    Given a matrix Phi, this function constructs the matrix A as given in page 8 of the article.
    '''
    n = Phi.shape[1]

    # Initialize the A matrix with appropriate size
    A = jnp.zeros((n, n*n))

    # Build the L matrix
    for i in range(n):
        col_i = Phi[:, i]  # Select column i of Phi
        A = A.at[i, :].set(jnp.ravel(jnp.kron(col_i.T, col_i.T)))
    return A

def residue(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    r_gamma, g_X = gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2)
    X_new = X - g_X
    D, V = jnp.linalg.eigh(X_new)
    X_new = jnp.dot(V, jnp.dot(jnp.maximum(D, 0), V.T))
    r_X = X - X_new
    return r_gamma, r_X

def SSN_step(r_1, r_2, gamma, X, mu, Q, Phi, lambda_1, lambda_2, A):
    
    m = len(gamma)
    Z = X - (jnp.dot(Phi, jnp.dot(jnp.diag(gamma), Phi.T)) + lambda_1 * jnp.eye(m)) 
    sigma, P = jnp.linalg.eigh(Z)
    alpha_pos = sigma[sigma > 0]
    alpha_neg = sigma[sigma <= 0] 
    
    # Select corresponding eigenvectors
    P_positive = P[:, sigma > 0]
    P_negative = P[:, sigma <= 0]

    # Construction of Omega matrix
    n = len(alpha_pos)
    Omega = jnp.zeros((m, m)) 
    Omega = Omega.at[:n, :n].set(1) 
    eta_ = jnp.array([[a / (b-a) for b in alpha_neg] for a in alpha_pos])
    Omega = Omega.at[:n, n:].set(eta_) 
    Omega = Omega.at[n:, :n].set(eta_.T) 
  
    # Construction of Gamma and L
    Gamma = jnp.diag(jnp.ravel(Omega.T)) 
    L = Gamma / (mu+1-Gamma)
  
    # Construction of Psi matrix from Omega
    Psi = jnp.zeros((m, m))
    Epsilon = eta_/(mu + 1 - eta_)
    Psi = Psi.at[:n, :n].set((1/mu) * jnp.ones((n, n)))
    Psi = Psi.at[:n, n:].set(Epsilon) 
    Psi = Psi.at[n:, :n].set(Epsilon.T) 

    # Compute the value of T operator at r_2 and add r_2
    #T_r = r_2 + T_operator(P_positive, P_negative, r_2, mu, Epsilon)  
    T_r = r_2 + P @ (Psi * (jnp.dot(jnp.dot(P.T, r_2), P))) @ P.T
    # Calculation of a_1 and a_2
    a_1 = -r_1 - (1 / (1 + mu)) * jnp.diag(jnp.dot(Phi.T, jnp.dot(T_r, Phi)))
    a_2 = -r_2

    # Solving the linear equations system
    P_tilde = jnp.kron(P,P) 
    T_k = jnp.dot(P_tilde, jnp.dot(L, P_tilde.T)) 
    M = Q/(2 * lambda_2) + mu * jnp.eye(m) + jnp.dot(jnp.dot(A, T_k), A.T) 
    a_tilde_1 = cg(M , a_1)[0] 
    #T_a_2 = T_operator(P_positive, P_negative, a_2, mu, Epsilon)
    T_a_2 = P @ (Psi * (jnp.dot(jnp.dot(P.T, a_2), P))) @ P.T
    a_tilde_2 = 1/(mu + 1) * ( a_2 + T_a_2)
  
    # Calculation of delta_omega_1 and delta_omega_2
    delta_omega_1 = a_tilde_1
    Phi_star_a_tilde_1 = jnp.dot(Phi, jnp.dot(jnp.diag(a_tilde_1), Phi.T))
    T_Phi_star = T_operator(P_positive, P_negative, Phi_star_a_tilde_1, mu, Epsilon)
    T_Phi_star = P @ (Psi * (jnp.dot(jnp.dot(P.T, T_Phi_star), P))) @ P.T
    delta_omega_2 = a_tilde_2 - T_Phi_star

    return delta_omega_1, delta_omega_2

In [None]:
import jax.numpy as jnp
from jax import random
from tqdm import tqdm

# Defining the SSN function
def SSN(data, lambda_1, lambda_2, tau, kappa, alpha_1, alpha_2,
        beta_0, beta_1, beta_2, theta_lower, theta_upper, eta):
    
    # Verify if the user has provided credible hyperparameters
    valid_hyperparameters = (alpha_2 >= alpha_1) and \
                        (alpha_1 > 0) and \
                        (0 < beta_0 < 1) and \
                        (0 < beta_1 < 1) and \
                        (beta_2 > 1) and \
                        (theta_lower > 0) and \
                        (theta_upper > 0) and \
                        (lambda_1 > 0) and \
                        (lambda_2 > 0)
    valid_hyperparameters = True
    if not valid_hyperparameters:
        print("Invalid hyperparameters provided.")
        return None

    else:
        # Information from data
        Phi = data['Phi']
        m = len(data['M'])
        Q = data['KX1'] + data['KY1']
        z = jnp.mean(data['KX2'], axis=1) + jnp.mean(data['KY2'], axis=1) - 2 * lambda_2 * data['M']
        A = build_A_matrix(Phi)

        # Initialization
        nIter = 300
        w_gamma = jnp.ones(m) / m
        key = random.PRNGKey(0)
        w_X = jnp.dot(random.normal(key, (m, m)), random.normal(key, (m, m)).T)
        v_gamma = w_gamma
        v_X = w_X
        theta = 0.5

        def update_step(v_gamma, v_X, w_gamma, w_X, theta):
            # One-step EG to update v: Line 4 algo

            # compute mid the step 
            g_gamma, g_X = gradient(v_gamma, v_X, Phi, Q, z, lambda_1, lambda_2)
            v_gamma_mid = v_gamma - eta * g_gamma
            v_X_mid = v_X + eta * g_X
            # compute the extra step
            g_gamma, g_X = gradient(v_gamma_mid, v_X_mid, Phi, Q, z, lambda_1, lambda_2)
            v_gamma = v_gamma - eta * g_gamma
            v_X = v_X + eta * g_X

            # One-step SSN to update w: Line 5 and 6
            w_r_1, w_r_2 = residue(w_gamma, w_X, Phi, Q, z, lambda_1, lambda_2)
            mu = theta * (jnp.linalg.norm(w_r_1) + jnp.linalg.norm(w_r_2, ord='fro'))
            delta_1, delta_2 = SSN_step(w_r_1, w_r_2, w_gamma, w_X, mu, Q, Phi, lambda_1, lambda_2, A)

            # Compute w_title: line 7 algo
            w_tilde_1 = w_gamma + delta_1
            w_tilde_2 = w_X + delta_2

            # Update theta in the adaptive manner: line 8 algo
            r_1, r_2 = residue(w_tilde_1, w_tilde_2, Phi, Q, z, lambda_1, lambda_2)
            r = jnp.vstack((jnp.reshape(r_1, (-1, 1)), jnp.ravel(r_2).reshape(-1, 1)))
            delta = jnp.vstack((jnp.reshape(delta_1, (-1, 1)), jnp.ravel(delta_2).reshape(-1, 1)))
            rho = - jnp.matmul(jnp.transpose(r), delta)
            delta_norm = jnp.linalg.norm(delta, ord='fro') ** 2

            if rho >= alpha_2 * delta_norm :
                theta = max(theta_lower, beta_0 * theta)
            elif rho >= alpha_1 * delta_norm :
                theta = beta_1 * theta
            else:
                theta = min(theta_upper, beta_2 * theta)

            # Ligne 9
            if (jnp.linalg.norm(r_1) + jnp.linalg.norm(r_2, ord='fro')) <= (jnp.linalg.norm(v_gamma) + jnp.linalg.norm(v_X, ord='fro')):
                w_gamma = w_tilde_1
                w_X = w_tilde_2
            else:
                w_gamma = v_gamma
                w_X = v_X

            return v_gamma, v_X, w_gamma, w_X, theta
        
        for _ in tqdm(range(nIter)):
            # Execute the update step
           v_gamma, v_X, w_gamma, w_X, theta = update_step(v_gamma, v_X, w_gamma, w_X, theta)
           r_1, r_2 = residue(w_gamma, w_X, Phi, Q, z, lambda_1, lambda_2)
           norm_r = (jnp.linalg.norm(r_1) + jnp.linalg.norm(r_2, ord='fro'))
           print(norm_r)
           if  norm_r  < 0.005:
               break

        return w_gamma, w_X


In [None]:
w_gamma, w_X = SSN(data, lambda_1 = lbda_1, lambda_2 = lbda_2, tau = tau, kappa = 0.4 , alpha_1 = 10e-6 , alpha_2 = 1.0,beta_0 = 0.5, beta_1 = 1.9, beta_2 = 5, theta_lower = 0.1,  theta_upper = 5, eta = 0.002)