In [1]:
# Library
import numpy as np
from scipy.sparse.linalg import cg
from numpy.linalg import norm
from scipy.io import loadmat

In [2]:
def gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    m = len(z)
    H = Phi.T @ X @ Phi
    
    # Compute g_gamma
    g_gamma = (Q @ gamma - z) / (2 * lambda_2) - np.diag(H)
    
    # Compute g_X
    Phi_diag_gamma_Phi_T = Phi @ np.diag(gamma) @ Phi.T
    g_X = - (Phi_diag_gamma_Phi_T + lambda_1 * np.eye(m))
    
    return g_gamma, g_X

In [3]:
def T_operator(P_positive, P_negative, S, mu, epsilon):
    '''
    Given a matrix S, this function returns the T operator evaluated at S, 
    using the decomposition given in page 9.
    '''
    U = P_positive.T @ S
    G = P_positive @ (1/(2*mu) * (U @ P_positive) @ P_positive.T + epsilon * (U @ P_negative) @ P_negative.T)
    return G + G.T

In [4]:
def build_A_matrix(Phi):
    '''
    Given a matrix Phi, this function constructs the matrix A as given in page 8 of the article.
    '''
    n = Phi.shape[1]

    # Initialize the A matrix with appropriate size
    A = np.zeros((n, n*n))

    # Build the L matrix
    for i in range(n):
        col_i = Phi[:, i]  # Select column i of Phi
        A[i, :] = np.kron(col_i.T, col_i.T)
    return A


In [5]:
def residue(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    r_gamma, g_X = gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2)
    X_new = X - g_X
    D, V = np.linalg.eigh(X_new)
    X_new = V @ np.maximum(D, 0) @ V.T
    r_X = X - X_new
    return r_gamma, r_X

In [6]:
def SSN_step(r_1, r_2, gamma, X, mu, Q, Phi, lambda_1, lambda_2, A):
    
    m = len(gamma)
    Z = X - (Phi @ np.diag(gamma) @ Phi.T + lambda_1 * np.eye(m)) 
    sigma, P = np.linalg.eigh(Z)
    alpha_pos = sigma[sigma > 0]
    alpha_neg = sigma[sigma <= 0] 
    
    # Select corresponding eigenvectors
    P_positive = P[:, sigma > 0]
    P_negative = P[:, sigma <= 0]

    # Construction of Omega matrix
    n = len(alpha_pos)
    Omega = np.zeros((m, m)) 
    Omega[:n, :n] = 1 
    eta_ = np.array([[a / (b-a) for b in alpha_neg] for a in alpha_pos])
    Omega[:n, n:] = eta_ 
    Omega[n:, :n] = eta_.T 
  
    # Construction of Gamma and L
    Gamma = np.diag(np.ravel(Omega.T)) 
    L = Gamma / (mu+1-Gamma)
  
    # Construction of Psi matrix from Omega
    Psi = np.zeros((m, m))
    Epsilon = eta_/(mu + 1 - eta_)
    Psi[:n, :n] = (1/mu) * np.ones((n, n))
    Psi[:n, n:] = Epsilon 
    Psi[n:, :n] = Epsilon.T 

    # Compute the value of T operator at r_2 and add r_2
    T_r = r_2 + T_operator(P_positive, P_negative, r_2, mu, Epsilon)  

    # Calculation of a_1 and a_2
    a_1 = -r_1 - (1 / (1 + mu)) * np.diag(Phi.T @ T_r @ Phi)
    a_2 = -r_2

    # Solving the linear equations system
    P_tilde = np.kron(P,P) 
    T_k = P_tilde @ L @ P_tilde.T 
    M = Q/(2 * lambda_2) + mu * np.eye(m) + A @ T_k @ A.T 
    a_tilde_1 = cg(M , a_1, atol=1e-5)[0] 
    T_a_2 = T_operator(P_positive, P_negative, a_2, mu, Epsilon)
    a_tilde_2 = 1/(mu + 1) * ( a_2 + T_a_2)
  
    # Calculation of delta_omega_1 and delta_omega_2
    delta_omega_1 = a_tilde_1
    Phi_star_a_tilde_1 = Phi @ np.diag(a_tilde_1) @ Phi.T
    T_Phi_star = T_operator(P_positive, P_negative, Phi_star_a_tilde_1, mu, Epsilon)
    delta_omega_2 = a_tilde_2 - T_Phi_star

    return delta_omega_1, delta_omega_2

In [7]:
def SSN(data: dict, lambda_1: float, lambda_2: float, tau: float, kappa: float, alpha_1: float, alpha_2: float,
        beta_0: float, beta_1: float, beta_2: float, theta_lower: float,  theta_upper: float, eta: float):

    
    #Verify if the user has provide credible hyperparameters
    valid_hyperparameters = (alpha_2 >= alpha_1) and \
                        (alpha_1 > 0) and \
                        (0 < beta_0 < 1) and \
                        (0 < beta_1 < 1) and \
                        (beta_2 > 1) and \
                        (theta_lower > 0) and \
                        (theta_upper > 0) and \
                        (lambda_1 > 0) and \
                        (lambda_2 > 0)   
    
    if not valid_hyperparameters:
        print("Invalid hyperparameters provided.")
        return None
        
    else :
        # Informations from data
        Phi = data['Phi']  
        m = len(data['M'])
        Q = data['KX1'] + data['KY1']
        z = np.mean(data['KX2'], axis=0) + np.mean(data['KY2'], axis=0) - 2 * lambda_2 * data['M']
        A = build_A_matrix(Phi)
        
        # Initialization;,
        nIter = 300
        w_gamma = np.ones(m) / m   
        W_X = W_X = np.random.randn(m, m) @ np.random.randn(m, m).T
        v_gamma = w_gamma
        v_X = W_X
        theta = 0.5
        
        for _ in range(nIter):
            
            # One-step EG to update v: Line 4 algo
            
            # compute mid the step 
            g_gamma, g_X = gradient(v_gamma, v_X, Phi, Q, z, lambda_1, lambda_2)
            v_gamma_mid = v_gamma - eta * g_gamma
            v_X_mid = v_X + eta * g_X
            # compute the extra step
            g_gamma, g_X = gradient(v_gamma_mid, v_X_mid, Phi, Q, z, lambda_1, lambda_2)
            v_gamma_mid = v_gamma - eta * g_gamma
            v_X_mid = v_X + eta * g_X
            
            # One-step SSN to update w: Line 5 and 6
            w_r_1, w_r_2 = residue(w_gamma, w_X, Phi, Q, z, lambda_1, lambda_2)
            mu = theta * (norm(w_r_1) + norm(w_r_2, 'fro'))
            delta_1, delta_2 = SSN_step(w_r_1, w_r_2, w_gamma, w_X, mu, Q, Phi, lambda_1, lambda_2, A)
            
            # Compute w_title: line 7 algo
            w_tilde_1 = w_gamma + delta_1
            w_tilde_2 = w_X + delta_2
            
            # Update theta in the adaptive manner: line 8 algo
            r_1, r_2 = residue(w_tilde_1, w_tilde_2, Phi, Q, z, lambda_1, lambda_2)
            r = np.vstack((r_1.reshape(-1, 1), np.ravel(r_2).reshape(-1, 1)))
            delta = np.vstack((delta_1.reshape(-1, 1), np.ravel(delta_2).reshape(-1, 1)))
            rho = - r.T @ delta
            delta_norm = norm(delta, ord ='fro') ** 2
            
            if rho >= alpha_2 * delta_norm :
                theta = max(theta_lower, beta_0 * theta)
            elif rho >= alpha_1 * delta_norm :
                theta = beta_1 * theta
            else:
                theta = min(theta_upper, beta_2 * theta)
                
            # Ligne 9
            if (norm(r_1) + norm(r_2,  ord ='fro')) <= (norm(v_gamma) + norm(v_X,  ord ='fro')):
                w_gamma = w_tilde_1
                w_X = w_tilde_2
            else :
                w_gamma = v_gamma
                w_X = v_X
            
        return w_gamma, w_X

In [8]:
# Charger les données à partir du fichier MATLAB
data = loadmat("X_perturbed_perturbed-drug-erlotinib.mat")

# JAX

In [9]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax.scipy.sparse.linalg import cg
from jax.numpy.linalg import norm

In [11]:
# Defining the gradient function
def gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    m = len(z)
    H = jnp.matmul(jnp.transpose(Phi), jnp.matmul(X, Phi))

    # Compute g_gamma
    g_gamma = (jnp.matmul(Q, gamma) - z) / (2 * lambda_2) - jnp.diag(H)

    # Compute g_X
    Phi_diag_gamma_Phi_T = jnp.matmul(Phi, jnp.matmul(jnp.diag(gamma), jnp.transpose(Phi)))
    g_X = - (Phi_diag_gamma_Phi_T + lambda_1 * jnp.eye(m))

    return g_gamma, g_X

In [12]:
# Defining the T_operator function
def T_operator(P_positive, P_negative, S, mu, epsilon):
    '''
    Given a matrix S, this function returns the T operator evaluated at S,
    using the decomposition given in page 9.
    '''
    U = jnp.matmul(jnp.transpose(P_positive), S)
    G = jnp.matmul(P_positive, (1/(2*mu) * jnp.matmul(U, P_positive) @ jnp.transpose(P_positive) + \
                                epsilon * jnp.matmul(U, P_negative) @ jnp.transpose(P_negative)))
    return G + jnp.transpose(G)

In [13]:
def build_A_matrix(Phi):
    '''
    Given a matrix Phi, this function constructs the matrix A as given in page 8 of the article.
    '''
    n = Phi.shape[1]

    # Initialize the A matrix with appropriate size
    A = jnp.zeros((n, n*n))

    # Build the A matrix
    for i in range(n):
        col_i = Phi[:, i]  # Select column i of Phi
        A = jax.ops.index_update(A, jax.ops.index[i, :], jnp.kron(jnp.transpose(col_i), jnp.transpose(col_i)))
        
    return A

In [14]:
# Defining the residue function
def residue(gamma, X, Phi, Q, z, lambda_1, lambda_2):
    r_gamma, g_X = gradient(gamma, X, Phi, Q, z, lambda_1, lambda_2)
    X_new = X - g_X
    D, V = jnp.linalg.eigh(X_new)
    X_new = jnp.matmul(V, jnp.matmul(jnp.maximum(D, 0), jnp.transpose(V)))
    r_X = X - X_new
    return r_gamma, r_X

In [15]:
# Defining the SSN_step function
def SSN_step(r_1, r_2, gamma, X, mu, Q, Phi, lambda_1, lambda_2, A):
    m = len(gamma)
    Z = X - (jnp.matmul(Phi, jnp.matmul(jnp.diag(gamma), jnp.transpose(Phi))) + lambda_1 * jnp.eye(m))
    sigma, P = jnp.linalg.eigh(Z)
    alpha_pos = sigma[sigma > 0]
    alpha_neg = sigma[sigma <= 0]

    # Select corresponding eigenvectors
    P_positive = P[:, sigma > 0]
    P_negative = P[:, sigma <= 0]

    # Construction of Omega matrix
    n = len(alpha_pos)
    Omega = jnp.zeros((m, m))
    Omega = jax.ops.index_update(Omega, jax.ops.index[:n, :n], 1)
    eta_ = jnp.array([[a / (b-a) for b in alpha_neg] for a in alpha_pos])
    Omega = jax.ops.index_update(Omega, jax.ops.index[:n, n:], eta_)
    Omega = jax.ops.index_update(Omega, jax.ops.index[n:, :n], jnp.transpose(eta_))

    # Construction of Gamma and L
    Gamma = jnp.diag(jnp.ravel(jnp.transpose(Omega)))
    L = Gamma / (mu+1-Gamma)

    # Construction of Psi matrix from Omega
    Psi = jnp.zeros((m, m))
    Epsilon = eta_/(mu + 1 - eta_)
    Psi = jax.ops.index_update(Psi, jax.ops.index[:n, :n], (1/mu) * jnp.ones((n, n)))
    Psi = jax.ops.index_update(Psi, jax.ops.index[:n, n:], Epsilon)
    Psi = jax.ops.index_update(Psi, jax.ops.index[n:, :n], Epsilon)

    # Compute the value of T operator at r_2 and add r_2
    T_r = r_2 + T_operator(P_positive, P_negative, r_2, mu, Epsilon)

    # Calculation of a_1 and a_2
    a_1 = -r_1 - (1 / (1 + mu)) * jnp.diag(jnp.matmul(jnp.transpose(Phi), jnp.matmul(T_r, Phi)))
    a_2 = -r_2

    # Solving the linear equations system
    P_tilde = jnp.kron(P,P)
    T_k = jnp.matmul(P_tilde, jnp.matmul(L, jnp.transpose(P_tilde)))
    M = Q/(2 * lambda_2) + mu * jnp.eye(m) + jnp.matmul(A, jnp.matmul(T_k, jnp.transpose(A)))
    a_tilde_1 = cg(M , a_1, atol=1e-5)[0]
    T_a_2 = T_operator(P_positive, P_negative, a_2, mu, Epsilon)
    a_tilde_2 = 1/(mu + 1) * ( a_2 + T_a_2)

    # Calculation of delta_omega_1 and delta_omega_2
    delta_omega_1 = a_tilde_1
    Phi_star_a_tilde_1 = jnp.matmul(Phi, jnp.matmul(jnp.diag(a_tilde_1), jnp.transpose(Phi)))
    T_Phi_star = T_operator(P_positive, P_negative, Phi_star_a_tilde_1, mu, Epsilon)
    delta_omega_2 = a_tilde_2 - T_Phi_star

    return delta_omega_1, delta_omega_2

In [19]:
# Defining the SSN function
def SSN(data: dict, lambda_1: float, lambda_2: float, tau: float, kappa: float, alpha_1: float, alpha_2: float,
        beta_0: float, beta_1: float, beta_2: float, theta_lower: float,  theta_upper: float, eta: float):
    
    # Verify if the user has provided credible hyperparameters
    valid_hyperparameters = (alpha_2 >= alpha_1) and \
                        (alpha_1 > 0) and \
                        (0 < beta_0 < 1) and \
                        (0 < beta_1 < 1) and \
                        (beta_2 > 1) and \
                        (theta_lower > 0) and \
                        (theta_upper > 0) and \
                        (lambda_1 > 0) and \
                        (lambda_2 > 0)

    if not valid_hyperparameters:
        print("Invalid hyperparameters provided.")
        return None

    else:
        # Information from data
        Phi = data['Phi']
        m = len(data['M'])
        Q = data['KX1'] + data['KY1']
        z = jnp.mean(data['KX2'], axis=0) + jnp.mean(data['KY2'], axis=0) - 2 * lambda_2 * data['M']
        A = build_A_matrix(Phi)

        # Initialization
        nIter = 300
        w_gamma = jnp.ones(m) / m
        w_X = random.randn(random.PRNGKey(0), m, m) @ random.randn(random.PRNGKey(0), m, m).T
        v_gamma = w_gamma
        v_X = w_X
        theta = 0.5

        @jit
        def update_step(v_gamma, v_X, w_gamma, w_X, theta):
            # One-step EG to update v: Line 4 algo

            # compute mid the step 
            g_gamma, g_X = gradient(v_gamma, v_X, Phi, Q, z, lambda_1, lambda_2)
            v_gamma_mid = v_gamma - eta * g_gamma
            v_X_mid = v_X + eta * g_X
            # compute the extra step
            g_gamma, g_X = gradient(v_gamma_mid, v_X_mid, Phi, Q, z, lambda_1, lambda_2)
            v_gamma = v_gamma - eta * g_gamma
            v_X = v_X + eta * g_X

            # One-step SSN to update w: Line 5 and 6
            w_r_1, w_r_2 = residue(w_gamma, w_X, Phi, Q, z, lambda_1, lambda_2)
            mu = theta * (norm(w_r_1) + norm(w_r_2, 'fro'))
            delta_1, delta_2 = SSN_step(w_r_1, w_r_2, w_gamma, w_X, mu, Q, Phi, lambda_1, lambda_2, A)

            # Compute w_title: line 7 algo
            w_tilde_1 = w_gamma + delta_1
            w_tilde_2 = w_X + delta_2

            # Update theta in the adaptive manner: line 8 algo
            r_1, r_2 = residue(w_tilde_1, w_tilde_2, Phi, Q, z, lambda_1, lambda_2)
            r = jnp.vstack((jnp.reshape(r_1, (-1, 1)), jnp.ravel(r_2).reshape(-1, 1)))
            delta = jnp.vstack((jnp.reshape(delta_1, (-1, 1)), jnp.ravel(delta_2).reshape(-1, 1)))
            rho = - jnp.matmul(jnp.transpose(r), delta)
            delta_norm = norm(delta, 'fro') ** 2

            if rho >= alpha_2 * delta_norm :
                theta = max(theta_lower, beta_0 * theta)
            elif rho >= alpha_1 * delta_norm :
                theta = beta_1 * theta
            else:
                theta = min(theta_upper, beta_2 * theta)

            # Ligne 9
            if (norm(r_1) + norm(r_2,  'fro')) <= (norm(v_gamma) + norm(v_X,  'fro')):
                w_gamma = w_tilde_1
                w_X = w_tilde_2
            else:
                w_gamma = v_gamma
                w_X = v_X

            return v_gamma, v_X, w_gamma, w_X, theta
        
        # JIT compile the update_step function
        update_step_jit = jit(update_step)

        for _ in range(nIter):
            # Execute the update step with JIT compilation
            v_gamma, v_X, w_gamma, w_X, theta = update_step_jit(v_gamma, v_X, w_gamma, w_X, theta)

        return w_gamma, w_X