In [1]:
import sys
sys.path.insert(1, '../../')
from SPD_SURE_pytorch import *
import numpy as np
import pandas as pd
from timeit import default_timer as timer
import multiprocessing
from joblib import Parallel, delayed
import pickle
from plotnine import *
import scipy.linalg as sla
from datetime import datetime

In [3]:
def check_SPD(X):
    # X be a n x N x N array
    # check if X[i]'s are SPD
    n = X.shape[0]
    N = X.shape[1]
    I = np.eye(N)
    res = np.zeros(X.shape)
    for i in range(n):
        min_eigval = np.min(np.linalg.eigvalsh(X[i]))
        if min_eigval < 0:
            res[i] = X[i] + (abs(min_eigval) + 1e-3)*I
        else:
            res[i] = X[i]
            
    return res

In [8]:
def exp_rs_fMRI(N, M, Sigma, ran_seed = 0, verbose = False):
    n = 5
    
    mat = np.load('connectivity_matrix.npz')  
    names = ['TD', 'ADHD_C', 'ADHD_I', 'H', 'P', 'CON', 'PSP']
    #N = 10 # number of regions/nodes 
    p = len(names)
    
    res = {'risk_M': pd.Series([0, 0, 0], 
                index = ['FM_logE', 'SURE', 'SURE_full']),
           'risk_Sig': pd.Series([0, 0, 0], 
                index = ['FM_logE', 'SURE', 'SURE_full'])} 
    res = pd.DataFrame(res)
    

    
    X = np.zeros((p, n, N, N))
    
    for i, name in enumerate(names):
        tmp = mat['con_mat_' + name]
        region = mat[name + '_region'][0:N]
        ind = np.random.choice(tmp.shape[0], n, replace = True)
        X[i] = check_SPD(tmp[ind][:, :, region][:, region])
        
        
    ## log-Euclideam mean
    M_logE = np.array([FM_logE(X[i]) for i in range(p)])
    #logX = np.array([vec(X[i]) for i in range(p)])

    
    S_logE = (n-1)*np.array([cov_logE(X[i]) for i in range(p)])
    S_eigval = np.linalg.eigh(S_logE)[0]

    ## SURE (mean only)
    lam_hat, mu_hat, M_SURE = SURE_const(M_logE, np.mean(S_eigval, axis = 1)/(n*(n-1)), verbose = verbose)
    
    ## SURE (mean and covariance)
    lam_hat, mu_hat, nu_hat, Psi_hat, M_SURE_full, Sig_SURE_full = SURE_full(M_logE, S_logE, n, verbose = verbose)

    ## risk
    res.loc['FM_logE', 'risk_M'] = loss(M, M_logE)
    res.loc['SURE', 'risk_M'] = loss(M, M_SURE)
    res.loc['SURE_full', 'risk_M'] = loss(M, M_SURE_full)
    res.loc['FM_logE', 'risk_Sig'] = np.sum((S_logE/(n-1)-Sigma)**2)/p
    res.loc['SURE', 'risk_Sig'] = np.sum((S_logE/(n-1)-Sigma)**2)/p
    res.loc['SURE_full', 'risk_Sig'] = np.sum((Sig_SURE_full-Sigma)**2)/p

    return res.values

In [10]:
mat = np.load('connectivity_matrix.npz')  
names = ['TD', 'ADHD_C', 'ADHD_I', 'H', 'P', 'CON', 'PSP']

p = len(names)
q = int(N*(N + 1)/2)


M = np.zeros((p, N, N))
Sigma = np.zeros((p, q, q))


for i, name in enumerate(names):
    tmp = mat['con_mat_' + name]
    region = mat[name + '_region'][0:N]
    tmp1 = check_SPD(tmp[:, :, region][:, region])
    #print(tmp[:, :, region][:, region].shape)
    M[i] = FM_logE(tmp1)
    #print(M[i].shape)
    Sigma[i] = cov_logE(tmp1)

Sigma = check_SPD(Sigma)

n = 5
m = 50 # repetition
N_vec = np.array([3, 5, 7, 10])
ran_seed = 12345

N = 3



num_cores = -1

results = Parallel(n_jobs=num_cores)(delayed(exp_rs_fMRI)(N, M, Sigma, ran_seed + i) for i in range(m))

In [11]:
m = 100 # repetition
N_vec = np.array([3, 5, 7, 10])
ran_seed = 12345



out_file = 'rs-fMRI_exp.p'

risk_M = pd.DataFrame(np.zeros((len(N_vec), 4)))
risk_M.columns = ['N', 'FM_LogE', 'SURE', 'SURE_full']
risk_M_sd = risk_M.copy()
risk_Sig = risk_M.copy()
risk_Sig_sd = risk_M.copy()
r_ind = 0


for N in N_vec:
    #Compute the papulation mean/cov
    mat = np.load('connectivity_matrix.npz')  
    names = ['TD', 'ADHD_C', 'ADHD_I', 'H', 'P', 'CON', 'PSP']

    p = len(names)
    q = int(N*(N + 1)/2)


    M = np.zeros((p, N, N))
    Sigma = np.zeros((p, q, q))


    for i, name in enumerate(names):
        tmp = mat['con_mat_' + name]
        region = mat[name + '_region'][0:N]
        tmp1 = check_SPD(tmp[:, :, region][:, region])
        #print(tmp[:, :, region][:, region].shape)
        M[i] = FM_logE(tmp1)
        #print(M[i].shape)
        Sigma[i] = cov_logE(tmp1)

    Sigma = check_SPD(Sigma)

    ###############################################################
    print('N = ', N)
    #results = np.zeros((m, 3, 2))
    #for i in range(m):
    #    results[i] = exp_rs_fMRI(n, N, mat, M, Sigma, ran_seed + i)
    num_cores = -1

    results = Parallel(n_jobs=num_cores)(delayed(exp_rs_fMRI)(N, M, Sigma, ran_seed + i) for i in range(m))

    res = np.mean(np.array(results), axis = 0)
    res_sd = np.std(np.array(results), axis = 0)/np.sqrt(m)
    res = pd.DataFrame(res, index = ['FM_logE', 'SURE', 'SURE_full'])
    res_sd = pd.DataFrame(res_sd, index = ['FM_logE', 'SURE', 'SURE_full'])
    res.columns = ['risk_M', 'risk_Sig']
    res_sd.columns = ['risk_M', 'risk_Sig']
    risk_M.values[r_ind] = np.array([N,
        res.loc['FM_logE', 'risk_M'], 
        res.loc['SURE', 'risk_M'],
        res.loc['SURE_full', 'risk_M']])                
    risk_Sig.values[r_ind] = np.array([N,
        res.loc['FM_logE', 'risk_Sig'], 
        res.loc['SURE', 'risk_Sig'],
        res.loc['SURE_full', 'risk_Sig']])                             
    risk_M_sd.values[r_ind] = np.array([N,
        res_sd.loc['FM_logE', 'risk_M'], 
        res_sd.loc['SURE', 'risk_M'],
        res_sd.loc['SURE_full', 'risk_M']])                
    risk_Sig_sd.values[r_ind] = np.array([N,
        res_sd.loc['FM_logE', 'risk_Sig'], 
        res_sd.loc['SURE', 'risk_Sig'],
        res_sd.loc['SURE_full', 'risk_Sig']])                            
    r_ind += 1
    print('Success!')

#results = np.zeros((m, 3, 2))
#for i in range(m):
#    results[i] = exp_rs_fMRI(n, N, M, Sigma, ran_seed + i)
#print(exp_rs_fMRI(n, mat, M, Sigma, ran_seed))

#num_cores = -1
#results = Parallel(n_jobs=num_cores)(delayed(exp_rs_fMRI)(n, N_vec[0], M, Sigma, ran_seed + i) for i in range(m))
#pickle.dump({'N':N_vec, 'risk_M':risk_M, 'risk_Sig':risk_Sig, 
#    'risk_M_sd':risk_M_sd, 'risk_Sig_sd':risk_Sig_sd}, open(out_file, 'wb'))    

#print('res: ', res)
#print('res_sd: ', res_sd)

now = datetime.now()
dt_string = now.strftime("%d/%m/%Y %H:%M:%S")
print("Done @ ", dt_string)

print('risk_M: \n', risk_M)
print('\n')
print('risk_Sig: \n', risk_Sig)
print('\n')
print('risk_M_sd: \n', risk_M_sd)
print('\n')
print('risk_Sig_sd: \n', risk_Sig_sd)
print('\n')

N =  3
Success!
N =  5
Success!
N =  7


  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)
  w = np.expand_dims(np.log(w), axis=-1)


Success!
N =  10
Success!
Done @  19/01/2022 13:53:05
risk_M: 
       N   FM_LogE      SURE  SURE_full
0   3.0  0.409936  0.316225   0.236874
1   5.0  0.761145  0.626729   0.530997
2   7.0  1.362758  1.084409   0.978232
3  10.0  1.947868  1.797288   1.585236


risk_Sig: 
       N    FM_LogE       SURE  SURE_full
0   3.0   8.160027   8.160027   6.851964
1   5.0  13.807250  13.807250  11.490516
2   7.0  31.062862  31.062862  25.820230
3  10.0  45.469828  45.469828  38.030115


risk_M_sd: 
       N   FM_LogE      SURE  SURE_full
0   3.0  0.026528  0.019504   0.009990
1   5.0  0.031571  0.019733   0.015789
2   7.0  0.038421  0.019825   0.017887
3  10.0  0.050607  0.040007   0.036043


risk_Sig_sd: 
       N   FM_LogE      SURE  SURE_full
0   3.0  0.416185  0.416185   0.319874
1   5.0  0.650220  0.650220   0.521244
2   7.0  0.736314  0.736314   0.578092
3  10.0  0.959553  0.959553   0.793542


