In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import ot

from dtmrpy import DT_GMM

from time import time

device='cuda'


In [2]:
from scipy.spatial.distance import cdist

def bures_distance_matrix(Sigma_x,Sigma_y):
    Ux,Sx,Vx = torch.linalg.svd(Sigma_x)
    Sigma_x_sqrt = torch.matmul(torch.matmul(Ux,torch.diag_embed(torch.sqrt(Sx))),Vx)
    
    cross_term = torch.matmul(torch.matmul(Sigma_x_sqrt, Sigma_y.unsqueeze(1)),Sigma_x_sqrt.unsqueeze(0))
    Uc,Sc,Vc = torch.linalg.svd(cross_term)
    sqrt_cross_term = torch.matmul(torch.matmul(Uc,torch.diag_embed(torch.sqrt(Sc))),Vc)

    return torch.einsum('ijkk -> ij', Sigma_x.unsqueeze(0) + Sigma_y.unsqueeze(1) - 2*sqrt_cross_term).T #the transpose here should be fixed


def wasserstein_type_distance(mu0,mu1,reg=1,reg_m=1):
    (weights0,means0,sigma0)=(mu0.weights_,mu0.means_,torch.tensor(mu0.covariances_))
    (weights1,means1,sigma1)=(mu1.weights_,mu1.means_,torch.tensor(mu1.covariances_))

    M = cdist(means0,means1)**2 + bures_distance_matrix(sigma0,sigma1).cpu().numpy()
    
    # return M,ot.unbalanced.sinkhorn_unbalanced(weights0.reshape(-1),weights1.reshape(-1),M,reg=reg,reg_m=reg_m)
    a = weights0.reshape(-1)/np.sum(weights0)
    b = weights1.reshape(-1)/np.sum(weights1)
    return M, ot.emd(a,b,M)

def test_plotter(X,y,model_list,model_name_list):
    
    y_hat_list=[]
    acc_list=[]

    for mdl in model_list:
        y_hat_list.append(cross_val_predict(mdl,X,y,cv=LeaveOneOut()))
        acc_list.append(y_hat_list[-1]==y)

    for i, y_hat in enumerate(y_hat_list):

        ConfusionMatrixDisplay(confusion_matrix(y,y_hat)).plot()
        plt.title(tract+" "+model_name_list[i]+" Accuracy: "+str(np.mean(acc_list[i])))
        plt.show()
        
def get_data(df, tract):
    
    barycenter = np.load("data_files/data_test/"+tract+"_barycenter.npy", allow_pickle=True).item()
    
    measure_list=[]
    label_list=[]

    for i, subject in enumerate(df[tract]):
        if type(subject)==DT_GMM:
            measure_list.append(subject)
            label_list.append(int(df['labels'].values[i]))

    y=np.array(label_list) 

    X=[]

    for subject in measure_list:

        # M,Pi = wasserstein_type_distance(barycenter, subject)
        b_temp = barycenter
        s_temp = subject
        b_temp.means_ = b_temp.means_-np.mean(b_temp.means_,0)
        s_temp.means_ = subject.means_-np.mean(subject.means_,0)
        M,Pi = wasserstein_type_distance(b_temp,s_temp)
        X.append(np.sum(M*Pi,1))
        

    return np.array(X),y


In [3]:
df = pd.read_pickle("data_files/dtmri_dataframe_11_2.pkl")
df = df[df['labels']!='#NULL!']


In [4]:
from sklearn.model_selection import cross_val_predict, LeaveOneOut
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC


def get_result(clf, X, y):
    
    y_hat = cross_val_predict(clf,X,y,cv=LeaveOneOut())
    
    acc = np.mean(y_hat==y)
    conf_mat = confusion_matrix(y,y_hat)
    
    return acc, conf_mat
    
    
def energy_statistic(D,y):
    
    X = D[y==0][:,y==0]
    Y = D[y==1][:,y==1]
    
    n1 = sum(y==0)
    n2 = sum(y==1)
    
    return (2/(n1*n2))*np.sum(D[y==0][:,y==1]) - (1/n1**2)*np.sum(X) - (1/n2**2)*np.sum(Y)


def energy_test(D,y,n_perm=1000):    
    
    n=y.shape[0]
    e_obs = energy_statistic(D,y)
    e_perm = [energy_statistic(D,y[np.random.choice(n,n,replace=False)]) for i in range(n_perm)]
    
    return np.mean(e_obs<e_perm)



def run_tests(X,y, tract_name):
    rf = RandomForestClassifier(class_weight='balanced')
    dt = DecisionTreeClassifier(class_weight='balanced')
    svm = SVC(class_weight='balanced')

    # compute 2 sample energy statistic permutation test 
    D = cdist(X,X)
    energy_stat = energy_test(D,y)

    # build classifiers
    rf_result = get_result(rf, X, y)
    dt_result = get_result(dt, X, y)
    svm_result = get_result(svm, X, y)
    
    return pd.DataFrame({"Tract Name": [tract],
                         "Energy Test p-value": [energy_stat],
                         "Random Forest Accuracy": [rf_result[0]],
                         "Random Forest Confusion Matrix": [rf_result[1]],
                        "Decision Tree Accuracy": [dt_result[0]],
                         "Decision Tree Confusion Matrix": [dt_result[1]],
                        "SVM Accuracy": [svm_result[0]],
                         "SVM Confusion Matrix": [svm_result[1]]})
    
    
    



In [5]:
frames =[]

for tract in df.columns[1:11]:
    
    X,y = get_data(df,tract)
    frames.append(run_tests(X,y,tract))
    
result_df = pd.concat(frames)
result_df

Unnamed: 0,Tract Name,Energy Test p-value,Random Forest Accuracy,Random Forest Confusion Matrix,Decision Tree Accuracy,Decision Tree Confusion Matrix,SVM Accuracy,SVM Confusion Matrix
0,Cingulum_Frontal_Parahippocampal_L,0.449,0.590164,"[[72, 10], [40, 0]]",0.540984,"[[50, 32], [24, 16]]",0.467213,"[[36, 46], [19, 21]]"
0,Cingulum_Frontal_Parahippocampal_R,0.863,0.610687,"[[73, 14], [37, 7]]",0.480916,"[[50, 37], [31, 13]]",0.633588,"[[63, 24], [24, 20]]"
0,Cingulum_Frontal_Parietal_L,0.312,0.661654,"[[81, 8], [37, 7]]",0.62406,"[[64, 25], [25, 19]]",0.556391,"[[55, 34], [25, 19]]"
0,Cingulum_Frontal_Parietal_R,0.98,0.586466,"[[76, 13], [42, 2]]",0.481203,"[[53, 36], [33, 11]]",0.533835,"[[67, 22], [40, 4]]"
0,Cingulum_Parahippocampal_L,0.799,0.639098,"[[84, 5], [43, 1]]",0.556391,"[[58, 31], [28, 16]]",0.43609,"[[34, 55], [20, 24]]"
0,Cingulum_Parahippocampal_R,0.277,0.686567,"[[83, 6], [36, 9]]",0.529851,"[[58, 31], [32, 13]]",0.58209,"[[60, 29], [27, 18]]"
0,Cingulum_Parahippocampal_Parietal_L,0.101,0.616541,"[[76, 13], [38, 6]]",0.556391,"[[63, 26], [33, 11]]",0.56391,"[[64, 25], [33, 11]]"
0,Cingulum_Parahippocampal_Parietal_R,0.314,0.643939,"[[76, 13], [34, 9]]",0.530303,"[[58, 31], [31, 12]]",0.522727,"[[43, 46], [17, 26]]"
0,Cingulum_Parolfactory_L,0.985,0.580153,"[[75, 13], [42, 1]]",0.51145,"[[54, 34], [30, 13]]",0.351145,"[[33, 55], [30, 13]]"
0,Cingulum_Parolfactory_R,0.563,0.618321,"[[80, 8], [42, 1]]",0.549618,"[[63, 25], [34, 9]]",0.519084,"[[56, 32], [31, 12]]"
