In [2]:
# env: sex_diff
import pandas as pd
import scipy.io
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pickle
import matplotlib.pyplot as plt
from scipy.interpolate import make_interp_spline
from sklearn.metrics.pairwise import cosine_similarity
from scipy.stats import permutation_test

import os


## Save ensemble covat site prediction results
### Data: HCPA, HCPD; newYA95 same acquisition as original A, D
### (1) load single flavor results (logistic regression after covbat)
### (2) get ensemble covbat results
### (3) save as a pickle file for use (used by site_acc_allmodels.ipynb)

In [None]:
bins = [8, 11, 14, 18, 22, 36, 45, 55, 65, 80, 101] 
group_labels = ['[8,11)', '[11,14)', '[14,18)', '[18,22)', '[22,36)', '[36,45)', '[45, 55)', '[55, 65)' ,'[65, 80)', '[80,101)']
group_labels1 = ['8-11', '11-14', '14-18', '18-22', '22-36', '36-45', '45-55', '55-65', '65-80', '80-101']

flavor = 'SCifod2act_fs86_volnormicv'
data = scipy.io.loadmat(f'/home/out_log/results_HCPdata_Kraken/newYA95/logistic_regression_ensemble/{flavor}[8, 11, 14, 18, 22, 36, 45, 55, 65, 80, 101]stratifiedCV_split0328_test30newYA_train_downsampledtrain30site.mat')
n_connectivity = data['8-11'][0][0][0].shape[1]

In [None]:
def load_data(flavor, outer_folds=100, bins = [8, 11, 14, 18, 22, 36, 45, 55, 65, 80, 101], mul_etiv=False):
    data = scipy.io.loadmat(f'/home/out_log/results_HCPdata_Kraken/newYA95/logistic_regression_ensemble/covbat/{flavor}[8, 11, 14, 18, 22, 36, 45, 55, 65, 80, 101]stratifiedCV_split0328_test30newYA_train_downsampledtrain30site.mat')
    return data

sc_shen268_ifod2act = load_data('SCifod2act_shen268_volnormicv')
sc_fs86_ifod2act = load_data('SCifod2act_fs86_volnormicv')
sc_cocommpsuit439_ifod2act = load_data('SCifod2act_coco439_volnormicv')

sc_shen268_sdstream = load_data('SCsdstream_shen268_volnormicv')
sc_fs86_sdstream = load_data('SCsdstream_fs86_volnormicv')
sc_cocommpsuit439_sdstream = load_data('SCsdstream_coco439_volnormicv')

fc_shen268_FCcov_hpfgsr = load_data('FCcorr_shen268_hpfgsr')
fc_fs86_FCcov_hpfgsr = load_data('FCcorr_fs86_hpfgsr')
fc_cocommpsuit439_FCcov_hpfgsr = load_data('FCcorr_coco439_hpfgsr')

fc_shen268_FCcov_hpf = load_data('FCcorr_shen268_hpf')
fc_fs86_FCcov_hpf = load_data('FCcorr_fs86_hpf')
fc_cocommpsuit439_FCcov_hpf = load_data('FCcorr_coco439_hpf')

fc_shen268_FCpcorr_hpf = load_data('FCpcorr_shen268_hpf')
fc_fs86_FCpcorr_hpf = load_data('FCpcorr_fs86_hpf')
fc_cocommpsuit439_FCpcorr_hpf = load_data('FCpcorr_coco439_hpf')

In [5]:
# all flavors
flavors = [fc_fs86_FCcov_hpf, fc_fs86_FCcov_hpfgsr, fc_fs86_FCpcorr_hpf, sc_fs86_ifod2act, sc_fs86_sdstream, 
           fc_shen268_FCcov_hpf, fc_shen268_FCcov_hpfgsr, fc_shen268_FCpcorr_hpf, sc_shen268_ifod2act, sc_shen268_sdstream,
           fc_cocommpsuit439_FCcov_hpf, fc_cocommpsuit439_FCcov_hpfgsr, fc_cocommpsuit439_FCpcorr_hpf, sc_cocommpsuit439_ifod2act, sc_cocommpsuit439_sdstream]
flavorsFC = [fc_fs86_FCcov_hpf, fc_fs86_FCcov_hpfgsr, fc_fs86_FCpcorr_hpf, 
           fc_shen268_FCcov_hpf, fc_shen268_FCcov_hpfgsr, fc_shen268_FCpcorr_hpf,
           fc_cocommpsuit439_FCcov_hpf, fc_cocommpsuit439_FCcov_hpfgsr, fc_cocommpsuit439_FCpcorr_hpf]
flavorsSC = [sc_fs86_ifod2act, sc_fs86_sdstream, sc_shen268_ifod2act, sc_shen268_sdstream, sc_cocommpsuit439_ifod2act, sc_cocommpsuit439_sdstream]

flavors_name = ['fc_fs86_FCcov_hpf', 'fc_fs86_FCcov_hpfgsr', 'fc_fs86_FCpcorr_hpf', 'sc_fs86_ifod2act', 'sc_fs86_sdstream', 
           'fc_shen268_FCcov_hpf', 'fc_shen268_FCcov_hpfgsr', 'fc_shen268_FCpcorr_hpf', 'sc_shen268_ifod2act', 'sc_shen268_sdstream',
           'fc_cocommpsuit439_FCcov_hpf', 'fc_cocommpsuit439_FCcov_hpfgsr', 'fc_cocommpsuit439_FCpcorr_hpf', 'sc_cocommpsuit439_ifod2act', 'sc_cocommpsuit439_sdstream']

In [6]:
from scipy.stats import mode
from sklearn.metrics import recall_score

ensemble_reps = []
ensembleFC_reps = []
ensembleSC_reps = []

ensemble_mean = []
ensembleSC_mean = []
ensembleFC_mean = []
for age in group_labels1:
    sc_shen268_test = sc_shen268_sdstream[age][0][0][4]
    true = sc_shen268_test[:,0,:]
    # fusion results
    test_results = []
    for i in range(len(flavors)):
        arr = flavors[i][age][0][0][4]   # (100, 30)
        pred = arr[:,1,:]
        test_results.append(pred)
    preds = np.stack(test_results, axis=0)
    ensemble_pred = np.apply_along_axis(
        lambda x: np.bincount(x, minlength=5).argmax(), axis=0, arr=preds
    ) 
    baccs = [recall_score(true[r], ensemble_pred[r], labels=[0,1,2,3,4], average="macro", zero_division=0) for r in range(true.shape[0])]   # 100 accuracies
    ensemble_reps.append(baccs)
    ensemble_mean.append(np.mean(baccs))

    # FC only
    test_results = []
    for i in range(len(flavorsFC)):
        arr = flavorsFC[i][age][0][0][4]
        pred = arr[:, 1, :]
        test_results.append(pred)
    preds = np.stack(test_results, axis=0)
    ensemble_pred = np.apply_along_axis(
        lambda x: np.bincount(x, minlength=5).argmax(), axis=0, arr=preds
    )
    baccs_FC = [recall_score(true[r], ensemble_pred[r], labels=[0,1,2,3,4], average="macro", zero_division=0) for r in range(true.shape[0])]   # 100 accuracies
    ensembleFC_reps.append(baccs_FC)
    ensembleFC_mean.append(np.mean(baccs_FC))

    # SC only
    test_results = []
    for i in range(len(flavorsSC)):
        arr = flavorsSC[i][age][0][0][4]
        pred = arr[:, 1, :]
        test_results.append(pred)
    preds = np.stack(test_results, axis=0)
    ensemble_pred = np.apply_along_axis(
        lambda x: np.bincount(x, minlength=5).argmax(), axis=0, arr=preds
    )
    baccs_SC = [recall_score(true[r], ensemble_pred[r], labels=[0,1,2,3,4], average="macro", zero_division=0) for r in range(true.shape[0])]   # 100 accuracies
    ensembleSC_reps.append(baccs_SC)
    ensembleSC_mean.append(np.mean(baccs_SC))

In [None]:
# Save ensemble_reps, ensembleFC_reps, ensembleSC_reps to a pickle file
with open('/home/out_log/results_HCPdata_Kraken/newYA95/logistic_regression_ensemble/covbat/covbat_ensemble_reps0912_site.pkl', 'wb') as f:
    pickle.dump({'ensemble_reps': ensemble_reps, 
                  'ensembleFC_reps': ensembleFC_reps, 
                  'ensembleSC_reps': ensembleSC_reps}, f)