In [5]:
import sys
sys.path.insert(1, '../../')
from SPD_SURE_pytorch import *
from scipy.io import loadmat
from scipy.linalg import logm, expm
import numpy.linalg as nla
import cmath
import pickle
from numpy.random import choice
import multiprocessing
from joblib import Parallel, delayed
import pandas as pd

def format_data(X):
    d, p, n = X.shape[1:]
    new_X = np.zeros((p, n, d, d))
    for i in range(p):
        for j in range(n):
            new_X[i, j] = X[:, :, i, j]

    return new_X

def format_mu(mu):
    d, p = mu.shape[1:]
    new_mu = np.zeros((p, d, d))
    for i in range(p):
            new_mu[i] = mu[:, :, i]

    return new_mu

def remove_non_SPD(X):
    p, n, d = X.shape[0:3]
    new_X = np.zeros(X.shape)
    for i in range(p):
        for j in range(n):
            S, Q = nla.eigh(X[i, j])
            new_X[i, j] = np.matmul(Q, np.matmul(np.diag(np.maximum(S, 1e-8)), Q.T))
            # if not all(nla.eigvalsh(X[i, j]) > 1e-3):
                # tmp = logm(X[i, j]).real
                # new_X[i, j] = expm(tmp)
                # if not all(nla.eigvalsh(new_X[i, j]) > 1e-5):
                    # new_X[i, j] += 1e-5*np.eye(d)
            # else:
                # new_X[i, j] = X[i, j]
    return new_X


def test_M(k, X, Y, ran_seed):
    p, n1, d = X.shape[0:3]
    n2 = Y.shape[1]
    np.random.seed(ran_seed)
    tmp_X = X[:, choice(n1, k, replace = True)] 
    tmp_Y = Y[:, choice(n2, k, replace = True)] 


    meanX_logE = np.array([FM_logE(tmp_X[i]) for i in range(p)])
    meanY_logE = np.array([FM_logE(tmp_Y[i]) for i in range(p)])
    #meanX_GL = np.array([FM_GL_rec(tmp_X[i]) for i in range(p)])
    #meanY_GL = np.array([FM_GL_rec(tmp_Y[i]) for i in range(p)])

    lam_hat, mu_hat, meanX_SURE = SURE_const(meanX_logE, var_X/k)
    lam_hat, mu_hat, meanY_SURE = SURE_const(meanY_logE, var_Y/k)
    
    tmpX_S_logE = (k-1)*np.array([cov_logE(tmp_X[i]) for i in range(p)])
    tmpY_S_logE = (k-1)*np.array([cov_logE(tmp_Y[i]) for i in range(p)])
    
    lam_hat, mu_hat, nu_hat, Psi_hat, MX_SURE_full, SigX_SURE_full = SURE_full(meanX_logE, tmpX_S_logE, k)
    lam_hat, mu_hat, nu_hat, Psi_hat, MY_SURE_full, SigY_SURE_full = SURE_full(meanY_logE, tmpY_S_logE, k)

    l_logE = loss(mu_X, meanX_logE) + loss(mu_Y, meanY_logE)
    #l_GL = loss(mu_X, meanX_GL) + loss(mu_Y, meanY_GL)
    l_SURE = loss(mu_X, meanX_SURE) + loss(mu_Y, meanY_SURE)
    l_SURE_full = loss(mu_X, MX_SURE_full) + loss(mu_Y, MY_SURE_full)

    return np.array([l_logE, l_SURE, l_SURE_full])

def loss_Sig(X, Y):
    p = X.shape[0]
    return np.sum((X-Y)**2)/p
    

def test_Sig(k, X, Y, ran_seed):
    p, n1, d = X.shape[0:3]
    n2 = Y.shape[1]
    np.random.seed(ran_seed)
    tmp_X = X[:, choice(n1, k, replace = True)] 
    tmp_Y = Y[:, choice(n2, k, replace = True)] 


    meanX_logE = np.array([FM_logE(tmp_X[i]) for i in range(p)])
    meanY_logE = np.array([FM_logE(tmp_Y[i]) for i in range(p)])
    
    tmpX_S_logE = (k-1)*np.array([cov_logE(tmp_X[i]) for i in range(p)])
    tmpY_S_logE = (k-1)*np.array([cov_logE(tmp_Y[i]) for i in range(p)])
    
    lam_hat, mu_hat, nu_hat, Psi_hat, MX_SURE_full, SigX_SURE_full = SURE_full(meanX_logE, tmpX_S_logE, k)
    lam_hat, mu_hat, nu_hat, Psi_hat, MY_SURE_full, SigY_SURE_full = SURE_full(meanY_logE, tmpY_S_logE, k)

    l_MLE = loss_Sig(S_X, tmpX_S_logE/(k-1)) + loss_Sig(S_Y, tmpY_S_logE/(k-1))
    l_SURE_full = loss_Sig(S_X, SigX_SURE_full) + loss_Sig(S_Y, SigY_SURE_full)

    return np.array([l_MLE, l_SURE_full])

In [2]:
# Parkinson's tracts
# 141 patients, 4 classes, 33 SPD(3) per patient
mat = loadmat('tracts.mat')
data = mat['tracts3'][0][0][0]
data = data.transpose((1,0,2,3))
data = remove_non_SPD(data)
name = mat['names']
label = mat['labels'] # 0:CON; 1:PDL; 2:MSA; 3:PSP
label = label.reshape(label.shape[0])
p, n, d = data.shape[0:3]

X = data[:, label == 0]
Y = data[:, label == 1]
n1 = X.shape[1]
n2 = Y.shape[1]
mu_X = np.array([FM_logE(X[i]) for i in range(p)])
mu_Y = np.array([FM_logE(Y[i]) for i in range(p)])
var_X = np.array([var_logE(X[i]) for i in range(p)])
var_Y = np.array([var_logE(Y[i]) for i in range(p)])


num_cores = 8
m = 100
k_vec = np.array([10, 20, 50, 100])

In [3]:
# for FM
res = np.zeros((3, len(k_vec)))
res = pd.DataFrame(res, index = ['FM_logE', 'SURE', 'SURE.Full'])
res.columns = np.copy(k_vec)
ran_seed = 1000

res_se = res.copy()

for k in k_vec:
    #results = Parallel(n_jobs=num_cores)(delayed(test_M)(k, X, Y, 12345+i) \
    #    for i in range(m))
    
    results = np.zeros((m, 3))
    for i in range(m):
        results[i] = test_M(k, X, Y, ran_seed + i) 
    
    res.loc[:,k] = np.mean(np.array(results), axis = 0)
    res_se.loc[:,k] = np.std(np.array(results), axis = 0)/np.sqrt(m)
    print("k =", k, "finished!")
    
print(res)
print(res_se)

k = 10 finished!
k = 20 finished!
k = 50 finished!
k = 100 finished!
                10        20        50        100
FM_logE    0.774482  0.404533  0.159456  0.079296
SURE       0.774328  0.404487  0.159339  0.079705
SURE.Full  0.388456  0.168624  0.094449  0.057247
                10        20        50        100
FM_logE    0.028893  0.012023  0.004607  0.001904
SURE       0.028877  0.012020  0.004601  0.001859
SURE.Full  0.019272  0.004584  0.001909  0.000981


In [7]:
# for Sigma

S_X = np.array([cov_logE(X[i]) for i in range(p)])
S_Y = np.array([cov_logE(Y[i]) for i in range(p)])
res_Sig = np.zeros((2, len(k_vec)))
res_Sig = pd.DataFrame(res_Sig, index = ['MLE', 'SURE.Full'])
res_Sig.columns = np.copy(k_vec)
ran_seed = 1000

res_Sig_se = res_Sig.copy()

for k in k_vec:
    #results = Parallel(n_jobs=num_cores)(delayed(test_Sig)(k, i, X, Y) \
    #    for i in range(m))
    
    results = np.zeros((m, 2))
    for i in range(m):
        results[i] = test_Sig(k, X, Y, ran_seed + i) 
    
    res_Sig.loc[:,k] = np.mean(np.array(results), axis = 0)
    res_Sig_se.loc[:,k] = np.std(np.array(results), axis = 0)/np.sqrt(m)
    print("k =", k, "finished!")
    
    
print(res_Sig)
print(res_Sig_se)

k = 10 finished!
k = 20 finished!
k = 50 finished!
k = 100 finished!
                  10         20         50         100
MLE        123.693576  66.795144  25.539377  12.912664
SURE.Full  111.128060  63.273982  24.993598  12.798364
                10        20        50        100
MLE        5.709112  2.689674  0.906877  0.409572
SURE.Full  5.007468  2.525299  0.883503  0.404218
