In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy.spatial.distance import cdist

import ot

import geomstats
from geomstats.geometry.spd_matrices import SPDMatrices

from dtmrpy import DT_GMM

from sklearn.svm import SVC
from sklearn.model_selection import cross_val_score, LeaveOneOut

INFO: Using numpy backend


In [2]:
class dtgmm:
    def __init__(self, weights, locations):
        assert weights.shape[0] == locations.shape[0]
        self.weights = weights
        self.locations = locations
     
    def means_(self):
        return self.locations[:,:3]
    
    def covariances_(self):
        C = self.locations[:,3:]
        return np.array([C[:,0],C[:,1]/2,C[:,2]/2,C[:,1]/2,C[:,3],C[:,4]/2,C[:,2]/2,C[:,4]/2,C[:,5]]).T.reshape(-1,3,3)
        
    def plot_gmm(self):
        w = self.weights/max(self.weights)
        x,y,z = self.locations[:,0],self.locations[:,1],self.locations[:,2]
        c = np.array([self.locations[:,3],self.locations[:,6],self.locations[:,8]]).T
        c=c/(np.sum(c,1).reshape(-1,1))
 
        for i in range(w.shape[0]):
            plt.plot(x[i],y[i],z[i],'.',c=c[i],alpha = w[i])
            
        # plt.gca().scatter(x,y,z,s=8,c=c, alpha = w)

def DT_GMM_to_dtgmm(DT_GMM):
    w = DT_GMM.weights_.reshape(-1)
    weights = w/sum(w)
    
    x = DT_GMM.means_
    x = x-np.mean(x,0)
    cov_temp = geomstats.geometry.spd_matrices.SPDMatrices(3).projection(DT_GMM.covariances_)
    C = (cov_temp.reshape(-1,9)[:,np.array([0,1,2,4,5,8])])*(np.array([1,2,2,1,2,1]).reshape(1,-1))   
    locations = np.concatenate([x,C],1)
    
    return dtgmm(weights, locations)

def dtgmm_to_DT_GMM(dtgmm):
    weights_ = dtgmm.weights
    means_ = dtgmm.locations[:,:3]
    dtgmm.covariances_
    return DT_GMM(weights_,means_,covarinaces_)

In [3]:
class free_support_barycenter(dtgmm):
    
    def __init__(self,barycenter=None, N=200, lr=0.0000001):
        self.lr=lr #learning rate for weights update
        
        if barycenter==None:
            init_weights = np.ones(N)/N
            x = np.zeros(3)+np.random.normal(size=(N,3))
            cov_temp = geomstats.geometry.spd_matrices.SPDMatrices(3).random_point(N)
            C = (cov_temp.reshape(-1,9)[:,np.array([0,1,2,4,5,8])])*(np.array([1,2,2,1,2,1]).reshape(1,-1)) 
            init_locations = np.concatenate([x,C],1)
            
            super().__init__(init_weights, init_locations)
            self.N = N
        
        else:
            super().__init__(barycenter.weights, barycenter.locations)
            self.N = barycenter.weights.shape[0]
               
    def get_barycentric_projection_embedding(self, measures_list):
        self.M_list = [np.square(cdist(self.locations,measure.locations)) for measure in measures_list]
        #calculate optimal couplings and optimal dual variables
        result_list = [ot.emd(self.weights,measure.weights,self.M_list[i],log=True) for i, measure in enumerate(measures_list)]
        #store optimal couplings
        self.Pi_list = [result[0] for result in result_list]
        #store optimal dual variables - center_ot_dual can probably just be replaced with result[1]['u']-np.mean(result[1]['u'])
        self.alpha_list = [ot.lp.center_ot_dual(result[1]['u'],result[1]['v'])[0] for result in result_list]
        #calculate and store barycentric projection locations
        self.embedding = np.array([(measures_list[i].locations.T@self.Pi_list[i].T@np.diag(1/self.weights.reshape(-1))).T for i in range(len(measures_list))])

    def weights_update(self):
        #get subgradient
        alpha = np.mean(np.array(self.alpha_list),0)
        #calculate subgradient update
        a_star = self.weights+(self.lr*alpha.reshape(1,-1))
        #project a_star into (interior of) probability simplex
        a_star[a_star<0]=1e-8
        a = a_star/np.sum(a_star)

        return a.reshape(-1)
        
    def free_support_barycenter_update(self, measures_list):
        self.get_barycentric_projection_embedding(measures_list)
        self.locations = np.mean(self.embedding,0)
        self.weights = self.weights_update()
        # print(np.mean(np.square(np.linalg.norm(barycenter.pseudo_log(),axis=1))))
        
    def pseudo_log(self):
        #calculate vector field representations
        return (self.embedding - self.locations).reshape(-1,self.N*9)
    
    def fit(self, measures_list, K=10, plot_steps=False):
        
        for i in range(K):
            self.free_support_barycenter_update(measures_list)
            
            if plot_steps:
                plt.figure().add_subplot(projection='3d')
                barycenter.plot_gmm()
                plt.gca().view_init(35,135)
                plt.show()
        

In [None]:
df = np.load('D:/DTMRI/HCP/hcp_centered_full_MZ_removed.pkl', allow_pickle=True)
keep_columns = list(df.columns[:2])+list(df.columns[4:12])+list(df.columns[16:18])+list(df.columns[24:-1])
df = df[keep_columns].dropna()
df


In [66]:
from sklearn.model_selection import cross_val_predict, LeaveOneOut

# X_list2=[]
for k in [200]:
    X_list=[]

    for i, tract in enumerate(df.columns[2:]):
        print(i, end=' ')

        #get tract data
        ind = df[tract].dropna().index
        measure_list = list(df[tract][ind])
        dtgmm_list = [DT_GMM_to_dtgmm(measure) for measure in measure_list]
        y = np.array((df['label'][ind]=='M').astype(int))

        #calculate barycenter
        #initialize barycenter with k support points
        barycenter = free_support_barycenter(N=k)

        #fit k-support barycenter to data
        barycenter.fit(dtgmm_list)

        #get barycentric_projections
        X_list.append((barycenter.embedding - barycenter.locations)[:,:,3:].reshape(len(y),-1))

    X_list2.append(X_list)
    # Z = np.array([x.T for x in X_list]).reshape(-1,739).T
    # print(' ')
    # print(k, np.mean(cross_val_score(SVC(kernel='rbf', probability=True),Z,y,cv=10)))
    # print(' ')

# test = np.array([cross_val_predict(SVC(kernel='rbf', probability=True),x,y,cv=cv, method='predict_proba')[:,1] for x in X_list])
    

0 

  result_code_string = check_result(result_code)


1 

  result_code_string = check_result(result_code)


2 3 4 5 6 7 8 9 10 11 

In [82]:
np.save("X_list2_5",X_list2[5])

IndexError: list index out of range

In [78]:
np.load('X_list2_0.npy').shape

(12, 739, 6)

In [83]:
for X in X_list2:
    y_hat=[]
    for i, x in enumerate(X):
        print(i, end=' ')
        y_hat.append(cross_val_predict(SVC(kernel='rbf', probability=True),x,y,cv=10, method='predict_proba')[:,1])
        print(np.mean((np.mean(np.array(y_hat),0)>.5).astype(int)==y))
        
    print(np.mean((np.mean(np.array(y_hat),0)>.5).astype(int)==y))
    print(" ")
    

0 0.591339648173207
1 0.6359945872801083
2 0.6684709066305818
3 0.6684709066305818
4 0.6847090663058186
5 0.6901217861975643
6 0.6982408660351827
7 0.6874154262516915
8 0.6901217861975643
9 0.6901217861975643
10 0.6955345060893099
11 0.6928281461434371
0.6928281461434371
 
0 0.6630581867388363
1 0.6955345060893099
2 0.7212449255751014
3 0.7266576454668471
4 0.7469553450608931
5 0.7456021650879567
6 0.7469553450608931
7 0.7469553450608931
8 0.7456021650879567
9 0.7415426251691475
10 0.7523680649526387
11 0.7510148849797023
0.7510148849797023
 
0 0.6806495263870095
1 0.6820027063599459
2 0.7280108254397835
3 0.7428958051420839
4 0.7469553450608931
5 0.7469553450608931
6 0.7510148849797023
7 0.7428958051420839
8 0.7456021650879567
9 0.7428958051420839
10 0.7550744248985115
11 0.7523680649526387
0.7523680649526387
 
0 0.6698240866035182
1 0.6738836265223275
2 0.7253044654939107
3 0.7374830852503383
4 0.7374830852503383
5 0.7320703653585927
6 0.7374830852503383
7 0.7347767253044655
8 0.7374