In [None]:
from brainstates import BrainStates
import numpy as np
from scipy.signal.windows import hamming
import scipy.io as sio
import scipy.stats as stats
import matplotlib.pyplot as plt
from glob import glob
from tqdm import tqdm
import pandas as pd
from plot_points import plot_points
import statsmodels.api as sm
import statsmodels.formula.api as smf
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

In [None]:
grid1, grid2 = np.meshgrid([30,40,50,60,70], [2,8,16,32])
combinations = np.dstack((grid1, grid2)).reshape(-1, 2)

In [None]:
cm_list = np.zeros((len(combinations), 2, 2))
for counter, (WINDOW_SIZE, STRIDE) in tqdm(enumerate(combinations)):
    MAIN_DIR = "/home/usuario/disco1/proyectos/2023-Coma3D"
    N_CLUSTERS = 3
    REPLICATES = 1000
    GROUP = "patients"

    subs = glob(f"{MAIN_DIR}/Preprocess/parcellation/DMN_all/*.mat")
    subs = sorted(subs)
    n_time = 500
    n_rois = 34
    data = np.zeros((n_time, n_rois, len(subs)))
    for i, sub in enumerate(subs):
        data[:,:,i] = sio.loadmat(sub)["func_roi"].T

    if GROUP == "patients":
        data_conditions = {"patients": data[:,:,0:25]}
        path = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/patients/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
    else:
        data_conditions = {"patients": data[:,:,0:25], "controls": data[:,:,25:]}
        path = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/all/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"


    b = BrainStates(
        from_dict=data_conditions,
        output_path=path,
        export_vars=True,
        verbose=True
    )

    b.run(
        window_size=WINDOW_SIZE,
        stride=STRIDE,
        tapering_function=None,
        subsampling=1,
        n_clusters=N_CLUSTERS,
        n_init=REPLICATES
    )


    probs_patients = b.get_probs()["patients"]
    if GROUP != "patients":
        probs_controls = b.get_probs()["controls"]

    patients_data = data[:,:,0:25].copy()
    IGNORE_OUTCOME = [6, 16] # we don't have the data on these ones
    IGNORE_DATA = [6, 17] # we don't have the outcomes on these ones
    patients_data = np.delete(patients_data, IGNORE_DATA, axis=2)
    for iSub in tqdm(range(23)):
        patients_data_train = np.delete(patients_data, iSub, axis=2)
        data_conditions = {"patients": patients_data_train}
        b = BrainStates(
            from_dict=data_conditions,
            output_path=f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/leave_one_out/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}",
        )
        b.run(
            window_size=WINDOW_SIZE,
            stride=STRIDE,
            tapering_function=None,
            subsampling=1,
            n_clusters=N_CLUSTERS,
            n_init=REPLICATES
        )
        prob = b.calculate_prob(b.dfc_all_pipeline.fit_transform(np.expand_dims(patients_data[:,:,iSub], axis=2)))

        sio.savemat(f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/leave_one_out/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}/{iSub}.mat", {"prob": prob, "probs": b.get_probs()["patients"], "entropy": b.entropy})

    # outcomes

    INPUT_PATIENTS_DIR = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/patients/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
    INPUT_ALL_DIR = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/all/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
    INPUT_LEAVE_ONE_OUT = f"{MAIN_DIR}/scripts/brainstates/pipeline/outputs/leave_one_out/{N_CLUSTERS}/{WINDOW_SIZE}_{STRIDE}"
    IGNORE_OUTCOME = [6, 16] # we don't have the data on these ones
    IGNORE_DATA = [6, 17] # we don't have the outcomes on these ones

    outcomes = pd.read_csv(f"{MAIN_DIR}/scripts/demographics.csv")
    outcomes.drop(index=IGNORE_OUTCOME, inplace=True)
    outcomes["etiology"] = outcomes["etiology"].astype("category")
    outcomes["sex"] = outcomes["sex"].astype("category")
    outcomes["binary_outcome_cat"] = list(map(lambda x: "WORSE" if x == 0 else "BETTER", outcomes["binary_outcome"]))
    outcomes["binary_outcome_cat"] = outcomes["binary_outcome_cat"].astype("category")

    # Import relevant variables and match the corresponding outcome data
    probs = sio.loadmat(f"{INPUT_PATIENTS_DIR}/probs_patients.mat")["probs"]

    entropy = sio.loadmat(f"{INPUT_PATIENTS_DIR}/entropy.mat")["H"]
    probs = np.delete(probs, IGNORE_DATA, axis=0)
    outcomes["prob1"] = probs[:,0]
    outcomes["prob2"] = probs[:,1]
    #outcomes["prob3"] = probs[:,2]
    outcomes["we"] = np.sum(probs*np.repeat(entropy, 23, axis=0), axis=1)

    predicted_outcome = np.zeros(23)
    for iSub in range(23):
        prob = sio.loadmat(f"{INPUT_LEAVE_ONE_OUT}/{iSub}.mat")["prob"]
        probs = sio.loadmat(f"{INPUT_LEAVE_ONE_OUT}/{iSub}.mat")["probs"]
        entropy = sio.loadmat(f"{INPUT_LEAVE_ONE_OUT}/{iSub}.mat")["entropy"]
        
        outcomes_test = outcomes.copy()
        outcomes_test.drop(outcomes_test.index[iSub], inplace=True)
        outcomes_test["we"] = np.sum(probs*np.repeat(entropy, 22, axis=0), axis=1)

        worse = outcomes_test[outcomes_test["binary_outcome"] == 0]["we"].mean()
        better = outcomes_test[outcomes_test["binary_outcome"] == 1]["we"].mean()
        we = np.sum(prob*entropy)
        if we > (worse+better)/2:
            predicted_outcome[iSub] = 1
        else:
            predicted_outcome[iSub] = 0

    outcomes["predicted_outcome"] = predicted_outcome.astype(int)
    cm = confusion_matrix(outcomes["binary_outcome"].values, outcomes["predicted_outcome"].values, normalize="true")
    cm_list[counter, :, :] = cm
    #disp = ConfusionMatrixDisplay(cm)
    #disp.plot()

    #plt.savefig(f"{INPUT_LEAVE_ONE_OUT}/figures/confusion_matrix.png")

In [None]:
tp = cm_list[:,0,0]
tn = cm_list[:,1,1]
plt.plot(tp)
plt.plot(tn)

In [None]:
tp[9], tn[9], combinations[9]