In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, roc_curve, roc_auc_score, precision_score, f1_score, recall_score
import scipy.io as sio
from scipy.stats import chi2_contingency, fisher_exact, linregress
import statsmodels.formula.api as smf
from plot_tools import plot_points, plot_correlation, plot_confusion_matrix
from pycirclize import Circos
import matplotlib.cm as cm
import matplotlib
from scipy.spatial.distance import cdist
from brainstates import AgglomerateTransformer
from statsmodels.stats.multitest import fdrcorrection

In [None]:
MAIN_DIR = "/home/usuario/disco1/proyectos/2023-Coma3D"
WINDOW_SIZE = 30
STRIDE = 16
N_CLUSTERS = 2
INPUT_PATIENTS_DIR = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/patients/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
INPUT_ALL_DIR = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/all/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
INPUT_LEAVE_ONE_OUT = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/leave_one_out/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
IGNORE_OUTCOME = [6, 16] # we don't have the data on these ones
IGNORE_DATA = [6, 17] # we don't have the outcomes on these ones

outcomes = pd.read_csv(f"{MAIN_DIR}/scripts/demographics.csv")
outcomes.drop(index=IGNORE_OUTCOME, inplace=True)
outcomes["etiology"] = outcomes["etiology"].astype("category")
outcomes["sex"] = outcomes["sex"].astype("category")
outcomes["binary_outcome_cat"] = list(map(lambda x: "WORSE" if x == 0 else "BETTER", outcomes["binary_outcome"]))
outcomes["binary_outcome_cat"] = outcomes["binary_outcome_cat"].astype("category").cat.reorder_categories(["WORSE", "BETTER"])
outcomes

In [None]:
# Import relevant variables and match the corresponding outcome data
probs = sio.loadmat(f"{INPUT_PATIENTS_DIR}/probs_patients.mat")["probs"]
probs_all_controls = sio.loadmat(f"{INPUT_ALL_DIR}/probs_controls.mat")["probs"]
probs_all_patients = sio.loadmat(f"{INPUT_ALL_DIR}/probs_patients.mat")["probs"]

c_ord_matrix = sio.loadmat(f"{INPUT_PATIENTS_DIR}/centroids.mat")["Cord"]
c_ord_matrix_all = sio.loadmat(f"{INPUT_ALL_DIR}/centroids.mat")["Cord"]
entropy = sio.loadmat(f"{INPUT_PATIENTS_DIR}/entropy.mat")["H"]
probs = np.delete(probs, IGNORE_DATA, axis=0)
outcomes["prob1"] = probs[:,0]
outcomes["prob2"] = probs[:,1]
#outcomes["prob3"] = probs[:,2]
outcomes["we"] = np.sum(probs*np.repeat(entropy, 23, axis=0), axis=1)
outcomes

In [None]:
outcomes.info()

In [None]:
# Import the control + patient demographics
demographics = pd.read_csv(f"{MAIN_DIR}/scripts/demographics_all.csv")
probs_stacked = np.vstack((probs_all_patients, probs_all_controls))
demographics["prob1"] = probs_stacked[:,0]
demographics["prob2"] = probs_stacked[:,1]
#demographics["prob3"] = probs_stacked[:,2]
demographics["sex"] = demographics["sex"].astype("category")
demographics["group"] = demographics["group"].astype("category")
demographics

#### Circlize

In [None]:
roi_names = np.array(["Frontal_Sup_L","Frontal_Sup_Medial_L_R","Cingulum_Ant_L","Frontal_Sup_Medial_L","Cingulum_Ant_L_2","Frontal_Sup_Medial_L_2","Frontal_Med_Orb_R","Cingulum_Ant_L_R","Frontal_Sup_L_2","Cingulum_Ant_L_3","Frontal_Med_Orb_L","Angular_L","Frontal_Sup_R","PreCuneus_L_R","PreCuneus_L_R_2","Cingulum_Post_L_R","Cingulum_Mid_L_R","Angular_R","Thalamus_L_R","ParaHippocampal_L","ParaHippocampal_R","Calcarine_L","Frontal_Mid_L","Fusiform_L","Occipital_Mid_L","Precuneus_R/Lingual_R","PreCuneus_L_R","PreCuneus_R","PreCuneus_L","PreCuneus_R_2","Frontal_Mid_R","Frontal_Mid_R_2","Fusiform_R","Occipital_Mid_R"])
sector_names = ["Frontal", "Precuneus", "Cingulum", "Other"]
sectors = {"Frontal": 11, "Precuneus": 7, "Cingulum": 6, "Other": 10}
name2color = {"Frontal": "#ffb152", "Precuneus": "#ff3bf2", "Cingulum": "green", "Other": "gray"}
sector_rois = {
    "Frontal": [0, 1, 3, 5, 6, 8, 10, 12, 22, 30, 31],
    "Cingulum": [2, 4, 7, 9, 15, 16],
    "Precuneus": [13, 14, 25, 26, 27, 28, 29],
    "Other": [11, 17, 18, 19, 20, 21, 23, 32, 24, 33]
}
circos = Circos(sectors, space=5)
for sector in circos.sectors:
    track = sector.add_track((95, 100))
    track.axis(fc=name2color[sector.name])
    track.text(sector.name, color="white", size=12)
    #track.xticks_by_interval(1, label_formatter=label_formatter_func, label_orientation="vertical")
    track.xticks(
        np.arange(0, len(sector_rois[sector.name]))+0.5,
        #labels=roi_names[sector_rois[sector.name]],
        #label_orientation="vertical",
        #label_margin=3,
        #label_size=14
    )


# Plot links

matrix = c_ord_matrix[0,:,:]
thr_pos = np.percentile(matrix[matrix > 0], 90)
thr_neg = np.percentile(matrix[matrix < 0], 10)

minima = np.min(matrix)
maxima = np.max(matrix)

norm = matplotlib.colors.Normalize(vmin=minima, vmax=maxima, clip=True)
mapper = cm.ScalarMappable(norm=norm, cmap=cm.turbo)

for i, isector in enumerate(sector_names):
    for j, jsector in enumerate(sector_names):
        for k, kroi in enumerate(sector_rois[isector]):
            for l, lroi in enumerate(sector_rois[jsector]):
                value = matrix[kroi,lroi]
                color = "red" if value > 0 else "blue"
                if value > thr_pos or value < thr_neg:
                    circos.link((isector, k, k+1), (jsector, l, l+1), color=color, alpha=0.2)



fig = circos.plotfig()
circos.savefig(f"{INPUT_ALL_DIR}/figures/circlize_state1.png")

#### Centroids patients

In [None]:
plt.subplots(figsize=(6,6))
plt.imshow(c_ord_matrix[0,:,:], vmin=-0.2, vmax=0.8, cmap="turbo")

plt.axis("off")
plt.savefig(f"{INPUT_PATIENTS_DIR}/figures/centroid-1.png")

In [None]:
plt.subplots(figsize=(6,6))
plt.imshow(c_ord_matrix[1,:,:], vmin=-0.2, vmax=0.8, cmap="turbo")

plt.axis("off")
plt.savefig(f"{INPUT_PATIENTS_DIR}/figures/centroid-2.png")

#### Centroids controls + patients

#### centroid 1

In [None]:
plt.subplots(figsize=(6,6))
plt.imshow(c_ord_matrix_all[0,:,:], vmin=-0.2, vmax=0.8, cmap="turbo")

plt.axis("off")
plt.savefig(f"{INPUT_ALL_DIR}/figures/centroid-1.png")

#### DMN dorsal, centroid 1

In [None]:
# 115 Frontal_Sup_L
# 116 Frontal_Sup_Medial_L_R
# 117 Cingulum_Ant_L
# 118 Frontal_Sup_Medial_L
# 119 Cingulum_Ant_L_2
# 120 Frontal_Sup_Medial_L_2
# 121 Frontal_Med_Orb_R
# 122 Cingulum_Ant_L_R
# 123 Frontal_Sup_L_2
# 124 Cingulum_Ant_L_3
# 125 Frontal_Med_Orb_L

# 126 Angular_L
# 127 Frontal_Sup_R
# 128 PreCuneus_L_R
# 129 PreCuneus_L_R_2
# 130 Cingulum_Post_L_R
# 131 Cingulum_Mid_L_R
# 132 Angular_R
# 133 Thalamus_L_R
# 134 ParaHippocampal_L
# 135 ParaHippocampal_R

plt.imshow(c_ord_matrix_all[0,0:20,0:20], vmin=-0.2, vmax=0.8, cmap="turbo")

#### DMN ventral, centroid 1

In [None]:
# 216 Calcarine_L
# 217 Frontal_Mid_L
# 218 Fusiform_L
# 219 Occipital_Mid_L
# 220 Precuneus_R/Lingual_R
# 221 PreCuneus_L_R
# 222 PreCuneus_R
# 223 PreCuneus_L
# 224 PreCuneus_R_2
# 225 Frontal_Mid_R
# 226 Frontal_Mid_R_2
# 227 Fusiform_R 
# 228 Occipital_Mid_R

plt.imshow(c_ord_matrix_all[0,20:,20:], vmin=-0.2, vmax=0.8, cmap="turbo")

In [None]:
plt.imshow(c_ord_matrix_all[0,:20,20:], vmin=-0.2, vmax=0.8, cmap="turbo")

#### Centroid 2

In [None]:
plt.subplots(figsize=(6,6))
plt.imshow(c_ord_matrix_all[1,:,:], vmin=-0.2, vmax=0.8, cmap="turbo")

plt.axis("off")
plt.savefig(f"{INPUT_ALL_DIR}/figures/centroid-2.png")

#### DMN dorsal, centroid 2

In [None]:
plt.imshow(c_ord_matrix_all[1,0:20,0:20], vmin=-0.2, vmax=0.8, cmap="turbo")

#### DMN ventral, centroid 2

In [None]:
plt.imshow(c_ord_matrix_all[1,20:,20:], vmin=-0.2, vmax=0.8, cmap="turbo")

In [None]:
plt.imshow(c_ord_matrix_all[1,:20,20:], vmin=-0.2, vmax=0.8, cmap="turbo")

#### subdivision DMN frontal, cingulum, precuneus, other sobre centroide

In [None]:
c1_frontal = c_ord_matrix[0, sector_rois["Frontal"], :][:, sector_rois["Frontal"]]
c1_cingulum = c_ord_matrix[0, sector_rois["Cingulum"], :][:, sector_rois["Cingulum"]]
c1_precuneus = c_ord_matrix[0, sector_rois["Precuneus"], :][:, sector_rois["Precuneus"]]
c1_other = c_ord_matrix[0, sector_rois["Other"], :][:, sector_rois["Other"]]

c2_frontal = c_ord_matrix[1, sector_rois["Frontal"], :][:, sector_rois["Frontal"]]
c2_cingulum = c_ord_matrix[1, sector_rois["Cingulum"], :][:, sector_rois["Cingulum"]]
c2_precuneus = c_ord_matrix[1, sector_rois["Precuneus"], :][:, sector_rois["Precuneus"]]
c2_other = c_ord_matrix[1, sector_rois["Other"], :][:, sector_rois["Other"]]

plt.title("intra")
plt.bar(
    np.array([1,2,3,4])-0.2,
    [c1_frontal[np.triu_indices(len(c1_frontal))].mean(), c1_cingulum[np.triu_indices(len(c1_cingulum))].mean(), c1_precuneus[np.triu_indices(len(c1_precuneus))].mean(), c1_other[np.triu_indices(len(c1_other))].mean()],
    yerr=[c1_frontal[np.triu_indices(len(c1_frontal))].std()/np.sqrt(len(c1_frontal)), c1_cingulum[np.triu_indices(len(c1_cingulum))].std()/np.sqrt(len(c1_cingulum)), c1_precuneus[np.triu_indices(len(c1_precuneus))].std()/np.sqrt(len(c1_precuneus)), c1_other[np.triu_indices(len(c1_other))].std()/np.sqrt(len(c1_other))],
    width=0.4
)
plt.bar(
    np.array([1,2,3,4])+0.2,
    [c2_frontal[np.triu_indices(len(c2_frontal))].mean(), c2_cingulum[np.triu_indices(len(c2_cingulum))].mean(), c2_precuneus[np.triu_indices(len(c2_precuneus))].mean(), c2_other[np.triu_indices(len(c2_other))].mean()],
    yerr=[c2_frontal[np.triu_indices(len(c2_frontal))].std()/np.sqrt(len(c2_frontal)), c2_cingulum[np.triu_indices(len(c2_cingulum))].std()/np.sqrt(len(c2_cingulum)), c2_precuneus[np.triu_indices(len(c2_precuneus))].std()/np.sqrt(len(c2_precuneus)), c2_other[np.triu_indices(len(c2_other))].std()/np.sqrt(len(c2_other))],
    width=0.4
)
plt.xticks([1,2,3,4], labels=["frontal", "cingulum", "precuneus", "other"])

In [None]:
c1_frontal_cingulum = c_ord_matrix[0, sector_rois["Frontal"], :][:, sector_rois["Cingulum"]]
c1_frontal_precuneus = c_ord_matrix[0, sector_rois["Frontal"], :][:, sector_rois["Precuneus"]]
c1_frontal_other = c_ord_matrix[0, sector_rois["Frontal"], :][:, sector_rois["Other"]]
c1_cingulum_precuneus = c_ord_matrix[0, sector_rois["Cingulum"], :][:, sector_rois["Precuneus"]]
c1_cingulum_other = c_ord_matrix[0, sector_rois["Cingulum"], :][:, sector_rois["Other"]]
c1_precuneus_other = c_ord_matrix[0, sector_rois["Precuneus"], :][:, sector_rois["Other"]]

c2_frontal_cingulum = c_ord_matrix[1, sector_rois["Frontal"], :][:, sector_rois["Cingulum"]]
c2_frontal_precuneus = c_ord_matrix[1, sector_rois["Frontal"], :][:, sector_rois["Precuneus"]]
c2_frontal_other = c_ord_matrix[1, sector_rois["Frontal"], :][:, sector_rois["Other"]]
c2_cingulum_precuneus = c_ord_matrix[1, sector_rois["Cingulum"], :][:, sector_rois["Precuneus"]]
c2_cingulum_other = c_ord_matrix[1, sector_rois["Cingulum"], :][:, sector_rois["Other"]]
c2_precuneus_other = c_ord_matrix[1, sector_rois["Precuneus"], :][:, sector_rois["Other"]]

plt.title("inter")
plt.bar(
    np.array([1,2,3,4,5,6])-0.2,
    [c1_frontal_cingulum.mean(), c1_frontal_precuneus.mean(), c1_frontal_other.mean(), c1_cingulum_precuneus.mean(), c1_cingulum_other.mean(), c1_precuneus_other.mean()],
    yerr=[c1_frontal_cingulum.std()/np.sqrt(len(c1_frontal_cingulum)), c1_frontal_precuneus.std()/np.sqrt(len(c1_frontal_precuneus)), c1_frontal_other.std()/np.sqrt(len(c1_frontal_other)), c1_cingulum_precuneus.std()/np.sqrt(len(c1_cingulum_precuneus)), c1_cingulum_other.std()/np.sqrt(len(c1_cingulum_other)), c1_precuneus_other.std()/np.sqrt(len(c1_precuneus_other))],
    width=0.4
)

plt.bar(
    np.array([1,2,3,4,5,6])+0.2,
    [c2_frontal_cingulum.mean(), c2_frontal_precuneus.mean(), c2_frontal_other.mean(), c2_cingulum_precuneus.mean(), c2_cingulum_other.mean(), c2_precuneus_other.mean()],
    yerr=[c2_frontal_cingulum.std()/np.sqrt(len(c2_frontal_cingulum)), c2_frontal_precuneus.std()/np.sqrt(len(c2_frontal_precuneus)), c2_frontal_other.std()/np.sqrt(len(c2_frontal_other)), c2_cingulum_precuneus.std()/np.sqrt(len(c2_cingulum_precuneus)), c2_cingulum_other.std()/np.sqrt(len(c2_cingulum_other)), c2_precuneus_other.std()/np.sqrt(len(c2_precuneus_other))],
    width=0.4
)

plt.xticks([1,2,3,4,5,6], labels=["frontal-cingulum","frontal-precuneus","frontal-other","cingulum-precuneus","cingulum-other","precuneus-other"], rotation=90)

#### subdivision DMN sobre sujetos

In [None]:
dfc = sio.loadmat(f"{INPUT_ALL_DIR}/dfc.mat")["dfc"]
n_windows, n_rois, n_rois, n_subs = dfc.shape

conn_subrois = np.empty((48,30,2,4))
conn_subrois[:] = np.nan
conn_subrois_inter = np.empty((48,30,2,6))
conn_subrois_inter[:] = np.nan
for sub in range(n_subs):
    pred_list_a = []
    for window in range(n_windows):
        fc = dfc[window,:,:,sub]
        pred = cdist(fc.reshape(1,34*34), c_ord_matrix_all.reshape(2, 34*34), metric="cityblock").argmin(axis=1)
        pred_list[window] = pred
        pred_list_a.append(pred)
        
        c1_frontal = fc[sector_rois["Frontal"], :][:, sector_rois["Frontal"]]
        c1_cingulum = fc[sector_rois["Cingulum"], :][:, sector_rois["Cingulum"]]
        c1_precuneus = fc[sector_rois["Precuneus"], :][:, sector_rois["Precuneus"]]
        c1_other = fc[sector_rois["Other"], :][:, sector_rois["Other"]]

        c1_frontal_cingulum = fc[sector_rois["Frontal"], :][:, sector_rois["Cingulum"]]
        c1_frontal_precuneus = fc[sector_rois["Frontal"], :][:, sector_rois["Precuneus"]]
        c1_frontal_other = fc[sector_rois["Frontal"], :][:, sector_rois["Other"]]
        c1_cingulum_precuneus = fc[sector_rois["Cingulum"], :][:, sector_rois["Precuneus"]]
        c1_cingulum_other = fc[sector_rois["Cingulum"], :][:, sector_rois["Other"]]
        c1_precuneus_other = fc[sector_rois["Precuneus"], :][:, sector_rois["Other"]]

        conn_frontal = c1_frontal[np.triu_indices(len(c1_frontal))].mean()
        conn_cingulum = c1_cingulum[np.triu_indices(len(c1_cingulum))].mean()
        conn_precuneus = c1_precuneus[np.triu_indices(len(c1_precuneus))].mean()
        conn_other = c1_other[np.triu_indices(len(c1_other))].mean()

        conn_frontal_cingulum = c1_frontal_cingulum.mean()
        conn_frontal_precuneus = c1_frontal_precuneus.mean()
        conn_frontal_other = c1_frontal_other.mean()
        conn_cingulum_precuneus = c1_cingulum_precuneus.mean()
        conn_cingulum_other = c1_cingulum_other.mean()
        conn_precuneus_other = c1_precuneus_other.mean()

        conn_subrois[sub, window, pred, :] = [conn_frontal, conn_cingulum, conn_precuneus, conn_other]
        conn_subrois_inter[sub, window, pred, :] = [conn_frontal_cingulum, conn_frontal_precuneus, conn_frontal_other, conn_cingulum_precuneus, conn_cingulum_other, conn_precuneus_other]

# hay algunos sujetos que tienen todo estado 1 o estado 2 asi que van a seguir habiendo nans despues de este paso
conn_subrois_avg_window = np.nanmean(conn_subrois, axis=1)
conn_subrois_inter_avg_window = np.nanmean(conn_subrois_inter, axis=1)

In [None]:
fig, ax = plt.subplots(figsize=(10,10))
ax.bar(
    np.array([1,2,3,4])-0.2,
    np.nanmean(conn_subrois_avg_window[:,0,:], axis=0),
    yerr=np.nanstd(conn_subrois_avg_window[:,0,:], axis=0)/np.sqrt(48),
    width=0.4,
    color="#333333",
    capsize=6,
    error_kw=dict(lw=5, capsize=7, capthick=4)
)
ax.bar(
    np.array([1,2,3,4])+0.2,
    np.nanmean(conn_subrois_avg_window[:,1,:], axis=0),
    yerr=np.nanstd(conn_subrois_avg_window[:,1,:], axis=0)/np.sqrt(48),
    width=0.4,
    color="#cccccc",
    capsize=6,
    error_kw=dict(lw=5, capsize=7, capthick=4)
)
ax.set_xticks([1,2,3,4], labels="")
ax.set_ylim([0.2, 0.7])

plt.rcParams.update({'font.size': 50})
ax.spines[['right', 'top']].set_visible(False)
plt.setp(ax.spines.values(), linewidth=8)

plt.tight_layout()
plt.savefig(f"{INPUT_ALL_DIR}/figures/subdivision_intra.png")

In [None]:
fig, ax = plt.subplots(figsize=(14,10))
ax.bar(
    np.array([1,2,3,4,5,6])-0.2,
    np.nanmean(conn_subrois_inter_avg_window[:,0,:], axis=0),
    yerr=np.nanstd(conn_subrois_inter_avg_window[:,0,:], axis=0)/np.sqrt(48),
    width=0.4,
    color="#333333",
    capsize=6,
    error_kw=dict(lw=5, capsize=7, capthick=4)
)
ax.bar(
    np.array([1,2,3,4,5,6])+0.2,
    np.nanmean(conn_subrois_inter_avg_window[:,1,:], axis=0),
    yerr=np.nanstd(conn_subrois_inter_avg_window[:,1,:], axis=0)/np.sqrt(48),
    width=0.4,
    color="#cccccc",
    capsize=6,
    error_kw=dict(lw=5, capsize=7, capthick=4)
)
ax.set_xticks([1,2,3,4,5,6], labels="")
ax.set_ylim([-0.05, 0.42])

plt.rcParams.update({'font.size': 50})
ax.spines[['right', 'top', 'bottom']].set_visible(False)
plt.setp(ax.spines.values(), linewidth=8)

plt.tight_layout()
plt.savefig(f"{INPUT_ALL_DIR}/figures/subdivision_inter.png")

In [None]:
df_data = {"subject": [], "age": [], "sex": [], "centroid": [], "subroi": [], "conn": []}
for sub in range(48):
    for centroid in range(2):
        for subroi in range(4+6):
            df_data["subject"].append(sub+1)
            df_data["age"].append(demographics["age"][sub])
            df_data["sex"].append(demographics["sex"][sub])
            df_data["centroid"].append(centroid+1)
            df_data["subroi"].append(subroi)
            if subroi < 4: #intra
                df_data["conn"].append(conn_subrois_avg_window[sub, centroid, subroi])
            else: #inter
                df_data["conn"].append(conn_subrois_inter_avg_window[sub, centroid, subroi-4])

df = pd.DataFrame(df_data)
df["centroid"] = df["centroid"].astype("category")
df["subroi"] = df["subroi"].astype("category")
df["sex"] = df["sex"].astype("category")

df_frontal = df[df["subroi"] == 0]
df_cingulum = df[df["subroi"] == 1]
df_precuneus = df[df["subroi"] == 2]
df_other = df[df["subroi"] == 3]
df_frontal_cingulum = df[df["subroi"] == 4]
df_frontal_precuneus = df[df["subroi"] == 5]
df_frontal_other = df[df["subroi"] == 6]
df_cingulum_precuneus = df[df["subroi"] == 7]
df_cingulum_other = df[df["subroi"] == 8]
df_precuneus_other = df[df["subroi"] == 9]


ols_model = smf.ols("conn ~ centroid + age + sex", df_frontal).fit()
centroid_pvalue_frontal = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_cingulum).fit()
centroid_pvalue_cingulum = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_precuneus).fit()
centroid_pvalue_precuneus = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_other).fit()
centroid_pvalue_other = ols_model.pvalues[1]

ols_model = smf.ols("conn ~ centroid + age + sex", df_frontal_cingulum).fit()
centroid_pvalue_frontal_cingulum = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_frontal_precuneus).fit()
centroid_pvalue_frontal_precuneus = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_frontal_other).fit()
centroid_pvalue_frontal_other = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_cingulum_precuneus).fit()
centroid_pvalue_cingulum_precuneus = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_cingulum_other).fit()
centroid_pvalue_cingulum_other = ols_model.pvalues[1]
ols_model = smf.ols("conn ~ centroid + age + sex", df_precuneus_other).fit()
print(ols_model.summary2(), ols_model.pvalues[1])
centroid_pvalue_precuneus_other = ols_model.pvalues[1]

h, pcorregido = fdrcorrection(
    [centroid_pvalue_frontal, centroid_pvalue_cingulum, centroid_pvalue_precuneus, centroid_pvalue_other, centroid_pvalue_frontal_cingulum, centroid_pvalue_frontal_precuneus, centroid_pvalue_frontal_other, centroid_pvalue_cingulum_precuneus, centroid_pvalue_cingulum_other, centroid_pvalue_precuneus_other],
    alpha=0.05, method="indep", is_sorted=False
)

h, pcorregido

#### Colorbar

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
cax = ax.imshow(np.zeros((34,34)), vmin=-0.2, vmax=0.8, cmap="turbo")
cbar = fig.colorbar(cax, ticks=[-0.2, 0.8])

plt.rcParams.update({'font.size': 30})
plt.tight_layout()
plt.savefig(f"{INPUT_ALL_DIR}/figures/colorbar.png")

In [None]:
fig, ax = plt.subplots(figsize=(6,6))
cax = ax.imshow(np.zeros((34,34)), vmin=0, vmax=0.4, cmap="turbo")
cbar = fig.colorbar(cax, ticks=[0, 0.4], location="bottom")

plt.rcParams.update({'font.size': 30})
plt.tight_layout()
plt.savefig(f"{INPUT_ALL_DIR}/figures/colorbar-horizontal.png")

#### Probabilities (controls + patients)

Prob 1

In [None]:
plot_points(
    probs_all_controls[:,0], 
    probs_all_patients[:,0], 
    facecolors=["#222222", "#cc0000"], 
    alpha=0.25,
    capsize=5, 
    elinewidth=5, 
    capthick=5,
    dot_size=500,
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 50})
plt.tight_layout()
plt.savefig(f"{INPUT_ALL_DIR}/figures/state1.png")

In [None]:
ols_model = smf.ols("prob1 ~ group + sex + age", demographics).fit()

with open(f"{INPUT_ALL_DIR}/statistics/prob1_group_age_sex.txt", "w") as f:
    f.write(ols_model.summary2().__str__())
    
ols_model.summary2()

Prob 2

In [None]:
plot_points(
    probs_all_controls[:,1], 
    probs_all_patients[:,1], 
    facecolors=["#222222", "#cc0000"],  
    alpha=0.25,
    capsize=5, 
    elinewidth=5, 
    capthick=5,
    dot_size=500,
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 50})
plt.tight_layout()
plt.savefig(f"{INPUT_ALL_DIR}/figures/state2.png")

In [None]:
ols_model = smf.ols("prob1 ~ group + sex + age", demographics).fit()

with open(f"{INPUT_ALL_DIR}/statistics/prob2_group_age_sex.txt", "w") as f:
    f.write(ols_model.summary2().__str__())
    
ols_model.summary2()

#### Correlations with CRSR (only patients)

In [None]:
lr = linregress(outcomes["outcome_crsr"], outcomes["prob1"])
plot_correlation(
    outcomes["outcome_crsr"], 
    outcomes["prob1"], 
    lr, 
    output_file=f"{INPUT_PATIENTS_DIR}/figures/prob1_crsr.png",
    xmin=0,
    xmax=23,
    ymin=0,
    ymax=1,
    xticks=[0, 5, 10, 15, 20],
    xtickslabels=[0,5,10,15,20],
    facecolor="#cc0000",
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 40})

In [None]:
ols_model = smf.ols("prob1 ~ outcome_crsr + age + sex + etiology", outcomes).fit()

with open(f"{INPUT_ALL_DIR}/statistics/prob1_crsr_age_sex_etiology.txt", "w") as f:
    f.write(ols_model.summary2().__str__())
    
ols_model.summary2()

In [None]:
lr = linregress(outcomes["outcome_crsr"], outcomes["prob2"])
plot_correlation(
    outcomes["outcome_crsr"], 
    outcomes["prob2"], 
    lr, 
    output_file=f"{INPUT_PATIENTS_DIR}/figures/prob2_crsr.png",
    xmin=0,
    xmax=23,
    ymin=0,
    ymax=1,
    xticks=[0, 5, 10, 15, 20],
    xtickslabels=[0,5,10,15,20],
    facecolor="#cc0000",
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 40})

In [None]:
ols_model = smf.ols("prob2 ~ outcome_crsr + age + sex + etiology", outcomes).fit()

with open(f"{INPUT_ALL_DIR}/statistics/prob2_crsr_age_sex_etiology.txt", "w") as f:
    f.write(ols_model.summary2().__str__())
    
ols_model.summary2()

In [None]:
lr = linregress(outcomes["outcome_crsr"], outcomes["we"])
plot_correlation(
    outcomes["outcome_crsr"], 
    outcomes["we"], 
    lr, 
    output_file=f"{INPUT_PATIENTS_DIR}/figures/we_crsr.png",
    xmin=0,
    xmax=23,
    ymin=0,
    ymax=1,
    xticks=[0, 5, 10, 15, 20],
    xtickslabels=[0,5,10,15,20],
    facecolor="#cc0000",
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 40})

In [None]:
ols_model = smf.ols("we ~ outcome_crsr + age + sex + etiology", outcomes).fit()

with open(f"{INPUT_ALL_DIR}/statistics/we_crsr_age_sex_etiology.txt", "w") as f:
    f.write(ols_model.summary2().__str__())
    
ols_model.summary2()

#### Binary outcome

In [None]:
improve = outcomes[outcomes["binary_outcome"] == 1]
worse = outcomes[outcomes["binary_outcome"] == 0]

In [None]:
plot_points(
    worse["prob1"], 
    improve["prob1"], 
    facecolors=["#222222", "#cc0000"], 
    alpha=0.25,
    capsize=5, 
    elinewidth=5, 
    capthick=5,
    dot_size=500,
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 40})
plt.tight_layout()
plt.savefig(f"{INPUT_PATIENTS_DIR}/figures/prob1_outcome.png")

In [None]:
ols_model = smf.ols("prob1 ~ binary_outcome_cat + age + sex + etiology", outcomes).fit()

with open(f"{INPUT_ALL_DIR}/statistics/prob1_binaryoutcome_age_sex_etiology.txt", "w") as f:
    f.write(ols_model.summary2().__str__())
    
ols_model.summary2()

In [None]:
pruebita_ols = smf.ols("we ~ binary_outcome_cat + age + sex + etiology", outcomes).fit()
pruebita_ols.summary2()

In [None]:
plot_points(
    worse["prob2"], 
    improve["prob2"], 
    facecolors=["#222222", "#cc0000"], 
    alpha=0.25,
    capsize=5, 
    elinewidth=5, 
    capthick=5,
    dot_size=500,
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 40})
plt.tight_layout()
plt.savefig(f"{INPUT_PATIENTS_DIR}/figures/prob2_outcome.png")

In [None]:
md = smf.mixedlm("prob2 ~ binary_outcome_cat + age + sex + etiology", outcomes, groups=outcomes["patient"])
mdf = md.fit()

with open(f"{INPUT_PATIENTS_DIR}/statistics/prob2_binaryoutcome_age_sex_etiology.txt", "w") as f:
    f.write(mdf.summary().__str__())

mdf.summary()

In [None]:
plot_points(
    worse["we"], 
    improve["we"], 
    facecolors=["#222222", "#cc0000"], 
    alpha=0.25,
    capsize=5, 
    elinewidth=5, 
    capthick=5,
    dot_size=500,
    figsize=(10,10)
)

plt.rcParams.update({'font.size': 40})
plt.tight_layout()
plt.savefig(f"{INPUT_PATIENTS_DIR}/figures/we_outcome.png")

In [None]:
md = smf.mixedlm("we ~ binary_outcome_cat + age + sex + etiology", outcomes, groups=outcomes["patient"])
mdf = md.fit()

with open(f"{INPUT_PATIENTS_DIR}/statistics/we_binaryoutcome_age_sex_etiology.txt", "w") as f:
    f.write(mdf.summary().__str__())

mdf.summary()

#### Leave one out

In [None]:
predicted_outcome = np.zeros(23)
predicted_score_list = np.zeros(23)
for iSub in range(23):
    prob = sio.loadmat(f"{INPUT_LEAVE_ONE_OUT}/{iSub}.mat")["prob"]
    probs = sio.loadmat(f"{INPUT_LEAVE_ONE_OUT}/{iSub}.mat")["probs"]
    entropy = sio.loadmat(f"{INPUT_LEAVE_ONE_OUT}/{iSub}.mat")["entropy"]
    
    outcomes_test = outcomes.copy()
    outcomes_test.drop(outcomes_test.index[iSub], inplace=True)
    outcomes_test["we"] = np.sum(probs*np.repeat(entropy, 22, axis=0), axis=1)

    worse = outcomes_test[outcomes_test["binary_outcome"] == 0]["we"].mean()
    better = outcomes_test[outcomes_test["binary_outcome"] == 1]["we"].mean()
    we = np.sum(prob*entropy)
    predicted_score = (we - worse)/(better-worse)
    predicted_score = np.max([np.min([predicted_score, 1]), 0]) # clamp to [0, 1] 
    predicted_score_list[iSub] = predicted_score
    #if we > (worse+better)/2:
    #    predicted_outcome[iSub] = 1
    #else:
    #    predicted_outcome[iSub] = 0

predicted_outcome = (predicted_score_list > 0.5).astype(int)

In [None]:
outcomes["predicted_outcome"] = predicted_outcome
cm = confusion_matrix(outcomes["binary_outcome"].values, outcomes["predicted_outcome"].values)
cm_norm = confusion_matrix(outcomes["binary_outcome"].values, outcomes["predicted_outcome"].values, normalize="true")

In [None]:
disp = ConfusionMatrixDisplay(cm)
disp.plot()

#plt.savefig(f"{INPUT_LEAVE_ONE_OUT}/figures/confusion_matrix.png")

In [None]:
stat, p = fisher_exact(cm)
stat, p

In [None]:
plot_confusion_matrix(cm_norm, output_path=f"{INPUT_PATIENTS_DIR}/figures/confusion_matrix.png")

In [None]:
fp, tp, thresholds = roc_curve(outcomes["binary_outcome"].values, predicted_score_list)
auc = roc_auc_score(outcomes["binary_outcome"].values, predicted_score_list)
prec = precision_score(outcomes["binary_outcome"].values, outcomes["predicted_outcome"].values)
f1 = f1_score(outcomes["binary_outcome"].values, outcomes["predicted_outcome"].values)
recall = recall_score(outcomes["binary_outcome"].values, outcomes["predicted_outcome"].values)

fig, ax = plt.subplots(figsize=(10,10))
ax.plot(fp, tp, c="#333333", lw=6)
ax.plot([0, 1], [0, 1], c="red", linestyle="dashed", lw=6)
ax.set_yticks([0, 0.5, 1])

ax.spines[['right', 'top']].set_visible(False)
plt.setp(ax.spines.values(), linewidth=6)

plt.rcParams.update({'font.size': 40})
plt.tight_layout()
plt.savefig(f"{INPUT_LEAVE_ONE_OUT}/figures/roc.png")

In [None]:
print(f"AUC = {auc:.3f}")
print(f"Precision = {prec}")
print(f"F1-score = {f1}")
print(f"Recall = {recall}")