In [None]:
import pickle
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.pyplot import gcf
import seaborn as sns
import numpy as np
from sklearn.preprocessing import StandardScaler, PowerTransformer, RobustScaler
from scipy.stats import f_oneway, ttest_ind
from sklearn.metrics.pairwise import cosine_similarity

In [None]:
# RELAPSE
def genfi(data, res_dir, args):
    ## Finding Best Initialization
    ofile = open(f'{res_dir}/results.txt', 'w')

    exp_logs, ofile = find_bestrun(res_dir, args, ofile)
    brun = np.nanargmax(exp_logs) + 1
    print('Best run: ', brun, file=ofile)

    ## Setting Up Plot Directories
    plot_path = f'{res_dir}/plots_{brun}'

    if not os.path.exists(plot_path):
        os.makedirs(plot_path)
        os.makedirs(f'{plot_path}/svgs')

    ## Standardizing Data and Calculating Total Variance
    X = data.get('X')
    Y = data.get('Y')

    scaler = StandardScaler()
    robust_scaler = RobustScaler()
    transformer = PowerTransformer(method='yeo-johnson')

    # Modality Slices
    X1_columns = slice(0, 33)
    X2_columns = slice(33, 62)
    X3_columns = slice(62, 68)
    X4_columns = slice(68, 108)
    #X5_columns = slice(81, 121)

    # Calculate Variance for Each Modality
    """var_X1 = np.var(X[:, X1_columns], axis=0).mean()
    var_X2 = np.var(X[:, X2_columns], axis=0).mean()
    var_X3 = np.var(X[:, X3_columns], axis=0).mean()
    var_X4 = np.var(X[:, X4_columns], axis=0).mean()
    var_X5 = np.var(X[:, X5_columns], axis=0).mean()"""

    # Compute the Square Root of the Variance for Scaling
    """scaling_factors = {
        'X1': np.sqrt(var_X1),
        'X2': np.sqrt(var_X2),
        'X3': np.sqrt(var_X3),
        'X4': np.sqrt(var_X4),
        'X5': np.sqrt(var_X5)}"""

    # Apply Scaling by the Square Root of Variance
    """X[:, X1_columns] = X[:, X1_columns] / scaling_factors['X1']
    X[:, X2_columns] = X[:, X2_columns] / scaling_factors['X2']
    X[:, X3_columns] = X[:, X3_columns] / scaling_factors['X3']
    X[:, X4_columns] = X[:, X4_columns] / scaling_factors['X4']
    X[:, X5_columns] = X[:, X5_columns] / scaling_factors['X5']"""

    # Apply Standard Scaler (Feature-wise) and Box-Cox
    X[:, X1_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X1_columns]))
    X[:, X2_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X2_columns]))
    X[:, X3_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X3_columns]))
    X[:, X4_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X4_columns]))
    #X[:, X5_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X5_columns]))

    Tvar = np.trace(np.dot(X.T, X))

    ## Loading Robust Parameters

    rparams_path = f'{res_dir}/[{brun}]Robust_params.dictionary'

    if os.stat(rparams_path).st_size > 5:
        with open(rparams_path, 'rb') as parameters:
            rob_params = pickle.load(parameters)

        ## Calculating Variance Explained
        var_comps = []
        X_inf = rob_params['infX']

        for k in range(len(X_inf)):
            var_Xk = np.trace(np.dot(X_inf[k][0].T, X_inf[k][0])) / Tvar
            var_comps.append(var_Xk)

        ## Creating Scree Plot (visualizes the variance explained by each factor)

        varexp_comps = np.array(var_comps)
        ids_var = np.argsort(-varexp_comps)
        varexp_comps = varexp_comps[ids_var]

        # print(f'ids_var: {ids_var}')
        # print(f'varexp_comps: {varexp_comps}')

        x = np.arange(len(var_comps) + 1)
        cum_var = [0]

        for i in range(1, varexp_comps.size + 1):
            if i == 1:
                cum_var.append(varexp_comps[i - 1] * 100)
            else:
                cum_var.append(varexp_comps[i - 1] * 100 + cum_var[i - 1])

        plt.figure(figsize=(5, 5), dpi=300)
        plt.plot(x, cum_var, 'ko-', linewidth=2)
        plt.xlabel('Factors')
        plt.ylabel('Covariance explained (%)')
        plt.xticks(x, [f'{i}' for i in range(x.size)])
        plt.savefig(f'{plot_path}/Scree_plot.png')
        plt.savefig(f'{plot_path}/svgs/Scree_plot.svg')
        plt.close()

        ## Plotting Weights and Printing Total Explained Variance

        df_var = pd.read_csv(f'./aida_model/var_labels.csv')

        W = rob_params.get('W')[:, ids_var]

        print(f'\nTotal variance explained: {np.around(sum(var_comps) * 100, 2)}\n', file=ofile)

        structural_MRI_labels = list(df_var.iloc[0:33, 1])
        pos_fMRI_labels = list(df_var.iloc[33:62, 1])
        #neg_fMRI_labels = list(df_var.iloc[62:68, 1])
        # pos_fMRI_beta1_labels = list(df_var.iloc[69:138, 0])
        # pos_fMRI_beta3_labels = list(df_var.iloc[138:207, 0])
        # neg_fMRI_beta1_labels = list(df_var.iloc[138:207, 0])
        # neg_fMRI_beta3_labels = list(df_var.iloc[276:345, 0])
        # sad_EEG_labels = list(df_var.iloc[33:95, 0])
        # neutral_EEG_labels = list(df_var.iloc[75:168, 0])
        fert_labels = list(df_var.iloc[62:68, 1])
        # effort_labels = list(df_var.iloc[95:110, 0])
        questionnaire_labels = list(df_var.iloc[68:108, 1])
        # clinical_labels = list(df_var.iloc[83:90, 0])



        if 'sparseGFA' in args.model:

            ## Plotting Structural MRI Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_smri = W[0:33, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_smri]
                ax = axes[j]
                ax.barh(np.arange(w_smri.size), w_smri, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.50, 0.50])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.50, 1.00, 0.50))
                ax.set_yticks(np.arange(w_smri.size))

                if j == 0:
                    ax.set_ylabel('structural MRI variables')
                    ax.set_yticklabels(structural_MRI_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_sMRI_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_sMRI_loadings.svg')
            plt.close()

            ## Plotting Positive fMRI Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_pos_fmri = W[33:62, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_pos_fmri]
                ax = axes[j]
                ax.barh(np.arange(w_pos_fmri.size), w_pos_fmri, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.50, 0.50])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.50, 1.00, 0.50))
                ax.set_yticks(np.arange(w_pos_fmri.size))

                if j == 0:
                    ax.set_ylabel('positive fMRI variables')
                    ax.set_yticklabels(pos_fMRI_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_pos_fMRI_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_pos_fMRI_loadings.svg')
            plt.close()

            ## Plotting positive fMRI beta1 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_pos_fmri_beta1 = W[69:138, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_pos_fmri_beta1]
                ax = axes[j]
                ax.barh(np.arange(w_pos_fmri_beta1.size), w_pos_fmri_beta1, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_pos_fmri_beta1.size))

                if j == 0:
                    ax.set_ylabel('positive fMRI beta1 variables')
                    ax.set_yticklabels(pos_fMRI_beta1_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_pos_fMRI_beta1_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_pos_fMRI_beta1_loadings.svg')
            plt.close()"""

            ## Plotting positive fMRI beta3 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_pos_fmri_beta3 = W[138:207, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_pos_fmri_beta3]
                ax = axes[j]
                ax.barh(np.arange(w_pos_fmri_beta3.size), w_pos_fmri_beta3, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_pos_fmri_beta3.size))

                if j == 0:
                    ax.set_ylabel('positive fMRI beta3 variables')
                    ax.set_yticklabels(pos_fMRI_beta3_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_pos_fMRI_beta3_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_pos_fMRI_beta3_loadings.svg')
            plt.close()"""

            ## Plotting negative fMRI Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neg_fmri = W[62:91, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neg_fmri]
                ax = axes[j]
                ax.barh(np.arange(w_neg_fmri.size), w_neg_fmri, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.50, 0.50])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.50, 1.00, 0.50))
                ax.set_yticks(np.arange(w_neg_fmri.size))

                if j == 0:
                    ax.set_ylabel('negative fMRI variables')
                    ax.set_yticklabels(neg_fMRI_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neg_fMRI_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neg_fMRI_loadings.svg')
            plt.close()"""

            ## Plotting negative fMRI beta1 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neg_fmri_beta1 = W[138:207, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neg_fmri_beta1]
                ax = axes[j]
                ax.barh(np.arange(w_neg_fmri_beta1.size), w_neg_fmri_beta1, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_neg_fmri_beta1.size))

                if j == 0:
                    ax.set_ylabel('negative fMRI beta1 variables')
                    ax.set_yticklabels(neg_fMRI_beta1_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neg_fMRI_beta1_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neg_fMRI_beta1_loadings.svg')
            plt.close()"""

            ## Plotting negative fMRI beta1 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neg_fmri_beta3 = W[276:345, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neg_fmri_beta3]
                ax = axes[j]
                ax.barh(np.arange(w_neg_fmri_beta3.size), w_neg_fmri_beta3, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_neg_fmri_beta3.size))

                if j == 0:
                    ax.set_ylabel('negative fMRI beta3 variables')
                    ax.set_yticklabels(neg_fMRI_beta3_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neg_fMRI_beta3_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neg_fMRI_beta3_loadings.svg')
            plt.close()"""

            ## Plotting sad EEG Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_sad_eeg = W[33:95, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sad_eeg]
                ax = axes[j]
                ax.barh(np.arange(w_sad_eeg.size), w_sad_eeg, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_sad_eeg.size))

                if j == 0:
                    ax.set_ylabel('Sad EEG Variables')
                    ax.set_yticklabels(sad_EEG_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_sad_EEG_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_sad_EEG_loadings.svg')
            plt.close()"""

            ## Plotting neutral EEG Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neutral_eeg = W[75:168, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neutral_eeg]
                ax = axes[j]
                ax.barh(np.arange(w_neutral_eeg.size), w_neutral_eeg, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_neutral_eeg.size))

                if j == 0:
                    ax.set_ylabel('Neutral EEG Variables')
                    ax.set_yticklabels(neutral_EEG_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neutral_EEG_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neutral_EEG_loadings.svg')
            plt.close()"""

            ## Plotting fert Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_fert = W[62:68, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_fert]
                ax = axes[j]
                ax.barh(np.arange(w_fert.size), w_fert, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.50, 0.50])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.50, 1.00, 0.50))
                ax.set_yticks(np.arange(w_fert.size))

                if j == 0:
                    ax.set_ylabel('FERT Variables')
                    ax.set_yticklabels(fert_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_fert_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_fert_loadings.svg')
            plt.close()

            ## Plotting effort Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_effort = W[95:110, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_effort]
                ax = axes[j]
                ax.barh(np.arange(w_effort.size), w_effort, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_effort.size))

                if j == 0:
                    ax.set_ylabel('EFFORT Variables')
                    ax.set_yticklabels(effort_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_effort_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_effort_loadings.svg')
            plt.close()"""

            ## Plotting questionnaire Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_ques = W[68:108, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_ques]
                ax = axes[j]   
                ax.barh(np.arange(w_ques.size), w_ques, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.50, 0.50])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.50, 1.00, 0.50))
                ax.set_yticks(np.arange(w_ques.size))
                

                if j == 0:
                    ax.set_ylabel('questionnaire variables')
                    ax.set_yticklabels(questionnaire_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_questionnaire_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_questionnaire_loadings.svg')
            plt.close()

            ## Plotting clinical Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_cli = W[83:90, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_cli]
                ax = axes[j]
                ax.barh(np.arange(w_cli.size), w_cli, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_cli.size))

                if j == 0:
                    ax.set_ylabel('clinical variables')
                    ax.set_yticklabels(clinical_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_clinical_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_clinical_loadings.svg')
            plt.close()"""

        ## Plotting Top Components

        top = 5
        if len(var_comps) > top:
            pass
        else:
            top = len(var_comps)

        for j in range(top):
            # print(f'Variance explained by cmp {j+1}: {np.around(var_comps[j] * 100, 2)}', file=ofile)
            print(f'Variance explained by cmp {j + 1}: {np.around(varexp_comps[j] * 100, 2)}', file=ofile)

            # structural MRI
            w_smri = W[0:33, j]
            w_sort = w_smri[np.argsort(w_smri)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_smri.size), w_sort, color=colours)
            plt.ylabel('structural MRI variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])

            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_smri.size).tolist(),
                       [structural_MRI_labels[np.argsort(w_smri)[i]] for i in range(len(structural_MRI_labels))],
                       fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/structuralMRI_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/structuralMRI_loadings{j + 1}.svg')
            plt.close()

            # positive fMRI
            w_pos_fmri = W[33:62, j]
            w_sort = w_pos_fmri[np.argsort(w_pos_fmri)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_pos_fmri.size), w_sort, color=colours)
            plt.ylabel('positive fMRI variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])

            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_pos_fmri.size).tolist(),
                       [pos_fMRI_labels[np.argsort(w_pos_fmri)[i]] for i in range(len(pos_fMRI_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/positive_fMRI_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/positive_fMRI_loadings{j + 1}.svg')
            plt.close()

            # positive fMRI beta1
            """w_pos_fmri_beta1 = W[69:138, j]
            w_sort = w_pos_fmri_beta1[np.argsort(w_pos_fmri_beta1)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_pos_fmri_beta1.size), w_sort, color=colours)
            plt.ylabel('positive fMRI beta1 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_pos_fmri_beta1.size).tolist(),
                       [pos_fMRI_beta1_labels[np.argsort(w_pos_fmri_beta1)[i]] for i in
                        range(len(pos_fMRI_beta1_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/positive_fMRI_beta1_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/positive_fMRI_beta1_loadings{j + 1}.svg')
            plt.close()"""

            # positive fMRI beta3
            """w_pos_fmri_beta3 = W[138:207, j]
            w_sort = w_pos_fmri_beta3[np.argsort(w_pos_fmri_beta3)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_pos_fmri_beta3.size), w_sort, color=colours)
            plt.ylabel('positive fMRI beta3 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_pos_fmri_beta3.size).tolist(), [pos_fMRI_beta3_labels[np.argsort(w_pos_fmri_beta3)[i]] for i in range(len(pos_fMRI_beta3_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_beta3_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_beta3_loadings{j + 1}.svg')
            plt.close()"""

            # negative fMRI beta1
            """w_neg_fmri_beta1 = W[46:59, j]
            w_sort = w_neg_fmri_beta1[np.argsort(w_neg_fmri_beta1)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neg_fmri_beta1.size), w_sort, color=colours)
            plt.ylabel('negative fMRI beta1 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_neg_fmri_beta1.size).tolist(),
                       [neg_fMRI_beta1_labels[np.argsort(w_neg_fmri_beta1)[i]] for i in
                        range(len(neg_fMRI_beta1_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_beta1_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_beta1_loadings{j + 1}.svg')
            plt.close()"""

            # negative fMRI
            """w_neg_fmri = W[62:91, j]
            w_sort = w_neg_fmri[np.argsort(w_neg_fmri)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neg_fmri.size), w_sort, color=colours)
            plt.ylabel('negative fMRI variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])

            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_neg_fmri.size).tolist(),
                       [neg_fMRI_labels[np.argsort(w_neg_fmri)[i]] for i in range(len(neg_fMRI_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_loadings{j + 1}.svg')
            plt.close()"""

            # negative fMRI beta3
            """w_neg_fmri_beta3 = W[276:345, j]
            w_sort = w_neg_fmri_beta3[np.argsort(w_neg_fmri_beta3)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neg_fmri_beta3.size), w_sort, color=colours)
            plt.ylabel('negative fMRI beta3 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_neg_fmri_beta3.size).tolist(),[neg_fMRI_beta3_labels[np.argsort(w_neg_fmri_beta3)[i]] for i in range(len(neg_fMRI_beta3_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_beta3_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_beta3_loadings{j + 1}.svg')
            plt.close()"""

            # sad EEG
            """w_sad_eeg = W[42:104, j]
            w_sort = w_sad_eeg[np.argsort(w_sad_eeg)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_sad_eeg.size), w_sort, color=colours)
            plt.ylabel('Sad EEG Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_sad_eeg.size).tolist(),
                       [sad_EEG_labels[np.argsort(w_sad_eeg)[i]] for i in range(len(sad_EEG_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/sad_eeg_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/sad_eeg_loadings{j + 1}.svg')
            plt.close()"""

            # neutral EEG
            """w_neutral_eeg = W[75:168, j]
            w_sort = w_neutral_eeg[np.argsort(w_neutral_eeg)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neutral_eeg.size), w_sort, color=colours)
            plt.ylabel('Neutral EEG Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_neutral_eeg.size).tolist(),
                       [neutral_EEG_labels[np.argsort(w_neutral_eeg)[i]] for i in range(len(neutral_EEG_labels))],
                       fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/neutral_eeg_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/neutral_eeg_loadings{j + 1}.svg')
            plt.close()"""

            # fert
            w_fert = W[62:68, j]
            w_sort = w_fert[np.argsort(w_fert)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_fert.size), w_sort, color=colours)
            plt.ylabel('FERT Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])

            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_fert.size).tolist(),
                       [fert_labels[np.argsort(w_fert)[i]] for i in range(len(fert_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/fert_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/fert_loadings{j + 1}.svg')
            plt.close()

            # effort
            """w_effort = W[95:110, j]
            w_sort = w_effort[np.argsort(w_effort)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_effort.size), w_sort, color=colours)
            plt.ylabel('EFFORT Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_effort.size).tolist(), [effort_labels[np.argsort(w_effort)[i]] for i in range(len(effort_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/effort_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/effort_loadings{j + 1}.svg')
            plt.close()"""

            # questionnaire
            w_ques = W[68:108, j]
            w_sort = w_ques[np.argsort(w_ques)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_ques.size), w_sort, color=colours)
            plt.ylabel('questionnaire variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])

            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_ques.size).tolist(),
                       [questionnaire_labels[np.argsort(w_ques)[i]] for i in range(len(questionnaire_labels))],
                       fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/questionnaire_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/questionnaire_loadings{j + 1}.svg')
            plt.close()

            # clinical
            """w_cli = W[83:90, j]
            w_sort = w_cli[np.argsort(w_cli)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_cli.size), w_sort, color=colours)
            plt.ylabel('clinical variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.0, 2.0])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_cli.size).tolist(), [clinical_labels[np.argsort(w_cli)[i]] for i in range(len(clinical_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/clinical_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/clinical_loadings{j + 1}.svg')
            plt.close()"""

        ## Plotting Robust Components

        abs_mean_scores, scores_dist = plot_components(rob_params, top, ids_var, plot_path)

        total_scores = np.sum(abs_mean_scores, axis=0)

        x = np.arange(abs_mean_scores.shape[1])

        colors = ['#fdbb84', '#2b8cbe']

        width = 0.3

        b_diff = -width

        plt.figure(figsize=(7, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        for s in range(abs_mean_scores.shape[0]):
            plt.bar(x + b_diff, abs_mean_scores[s, :] / total_scores, width=width, color=colors[s])

            b_diff += width

        plt.xticks(x, [f'{i + 1}' for i in range(x.size)], fontsize=0.8 * fontsize)
        plt.ylim([0, 1])

        plt.yticks(fontsize=0.8 * fontsize)
        plt.xlabel('Factors', fontsize=fontsize)
        plt.ylabel('Factor contributions', fontsize=fontsize)

        plt.legend(['relapse', 'no_relapse'], fontsize=0.85 * fontsize)

        plt.tight_layout()
        plt.savefig(f'{plot_path}/Subtype_scores.png')
        plt.savefig(f'{plot_path}/svgs/Subtype_scores.svg')
        plt.close()

        ## Plotting Subtype Representation for Top Components

        plt.figure(figsize=(4, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        x = np.arange(top)
        height = 0.2
        b_diff = -height

        for s in range(abs_mean_scores.shape[0]):
            plt.barh(x + b_diff, width=abs_mean_scores[s, 0:top] / total_scores[0:top], height=height, color=colors[s])

            b_diff += height

        plt.yticks(x, [f'{i + 1}' for i in range(top)], fontsize=0.8 * fontsize)
        plt.xticks(fontsize=0.8 * fontsize)
        plt.xlabel('Factor contributions', fontsize=fontsize)
        plt.ylabel('Factors', fontsize=fontsize)

        plt.legend(['relapse', 'no_relapse'], fontsize=0.85 * fontsize)

        plt.tight_layout()
        plt.savefig(f'{plot_path}/Subtype_scores_top{top}.png')
        plt.savefig(f'{plot_path}/svgs/Subtype_scores_top{top}.svg')
        plt.close()

        ## Creating Boxplot for Subtype Scores

        # Creating Boxplot:
        plt.figure(figsize=(7, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        # Extracting Scores:
        rel = [scores_dist['s1'][:, i] for i in range(top)]
        norel = [scores_dist['s2'][:, i] for i in range(top)]

        ticks = [f'Factor {i + 1}' for i in range(top)]

        # Plotting Boxplots:
        rel_plot = plt.boxplot(rel, positions=np.array(np.arange(len(rel))) * 2.0 - 0.6, widths=0.5)
        norel_plot = plt.boxplot(norel, positions=np.array(np.arange(len(norel))) * 2.0, widths=0.5)

        # Setting Boxplot Properties:
        define_box_properties(rel_plot, '#fdbb84', 'rel')
        define_box_properties(norel_plot, '#2b8cbe', 'norel')

        plt.xticks(np.arange(0, len(ticks) * 2, 2), ticks, fontsize=0.8 * fontsize)
        plt.yticks(fontsize=0.8 * fontsize)
        plt.legend(fontsize=0.8 * fontsize)
        plt.xlim(-2, len(ticks) * 2)
        plt.ylabel('Absolute latent scores', fontsize=fontsize)
        plt.savefig(f'{plot_path}/Subtype_scores_boxplot.png', dpi=300)
        plt.savefig(f'{plot_path}/svgs/Subtype_scores_boxplot.svg')
        plt.close()

        ## Computing F Statistic (performing the F-test and t-tests helps in statistically validating the differences between the subtypes across the factors)

        scores = np.abs(rob_params['Z'][:, ids_var])

        N = scores.shape[0]

        df_subjs = pd.read_csv(f'./aida_model/visit11_data_{N}subjs.csv')

        ids = list(df_subjs["relapse"])

        ns = [sum([x == 1 for x in ids]), sum([x == 0 for x in ids])]

        g1 = scores[0:ns[0], :]
        g2 = scores[ns[0]:ns[0] + ns[1], :]

        comps = [f'Factor {i + 1}' for i in range(scores.shape[1])]

        ## F Test

        stats_all = f_oneway(g1, g2)

        df_all = pd.DataFrame(index=comps, columns=['F score', 'p-value'])
        df_all['F score'] = stats_all[0]
        df_all['p-value'] = stats_all[1]
        df_all.to_csv(f'{plot_path}/Ftest.csv')

        ## T Tests:

        stats_g1g2 = ttest_ind(g1, g2)
        df_g1g2 = pd.DataFrame(index=comps, columns=['t', 'p-value'])
        df_g1g2['t'] = stats_g1g2[0]
        df_g1g2['p-value'] = stats_g1g2[1]
        df_g1g2.to_csv(f'{plot_path}/Ttest_RELvsNOREL.csv')


In [None]:
# DISEASE
def genfi(data, res_dir, args):

    ## Finding Best Initialization
    ofile = open(f'{res_dir}/results.txt', 'w')

    exp_logs, ofile = find_bestrun(res_dir, args, ofile)
    brun = np.nanargmax(exp_logs) + 1
    print('Best run: ', brun, file=ofile)

    ## Setting Up Plot Directories
    plot_path = f'{res_dir}/plots_{brun}'

    if not os.path.exists(plot_path):
        os.makedirs(plot_path)
        os.makedirs(f'{plot_path}/svgs')

    ## Standardizing Data and Calculating Total Variance
    X = data.get('X')
    Y = data.get('Y')

    scaler = StandardScaler()
    transformer = PowerTransformer(method='yeo-johnson')

    X1_columns = slice(0, 155)
    X2_columns = slice(155, 310)
    X3_columns = slice(310, 350)
    #X4_columns = slice(169, 240)
    #X5_columns = slice(240, 245)
    
    # Standard Scaler (Feature-wise) and Box-Cox if 
    """X[:, X1_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X1_columns]))
    X[:, X2_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X2_columns]))
    X[:, X3_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X3_columns]))"""
    
    # Apply Standard Scaler (Feature-wise) only
    X[:, X1_columns] = scaler.fit_transform(X[:, X1_columns])
    X[:, X2_columns] = scaler.fit_transform(X[:, X2_columns])
    X[:, X3_columns] = scaler.fit_transform(X[:, X3_columns])

    Tvar = np.trace(np.dot(X.T, X))

    ## Loading Robust Parameters

    rparams_path = f'{res_dir}/[{brun}]Robust_params.dictionary'

    if os.stat(rparams_path).st_size > 5:
        with open(rparams_path, 'rb') as parameters:
            rob_params = pickle.load(parameters)

        ## Calculating Variance Explained
        var_comps = []
        X_inf = rob_params['infX']

        for k in range(len(X_inf)):
            var_Xk = np.trace(np.dot(X_inf[k][0].T, X_inf[k][0])) / Tvar
            var_comps.append(var_Xk)

        ## Creating Scree Plot (visualizes the variance explained by each factor)

        varexp_comps = np.array(var_comps)
        ids_var = np.argsort(-varexp_comps)
        varexp_comps = varexp_comps[ids_var]

        # print(f'ids_var: {ids_var}')
        # print(f'varexp_comps: {varexp_comps}')

        x = np.arange(len(var_comps) + 1)
        cum_var = [0]

        for i in range(1, varexp_comps.size + 1):
            if i == 1:
                cum_var.append(varexp_comps[i - 1] * 100)
            else:
                cum_var.append(varexp_comps[i - 1] * 100 + cum_var[i - 1])

        plt.figure(figsize=(5, 5), dpi=300)
        plt.plot(x, cum_var, 'ko-', linewidth=2)
        plt.xlabel('Factors')
        plt.ylabel('Covariance explained (%)')
        plt.xticks(x, [f'{i}' for i in range(x.size)])
        plt.savefig(f'{plot_path}/Scree_plot.png')
        plt.savefig(f'{plot_path}/svgs/Scree_plot.svg')
        plt.close()

        ## Plotting Weights and Printing Total Explained Variance

        df_var = pd.read_csv(f'./aida_model/var_labels.csv')

        W = rob_params.get('W')[:, ids_var]

        print(f'\nTotal variance explained: {np.around(sum(var_comps) * 100, 2)}\n', file=ofile)

        #structural_MRI_labels = list(df_var.iloc[0:69, 0])

        #pos_fMRI_beta1_labels = list(df_var.iloc[0:69, 0])
        #pos_fMRI_beta3_labels = list(df_var.iloc[138:207, 0])

        #neg_fMRI_beta1_labels = list(df_var.iloc[69:138, 0])
        #neg_fMRI_beta3_labels = list(df_var.iloc[276:345, 0])

        sad_EEG_labels = list(df_var.iloc[0:155, 0])
        neutral_EEG_labels = list(df_var.iloc[155:310, 0])
        
        questionnaire_labels = list(df_var.iloc[310:350, 0])
        #clinical_labels = list(df_var.iloc[226:231, 0])

        ## Plotting Structural MRI Weights for Each Component

        if 'sparseGFA' in args.model:

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_smri = W[0:69, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_smri]
                ax = axes[j]
                ax.barh(np.arange(w_smri.size), w_smri, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_smri.size))

                if j == 0:
                    ax.set_ylabel('structural MRI variables')
                    ax.set_yticklabels(structural_MRI_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_sMRI_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_sMRI_loadings.svg')
            plt.close()"""

            ## Plotting positive fMRI beta1 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_pos_fmri_beta1 = W[0:69, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_pos_fmri_beta1]
                ax = axes[j]
                ax.barh(np.arange(w_pos_fmri_beta1.size), w_pos_fmri_beta1, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_pos_fmri_beta1.size))

                if j == 0:
                    ax.set_ylabel('positive fMRI beta1 variables')
                    ax.set_yticklabels(pos_fMRI_beta1_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_pos_fMRI_beta1_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_pos_fMRI_beta1_loadings.svg')
            plt.close()"""

            ## Plotting positive fMRI beta3 Weights for Each Component
            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_pos_fmri_beta3 = W[138:207, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_pos_fmri_beta3]
                ax = axes[j]
                ax.barh(np.arange(w_pos_fmri_beta3.size), w_pos_fmri_beta3, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_pos_fmri_beta3.size))

                if j == 0:
                    ax.set_ylabel('positive fMRI beta3 variables')
                    ax.set_yticklabels(pos_fMRI_beta3_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_pos_fMRI_beta3_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_pos_fMRI_beta3_loadings.svg')
            plt.close()"""

            ## Plotting negative fMRI beta1 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neg_fmri_beta1 = W[69:138, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neg_fmri_beta1]
                ax = axes[j]
                ax.barh(np.arange(w_neg_fmri_beta1.size), w_neg_fmri_beta1, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_neg_fmri_beta1.size))

                if j == 0:
                    ax.set_ylabel('negative fMRI beta1 variables')
                    ax.set_yticklabels(neg_fMRI_beta1_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neg_fMRI_beta1_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neg_fMRI_beta1_loadings.svg')
            plt.close()"""

            ## Plotting negative fMRI beta1 Weights for Each Component

            """fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neg_fmri_beta3 = W[276:345, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neg_fmri_beta3]
                ax = axes[j]
                ax.barh(np.arange(w_neg_fmri_beta3.size), w_neg_fmri_beta3, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_neg_fmri_beta3.size))

                if j == 0:
                    ax.set_ylabel('negative fMRI beta3 variables')
                    ax.set_yticklabels(neg_fMRI_beta3_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neg_fMRI_beta3_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neg_fMRI_beta3_loadings.svg')
            plt.close()"""

            ## Plotting sad EEG Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_sad_eeg = W[0:155, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sad_eeg]
                ax = axes[j]
                ax.barh(np.arange(w_sad_eeg.size), w_sad_eeg, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.5, 0.5, 1.0))
                ax.set_yticks(np.arange(w_sad_eeg.size))

                if j == 0:
                    ax.set_ylabel('Sad EEG Variables')
                    ax.set_yticklabels(sad_EEG_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_sad_EEG_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_sad_EEG_loadings.svg')
            plt.close()
            
            
            ## Plotting neutral EEG Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_neutral_eeg = W[155:310, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_neutral_eeg]
                ax = axes[j]
                ax.barh(np.arange(w_neutral_eeg.size), w_neutral_eeg, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.5, 0.5, 1.0))
                ax.set_yticks(np.arange(w_neutral_eeg.size))

                if j == 0:
                    ax.set_ylabel('Neutral EEG Variables')
                    ax.set_yticklabels(neutral_EEG_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_neutral_EEG_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_neutral_EEG_loadings.svg')
            plt.close()

            ## Plotting questionnaire Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_ques = W[310:350, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_ques]
                ax = axes[j]
                ax.barh(np.arange(w_ques.size), w_ques, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks(np.arange(-0.5, 0.5, 1.0))
                ax.set_yticks(np.arange(w_ques.size))

                if j == 0:
                    ax.set_ylabel('questionnaire variables')
                    ax.set_yticklabels(questionnaire_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_questionnaire_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_questionnaire_loadings.svg')
            plt.close()

            ## Plotting Clinical Weights for Each Component
            """
            fig, axes = plt.subplots(1, len(var_comps), figsize=(20, 10), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_cli = W[226:231, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_cli]
                ax = axes[j]
                ax.barh(np.arange(w_cli.size), w_cli, color=colours)
                ax.set_xlabel('Loadings')
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-2.5, 2.5])

                ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
                ax.set_yticks(np.arange(w_cli.size))

                if j == 0:
                    ax.set_ylabel('clinical variables')
                    ax.set_yticklabels(clinical_labels, fontsize=10)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=8, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_clinical_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_clinical_loadings.svg')
            plt.close()
            """
        ## Plotting Top Components

        top = 5
        if len(var_comps) > top:
            pass
        else:
            top = len(var_comps)

        for j in range(top):
            # print(f'Variance explained by cmp {j+1}: {np.around(var_comps[j] * 100, 2)}', file=ofile)
            print(f'Variance explained by cmp {j + 1}: {np.around(varexp_comps[j] * 100, 2)}', file=ofile)

            # structural MRI
            """w_smri = W[0:69, j]
            w_sort = w_smri[np.argsort(w_smri)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_smri.size), w_sort, color=colours)
            plt.ylabel('structural MRI variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_smri.size).tolist(),
                       [structural_MRI_labels[np.argsort(w_smri)[i]] for i in range(len(structural_MRI_labels))],
                       fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/structuralMRI_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/structuralMRI_loadings{j + 1}.svg')
            plt.close()"""

            # positive fMRI beta1
            """w_pos_fmri_beta1 = W[0:69, j]
            w_sort = w_pos_fmri_beta1[np.argsort(w_pos_fmri_beta1)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_pos_fmri_beta1.size), w_sort, color=colours)
            plt.ylabel('positive fMRI beta1 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_pos_fmri_beta1.size).tolist(),
                       [pos_fMRI_beta1_labels[np.argsort(w_pos_fmri_beta1)[i]] for i in
                        range(len(pos_fMRI_beta1_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/positive_fMRI_beta1_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/positive_fMRI_beta1_loadings{j + 1}.svg')
            plt.close()"""

            # positive fMRI beta3
            """w_pos_fmri_beta3 = W[138:207, j]
            w_sort = w_pos_fmri_beta3[np.argsort(w_pos_fmri_beta3)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_pos_fmri_beta3.size), w_sort, color=colours)
            plt.ylabel('positive fMRI beta3 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_pos_fmri_beta3.size).tolist(), [pos_fMRI_beta3_labels[np.argsort(w_pos_fmri_beta3)[i]] for i in range(len(pos_fMRI_beta3_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_beta3_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_beta3_loadings{j + 1}.svg')
            plt.close()"""

            # negative fMRI beta1
            """w_neg_fmri_beta1 = W[69:138, j]
            w_sort = w_neg_fmri_beta1[np.argsort(w_neg_fmri_beta1)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neg_fmri_beta1.size), w_sort, color=colours)
            plt.ylabel('negative fMRI beta1 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_neg_fmri_beta1.size).tolist(),
                       [neg_fMRI_beta1_labels[np.argsort(w_neg_fmri_beta1)[i]] for i in
                        range(len(neg_fMRI_beta1_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_beta1_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_beta1_loadings{j + 1}.svg')
            plt.close()"""

            # negative fMRI beta3
            """w_neg_fmri_beta3 = W[276:345, j]
            w_sort = w_neg_fmri_beta3[np.argsort(w_neg_fmri_beta3)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neg_fmri_beta3.size), w_sort, color=colours)
            plt.ylabel('negative fMRI beta3 variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.5, 2.5])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_neg_fmri_beta3.size).tolist(),[neg_fMRI_beta3_labels[np.argsort(w_neg_fmri_beta3)[i]] for i in range(len(neg_fMRI_beta3_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/negative_fMRI_beta3_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/negative_fMRI_beta3_loadings{j + 1}.svg')
            plt.close()"""

            # sad EEG
            w_sad_eeg = W[0:155, j]
            w_sort = w_sad_eeg[np.argsort(w_sad_eeg)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_sad_eeg.size), w_sort, color=colours)
            plt.ylabel('Sad EEG Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_sad_eeg.size).tolist(),
                       [sad_EEG_labels[np.argsort(w_sad_eeg)[i]] for i in range(len(sad_EEG_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/sad_eeg_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/sad_eeg_loadings{j + 1}.svg')
            plt.close()
            
            
            # neutral EEG
            w_neutral_eeg = W[155:310, j]
            w_sort = w_neutral_eeg[np.argsort(w_neutral_eeg)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_neutral_eeg.size), w_sort, color=colours)
            plt.ylabel('Neutral EEG Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_neutral_eeg.size).tolist(),
                       [neutral_EEG_labels[np.argsort(w_neutral_eeg)[i]] for i in range(len(neutral_EEG_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/neutral_eeg_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/neutral_eeg_loadings{j + 1}.svg')
            plt.close()

            # questionnaire
            w_ques = W[310:350, j]
            w_sort = w_ques[np.argsort(w_ques)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_ques.size), w_sort, color=colours)
            plt.ylabel('questionnaire variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_ques.size).tolist(),
                       [questionnaire_labels[np.argsort(w_ques)[i]] for i in range(len(questionnaire_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/questionnaire_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/questionnaire_loadings{j + 1}.svg')
            plt.close()
            
            """
            # clinical
            w_cli = W[226:231, j]
            w_sort = w_cli[np.argsort(w_cli)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_cli.size), w_sort, color=colours)
            plt.ylabel('clinical variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-2.0, 2.0])
            ax.set_xticks(np.arange(-2.5, 3.0, 0.5))
            plt.yticks(np.arange(w_cli.size).tolist(),
                       [clinical_labels[np.argsort(w_cli)[i]] for i in range(len(clinical_labels))], fontsize=9)
            ax.tick_params(axis='y', labelsize=8, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/clinical_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/clinical_loadings{j + 1}.svg')
            plt.close()
            """
            
        ## Plotting Robust Components

        abs_mean_scores, scores_dist = plot_components(rob_params, top, ids_var, plot_path)

        total_scores = np.sum(abs_mean_scores, axis=0)

        x = np.arange(abs_mean_scores.shape[1])

        colors = ['#fdbb84', '#2b8cbe']

        width = 0.3

        b_diff = -width

        plt.figure(figsize=(7, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        for s in range(abs_mean_scores.shape[0]):
            plt.bar(x + b_diff, abs_mean_scores[s, :] / total_scores, width=width, color=colors[s])

            b_diff += width

        plt.xticks(x, [f'{i + 1}' for i in range(x.size)], fontsize=0.8 * fontsize)
        plt.ylim([0, 1])

        plt.yticks(fontsize=0.8 * fontsize)
        plt.xlabel('Factors', fontsize=fontsize)
        plt.ylabel('Factor contributions', fontsize=fontsize)

        #plt.legend(['relapse', 'no_relapse'], fontsize=0.85 * fontsize)
        plt.legend(['control', 'patient'], fontsize=0.85 * fontsize)

        plt.tight_layout()
        plt.savefig(f'{plot_path}/Subtype_scores.png')
        plt.savefig(f'{plot_path}/svgs/Subtype_scores.svg')
        plt.close()

        ## Plotting Subtype Representation for Top Components

        plt.figure(figsize=(4, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        x = np.arange(top)
        height = 0.2
        b_diff = -height

        for s in range(abs_mean_scores.shape[0]):
            plt.barh(x + b_diff, width=abs_mean_scores[s, 0:top] / total_scores[0:top], height=height, color=colors[s])

            b_diff += height

        plt.yticks(x, [f'{i + 1}' for i in range(top)], fontsize=0.8 * fontsize)
        plt.xticks(fontsize=0.8 * fontsize)
        plt.xlabel('Factor contributions', fontsize=fontsize)
        plt.ylabel('Factors', fontsize=fontsize)

        #plt.legend(['relapse', 'no_relapse'], fontsize=0.85 * fontsize)
        plt.legend(['control', 'patient'], fontsize=0.85 * fontsize)

        plt.tight_layout()
        plt.savefig(f'{plot_path}/Subtype_scores_top{top}.png')
        plt.savefig(f'{plot_path}/svgs/Subtype_scores_top{top}.svg')
        plt.close()

        ## Creating Boxplot for Subtype Scores

        # Creating Boxplot:
        plt.figure(figsize=(7, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        # Extracting Scores:
        rel = [scores_dist['s1'][:, i] for i in range(top)]
        norel = [scores_dist['s2'][:, i] for i in range(top)]

        ticks = [f'Factor {i + 1}' for i in range(top)]

        # Plotting Boxplots:
        rel_plot = plt.boxplot(rel, positions=np.array(np.arange(len(rel))) * 2.0 - 0.6, widths=0.5)
        norel_plot = plt.boxplot(norel, positions=np.array(np.arange(len(norel))) * 2.0, widths=0.5)

        # Setting Boxplot Properties:
        define_box_properties(rel_plot, '#fdbb84', 'rel')
        define_box_properties(norel_plot, '#2b8cbe', 'norel')

        plt.xticks(np.arange(0, len(ticks) * 2, 2), ticks, fontsize=0.8 * fontsize)
        plt.yticks(fontsize=0.8 * fontsize)
        plt.legend(fontsize=0.8 * fontsize)
        plt.xlim(-2, len(ticks) * 2)
        plt.ylabel('Absolute latent scores', fontsize=fontsize)
        plt.savefig(f'{plot_path}/Subtype_scores_boxplot.png', dpi=300)
        plt.savefig(f'{plot_path}/svgs/Subtype_scores_boxplot.svg')
        plt.close()

        ## Computing F Statistic (performing the F-test and t-tests helps in statistically validating the differences between the subtypes across the factors)

        scores = np.abs(rob_params['Z'][:, ids_var])

        N = scores.shape[0]

        df_subjs = pd.read_csv(f'./aida_model/visit11_data_{N}subjs.csv')

        #ids = list(df_subjs["relapse"])
        ids = list(df_subjs["isControl"])

        ns = [sum([x == 1 for x in ids]), sum([x == 0 for x in ids])]

        g1 = scores[0:ns[0], :]
        g2 = scores[ns[0]:ns[0] + ns[1], :]

        comps = [f'Factor {i + 1}' for i in range(scores.shape[1])]

        ## F Test

        stats_all = f_oneway(g1, g2)

        df_all = pd.DataFrame(index=comps, columns=['F score', 'p-value'])
        df_all['F score'] = stats_all[0]
        df_all['p-value'] = stats_all[1]
        df_all.to_csv(f'{plot_path}/Ftest.csv')

        ## T Tests:

        stats_g1g2 = ttest_ind(g1, g2)
        df_g1g2 = pd.DataFrame(index=comps, columns=['t', 'p-value'])
        df_g1g2['t'] = stats_g1g2[0]
        df_g1g2['p-value'] = stats_g1g2[1]
        df_g1g2.to_csv(f'{plot_path}/Ttest_RELvsNOREL.csv')
        

In [None]:
# EEG

def genfi(data, res_dir, args):
    ## Finding Best Initialization
    ofile = open(f'{res_dir}/results.txt', 'w')

    exp_logs, ofile = find_bestrun(res_dir, args, ofile)
    brun = np.nanargmax(exp_logs) + 1
    print('Best run: ', brun, file=ofile)

    ## Setting Up Plot Directories
    plot_path = f'{res_dir}/plots_{brun}'

    if not os.path.exists(plot_path):
        os.makedirs(plot_path)
        os.makedirs(f'{plot_path}/svgs')

    ## Standardizing Data and Calculating Total Variance
    X = data.get('X')
    Y = data.get('Y')

    scaler = StandardScaler()
    transformer = PowerTransformer(method='yeo-johnson')

    X1_columns = slice(0, 63)
    X2_columns = slice(63, 126)
    X3_columns = slice(126, 189)
    X4_columns = slice(189, 252)
    X5_columns = slice(252, 315)
    X6_columns = slice(315, 355)
    
    # Standard Scaler (Feature-wise) and Box-Cox if 

    X[:, X1_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X1_columns]))
    X[:, X2_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X2_columns]))
    X[:, X3_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X3_columns]))
    X[:, X4_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X4_columns]))
    X[:, X5_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X5_columns]))
    X[:, X6_columns] = transformer.fit_transform(scaler.fit_transform(X[:, X6_columns]))

    
    # Apply Standard Scaler (Feature-wise) only
    """
    X[:, X1_columns] = scaler.fit_transform(X[:, X1_columns])
    X[:, X2_columns] = scaler.fit_transform(X[:, X2_columns])
    X[:, X3_columns] = scaler.fit_transform(X[:, X3_columns])
    X[:, X4_columns] = scaler.fit_transform(X[:, X4_columns])
    X[:, X5_columns] = scaler.fit_transform(X[:, X5_columns])
    X[:, X6_columns] = scaler.fit_transform(X[:, X6_columns])
    """

    Tvar = np.trace(np.dot(X.T, X))
    ## Loading Robust Parameters

    rparams_path = f'{res_dir}/[{brun}]Robust_params.dictionary'

    if os.stat(rparams_path).st_size > 5:
        with open(rparams_path, 'rb') as parameters:
            rob_params = pickle.load(parameters)

        ## Calculating Variance Explained
        var_comps = []
        X_inf = rob_params['infX']

        for k in range(len(X_inf)):
            var_Xk = np.trace(np.dot(X_inf[k][0].T, X_inf[k][0])) / Tvar
            var_comps.append(var_Xk)

        ## Creating Scree Plot (visualizes the variance explained by each factor)

        varexp_comps = np.array(var_comps)
        ids_var = np.argsort(-varexp_comps)
        varexp_comps = varexp_comps[ids_var]

        # print(f'ids_var: {ids_var}')
        # print(f'varexp_comps: {varexp_comps}')

        x = np.arange(len(var_comps) + 1)
        cum_var = [0]

        for i in range(1, varexp_comps.size + 1):
            if i == 1:
                cum_var.append(varexp_comps[i - 1] * 100)
            else:
                cum_var.append(varexp_comps[i - 1] * 100 + cum_var[i - 1])

        plt.figure(figsize=(5, 5), dpi=300)
        plt.plot(x, cum_var, 'ko-', linewidth=2)
        plt.xlabel('Factors')
        plt.ylabel('Covariance explained (%)')
        plt.xticks(x, [f'{i}' for i in range(x.size)])
        plt.savefig(f'{plot_path}/Scree_plot.png')
        plt.savefig(f'{plot_path}/svgs/Scree_plot.svg')
        plt.close()

        ## Plotting Weights and Printing Total Explained Variance

        df_var = pd.read_csv(f'./aida_model/var_labels.csv')

        W = rob_params.get('W')[:, ids_var]

        print(f'\nTotal variance explained: {np.around(sum(var_comps) * 100, 2)}\n', file=ofile)


        alpha_labels = list(df_var.iloc[0:63, 1])
        beta_labels = list(df_var.iloc[63:126, 1])
        delta_labels = list(df_var.iloc[126:189, 1])
        gamma_labels = list(df_var.iloc[189:252, 1])
        theta_labels = list(df_var.iloc[252:315, 1])

        questionnaire_labels = list(df_var.iloc[315:355, 1])




        if 'sparseGFA' in args.model:


            ## Plotting ALPHA for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(40, 20), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_alpha = W[0:63, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_alpha]
                ax = axes[j]
                ax.barh(np.arange(w_alpha.size), w_alpha, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks([-0.5, 0, 0.5])
                ax.set_yticks(np.arange(w_alpha.size))

                if j == 0:
                    ax.set_ylabel('alpha variables')
                    ax.set_yticklabels(alpha_labels, fontsize=6)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=6, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/alpha_loadings.png')
            plt.savefig(f'{plot_path}/svgs/alpha_loadings.svg')
            plt.close()



            ## Plotting BETA for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(40, 20), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_beta = W[63:126, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_beta]
                ax = axes[j]
                ax.barh(np.arange(w_beta.size), w_beta, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks([-0.5, 0, 0.5])
                ax.set_yticks(np.arange(w_beta.size))

                if j == 0:
                    ax.set_ylabel('beta variables')
                    ax.set_yticklabels(beta_labels, fontsize=6)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=6, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/beta_loadings.png')
            plt.savefig(f'{plot_path}/svgs/beta_loadings.svg')
            plt.close()
            
            
            
            ## Plotting DELTA for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(40, 20), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_delta = W[126:189, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_delta]
                ax = axes[j]
                ax.barh(np.arange(w_delta.size), w_delta, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks([-0.5, 0, 0.5])
                ax.set_yticks(np.arange(w_delta.size))

                if j == 0:
                    ax.set_ylabel('delta variables')
                    ax.set_yticklabels(delta_labels, fontsize=6)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=6, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/delta_loadings.png')
            plt.savefig(f'{plot_path}/svgs/delta_loadings.svg')
            plt.close()
            
            
            
            ## Plotting GAMMA for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(40, 20), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_gamma = W[189:252, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_gamma]
                ax = axes[j]
                ax.barh(np.arange(w_gamma.size), w_gamma, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks([-0.5, 0, 0.5])
                ax.set_yticks(np.arange(w_gamma.size))

                if j == 0:
                    ax.set_ylabel('gamma variables')
                    ax.set_yticklabels(gamma_labels, fontsize=6)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=6, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/gamma_loadings.png')
            plt.savefig(f'{plot_path}/svgs/gamma_loadings.svg')
            plt.close()
            
            
            ## Plotting THETA for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(40, 20), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_theta = W[252:315, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_theta]
                ax = axes[j]
                ax.barh(np.arange(w_theta.size), w_theta, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks([-0.5, 0, 0.5])
                ax.set_yticks(np.arange(w_theta.size))

                if j == 0:
                    ax.set_ylabel('theta variables')
                    ax.set_yticklabels(theta_labels, fontsize=6)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=6, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/theta_loadings.png')
            plt.savefig(f'{plot_path}/svgs/theta_loadings.svg')
            plt.close()
            

            ## Plotting questionnaire Weights for Each Component

            fig, axes = plt.subplots(1, len(var_comps), figsize=(40, 20), dpi=300)

            if not isinstance(axes, np.ndarray):
                axes = [axes]

            for j in range(len(var_comps)):

                w_ques = W[315:355, j]
                colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_ques]
                ax = axes[j]
                ax.barh(np.arange(w_ques.size), w_ques, color=colours)
                ax.set_xlabel('Loadings', fontsize=8, labelpad=10)
                ax.set_title(f'Factor {j + 1}')
                ax.set_xlim([-0.5, 0.5])
                ax.tick_params(axis='x', labelsize=8)

                ax.set_xticks([-0.5, 0, 0.5])
                ax.set_yticks(np.arange(w_ques.size))

                if j == 0:
                    ax.set_ylabel('questionnaire variables')
                    ax.set_yticklabels(questionnaire_labels, fontsize=6)
                else:
                    ax.set_yticklabels([])

                ax.tick_params(axis='x', labelsize=10)
                ax.tick_params(axis='y', labelsize=6, pad=5)

            plt.tight_layout()
            plt.savefig(f'{plot_path}/all_questionnaire_loadings.png')
            plt.savefig(f'{plot_path}/svgs/all_questionnaire_loadings.svg')
            plt.close()


        ## Plotting Top Components

        top = 5
        if len(var_comps) > top:
            pass
        else:
            top = len(var_comps)

        for j in range(top):
            print(f'Variance explained by cmp {j + 1}: {np.around(varexp_comps[j] * 100, 2)}', file=ofile)

            # alpha
            w_alpha = W[0:63, j]
            w_sort = w_alpha[np.argsort(w_alpha)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_alpha.size), w_sort, color=colours)
            plt.ylabel('Alpha Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_alpha.size).tolist(),
                       [alpha_labels[np.argsort(w_alpha)[i]] for i in range(len(alpha_labels))], fontsize=6)
            ax.tick_params(axis='y', labelsize=6, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/sad_eeg_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/sad_eeg_loadings{j + 1}.svg')
            plt.close()

            # beta
            w_beta = W[63:126, j]
            w_sort = w_beta[np.argsort(w_beta)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_beta.size), w_sort, color=colours)
            plt.ylabel('Beta Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_beta.size).tolist(),
                       [beta_labels[np.argsort(w_beta)[i]] for i in range(len(beta_labels))], fontsize=6)
            ax.tick_params(axis='y', labelsize=6, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/beta_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/beta_loadings{j + 1}.svg')
            plt.close()


            # delta
            w_delta = W[126:189, j]
            w_sort = w_delta[np.argsort(w_delta)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_delta.size), w_sort, color=colours)
            plt.ylabel('Delta Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_delta.size).tolist(),
                       [delta_labels[np.argsort(w_delta)[i]] for i in range(len(delta_labels))], fontsize=6)
            ax.tick_params(axis='y', labelsize=6, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/delta_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/delta_loadings{j + 1}.svg')
            plt.close()
            

            # gamma
            w_gamma = W[189:252, j]
            w_sort = w_gamma[np.argsort(w_gamma)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_gamma.size), w_sort, color=colours)
            plt.ylabel('Gamma Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_gamma.size).tolist(),
                       [gamma_labels[np.argsort(w_gamma)[i]] for i in range(len(gamma_labels))], fontsize=6)
            ax.tick_params(axis='y', labelsize=6, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/gamma_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/gamma_loadings{j + 1}.svg')
            plt.close()
            
            
            # theta
            w_theta = W[252:315, j]
            w_sort = w_theta[np.argsort(w_theta)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_theta.size), w_sort, color=colours)
            plt.ylabel('Theta Variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_theta.size).tolist(),
                       [theta_labels[np.argsort(w_theta)[i]] for i in range(len(theta_labels))], fontsize=6)
            ax.tick_params(axis='y', labelsize=6, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/theta_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/theta_loadings{j + 1}.svg')
            plt.close()

            # questionnaire
            w_ques = W[310:355, j]
            w_sort = w_ques[np.argsort(w_ques)]
            colours = ['#b15352' if w > 0 else '#5ba3b4' for w in w_sort]
            plt.figure(figsize=(4, 7), dpi=300)
            ax = plt.axes()
            plt.barh(np.arange(w_ques.size), w_sort, color=colours)
            plt.ylabel('Questionnaire variables')
            plt.xlabel('Loadings')
            ax.set_xlim([-0.50, 0.50])
            ax.set_xticks(np.arange(-0.50, 0.75, 0.25))
            plt.yticks(np.arange(w_ques.size).tolist(),
                       [questionnaire_labels[np.argsort(w_ques)[i]] for i in range(len(questionnaire_labels))], fontsize=6)
            ax.tick_params(axis='y', labelsize=6, pad=5)
            plt.xticks(fontsize=9)
            plt.tight_layout()
            plt.savefig(f'{plot_path}/questionnaire_loadings{j + 1}.png')
            plt.savefig(f'{plot_path}/svgs/questionnaire_loadings{j + 1}.svg')
            plt.close()


        ## Plotting Robust Components

        abs_mean_scores, scores_dist = plot_components(rob_params, top, ids_var, plot_path)

        total_scores = np.sum(abs_mean_scores, axis=0)

        x = np.arange(abs_mean_scores.shape[1])

        colors = ['#fdbb84', '#2b8cbe']

        width = 0.3

        b_diff = -width

        plt.figure(figsize=(7, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        for s in range(abs_mean_scores.shape[0]):
            plt.bar(x + b_diff, abs_mean_scores[s, :] / total_scores, width=width, color=colors[s])

            b_diff += width

        plt.xticks(x, [f'{i + 1}' for i in range(x.size)], fontsize=0.8 * fontsize)
        plt.ylim([0, 1])

        plt.yticks(fontsize=0.8 * fontsize)
        plt.xlabel('Factors', fontsize=fontsize)
        plt.ylabel('Factor contributions', fontsize=fontsize)

        # plt.legend(['relapse', 'no_relapse'], fontsize=0.85 * fontsize)
        plt.legend(['control', 'patient'], fontsize=0.85 * fontsize)

        plt.tight_layout()
        plt.savefig(f'{plot_path}/Subtype_scores.png')
        plt.savefig(f'{plot_path}/svgs/Subtype_scores.svg')
        plt.close()

        ## Plotting Subtype Representation for Top Components

        plt.figure(figsize=(4, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        x = np.arange(top)
        height = 0.2
        b_diff = -height

        for s in range(abs_mean_scores.shape[0]):
            plt.barh(x + b_diff, width=abs_mean_scores[s, 0:top] / total_scores[0:top], height=height, color=colors[s])

            b_diff += height

        plt.yticks(x, [f'{i + 1}' for i in range(top)], fontsize=0.8 * fontsize)
        plt.xticks(fontsize=0.8 * fontsize)
        plt.xlabel('Factor contributions', fontsize=fontsize)
        plt.ylabel('Factors', fontsize=fontsize)

        # plt.legend(['relapse', 'no_relapse'], fontsize=0.85 * fontsize)
        plt.legend(['control', 'patient'], fontsize=0.85 * fontsize)

        plt.tight_layout()
        plt.savefig(f'{plot_path}/Subtype_scores_top{top}.png')
        plt.savefig(f'{plot_path}/svgs/Subtype_scores_top{top}.svg')
        plt.close()

        ## Creating Boxplot for Subtype Scores

        # Creating Boxplot:
        plt.figure(figsize=(7, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 5 * (dpi / 100)

        # Extracting Scores:
        rel = [scores_dist['s1'][:, i] for i in range(top)]
        norel = [scores_dist['s2'][:, i] for i in range(top)]

        ticks = [f'Factor {i + 1}' for i in range(top)]

        # Plotting Boxplots:
        rel_plot = plt.boxplot(rel, positions=np.array(np.arange(len(rel))) * 2.0 - 0.6, widths=0.5)
        norel_plot = plt.boxplot(norel, positions=np.array(np.arange(len(norel))) * 2.0, widths=0.5)

        # Setting Boxplot Properties:
        define_box_properties(rel_plot, '#fdbb84', 'rel')
        define_box_properties(norel_plot, '#2b8cbe', 'norel')

        plt.xticks(np.arange(0, len(ticks) * 2, 2), ticks, fontsize=0.8 * fontsize)
        plt.yticks(fontsize=0.8 * fontsize)
        plt.legend(fontsize=0.8 * fontsize)
        plt.xlim(-2, len(ticks) * 2)
        plt.ylabel('Absolute latent scores', fontsize=fontsize)
        plt.savefig(f'{plot_path}/Subtype_scores_boxplot.png', dpi=300)
        plt.savefig(f'{plot_path}/svgs/Subtype_scores_boxplot.svg')
        plt.close()

        ## Computing F Statistic (performing the F-test and t-tests helps in statistically validating the differences between the subtypes across the factors)

        scores = np.abs(rob_params['Z'][:, ids_var])

        N = scores.shape[0]

        df_subjs = pd.read_csv(f'./aida_model/visit11_data_{N}subjs.csv')

        # ids = list(df_subjs["relapse"])
        ids = list(df_subjs["isControl"])

        ns = [sum([x == 1 for x in ids]), sum([x == 0 for x in ids])]

        g1 = scores[0:ns[0], :]
        g2 = scores[ns[0]:ns[0] + ns[1], :]

        comps = [f'Factor {i + 1}' for i in range(scores.shape[1])]

        ## F Test

        stats_all = f_oneway(g1, g2)

        df_all = pd.DataFrame(index=comps, columns=['F score', 'p-value'])
        df_all['F score'] = stats_all[0]
        df_all['p-value'] = stats_all[1]
        df_all.to_csv(f'{plot_path}/Ftest.csv')

        ## T Tests:

        stats_g1g2 = ttest_ind(g1, g2)
        df_g1g2 = pd.DataFrame(index=comps, columns=['t', 'p-value'])
        df_g1g2['t'] = stats_g1g2[0]
        df_g1g2['p-value'] = stats_g1g2[1]
        df_g1g2.to_csv(f'{plot_path}/Ttest_RELvsNOREL.csv')

In [None]:
def find_bestrun(res_dir, args, ofile):
    ## Finding the Best Run (identifies the best run based on the expected log joint density)

    exp_logs = np.nan * np.ones((1, args.num_runs))  # Initializes an array with NaN values to store the expected log joint densities for each run.

    for r in range(args.num_runs):  # Iterates over the number of runs specified in args.num_runs.

        res_path = f'{res_dir}/[{r + 1}]Model_params.dictionary'

        if os.stat(res_path).st_size > 5:

            with open(res_path, 'rb') as parameters:
                mcmc_samples = pickle.load(parameters)

            exp_logs[0, r] = mcmc_samples['exp_logdensity']  # Extracts the expected log joint density from the MCMC samples and stores it in exp_logs.

            print('Run: ', r + 1, file=ofile)
            print('MCMC elapsed time: {:.2f} h'.format(mcmc_samples['time_elapsed'] / 60), file=ofile)
            print('Expected log joint density: {:.2f}\n'.format(exp_logs[0, r]), file=ofile)

        else:
            print('The model output file is empty!')
            sys.exit(1)

    return exp_logs, ofile

In [None]:
# RELAPSE
def plot_components(params, top, ids_var, path):
    ## Reading Data

    X = params['infX']
    N = params['Z'].shape[0]

    data_dir = './aida_model'

    ### READS DATA WITH ALL MODALITIES, INCLUDING DEMOGRAPHY AND RELAPSE INFORMATION
    df_subjs = pd.read_csv(f'{data_dir}/visit11_data_{N}subjs.csv')

    ### 'Blinded Code' IS ASSUMED TO BE SUBGROUP INFORMATION, THEREFORE ids IS A LIST THAT IS EITHER 0 OR 1S CORRESPONDING TO RELAPSE
    ids = list(df_subjs["relapse"])

    ### df_subjs["Genetic Group"] NO IDEA WHAT THIS IS, GOTTA LEARN, FOR NOW ['Blinded Code'] AND ["relapse"] ARE ASSUMED TO BE THE SAME
    subtype_labels = df_subjs["relapse"]

    df_var = pd.read_csv(f'{data_dir}/var_labels.csv')

    ### ADDING THE COLUMNS 'new_labels' AND 'view' TO THE DATAFRAME df_var WHICH ALREADY CONTAINS THE COLUMN "labels"
    df_var['new_labels'] = df_var['labels']
    df_var['view'] = 0

    df_var.loc[0:32, 'view'] = 1
    df_var.loc[33:61, 'view'] = 2
    df_var.loc[62:67, 'view'] = 3
    df_var.loc[68:107, 'view'] = 4
    #df_var.loc[81:120, 'view'] = 5

    """df_var.loc[368:382, 'view'] = 6
    df_var.loc[383:454, 'view'] = 7
    df_var.loc[455:462, 'view'] = 8"""

    ## Setting Up Colors and Labels

    # Subtype Colors:
    colors = ['#fdbb84', '#2b8cbe']  # Defines colors for the subtypes.
    subtype_lut = dict(
        zip(subtype_labels.unique(), colors))  # Creates a lookup table mapping unique subtypes to colors.
    subtype_colors = subtype_labels.map(subtype_lut)  # Maps the subtype labels to their corresponding colors.

    # subtype_colors = subtype_labels.map(subtype_lut).fillna('#FFFFFF')
    # subtype_colors = subtype_colors.apply(lambda x: '#FFFFFF' if pd.isna(x) else x)

    subtype_colors.name = ''  # Sets the name of the Series to an empty string.

    # subtype_labels.unique() gets the unique values from the subtype_labels Series.
    # Since subtype_labels contain only 0 and 1, subtype_labels.unique() will return an array with these two values.
    # zip(subtype_labels.unique(), colors) creates pairs of the unique values (0 and 1) with the colors.
    # The result will be something like [(0, '#fdbb84'), (1, '#2b8cbe')].
    # dict(zip(subtype_labels.unique(), colors)) converts the pairs into a dictionary.
    # The resulting dictionary, subtype_lut, will look like {0: '#fdbb84', 1: '#2b8cbe'}
    # subtype_labels.map(subtype_lut) applies the mapping defined by subtype_lut to each value in subtype_labels.
    # It replaces each 0 with '#fdbb84' and each 1 with '#2b8cbe'.
    # subtype_colors will be a new pandas Series where each value in subtype_labels has been replaced by its corresponding color from subtype_lut.
    # subtype_colors is a Series where the original 0 and 1 values in the "relapse" column are replaced by their corresponding colors as specified in subtype_lut.

    # Feature Colors:

    ### ADD ['view'] COLUMN TO "df_vars" CREATED IN "get_data.genfi"
    view_labels = df_var['view']

    ### UPDATE THE LIST TO HAVE AS MANY ELEMENTS AS THE NUMBER OF MODALITIES
    # view_lut = dict(zip(view_labels.unique(), ['#993404', '#fec44f', '#ff7f00', '#377eb8', '#4daf4a', '#e41a1c', '#984ea3', '#a65628']))
    view_lut = dict(zip(view_labels.unique(), ['#993404', '#4daf4a', '#ff7f00', '#e41a1c']))
    view_colors = [view_labels.map(view_lut)]

    ### CREATE df_var['labels'] THAT IS EXACTLY THE SAME AS df_var['labels']
    var_labels = list(df_var['new_labels'])

    ## Plotting Clustermap

    patient_ids = df_subjs['ppid'].tolist()

    Z = params['Z']
    Z = Z[:, ids_var]
    lcomps = [f'Factor {k + 1}' for k in range(Z.shape[1])]
    df_Z = pd.DataFrame(Z, columns=lcomps)

    df_Z.to_csv(f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/trials/df_Z_all.csv', index=False)

    cm = sns.clustermap(df_Z,
                        vmin=np.min(Z),
                        vmax=np.max(Z),
                        cmap="vlag",
                        center=0.00,
                        row_colors=subtype_colors,
                        row_cluster=False,
                        col_cluster=False,
                        xticklabels=True,
                        yticklabels=patient_ids,
                        figsize=(20, 15)
                        )

    # cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_yticklabels(), rotation=0, fontsize=8, color='black')
    # cm.ax_heatmap.yaxis.set_tick_params(pad=10)

    for label in subtype_labels.unique():
        cm.ax_row_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

    cm.ax_row_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
    cm.ax_row_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    plt.savefig(f'{path}/infZ_ord.png')
    plt.savefig(f'{path}/svgs/infZ_ord.svg')
    plt.close()

    ## Plotting Clustermap for Top Components

    df_Z = pd.DataFrame(Z[:, :top], columns=lcomps[:top])

    df_Z.to_csv(f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/trials/df_Z_top.csv', index=False)

    cm = sns.clustermap(df_Z,
                        vmin=np.min(Z[:, :top]),
                        vmax=np.max(Z[:, :top]),
                        cmap="vlag",
                        center=0.00,
                        row_colors=subtype_colors,
                        row_cluster=False,
                        col_cluster=False,
                        xticklabels=True,
                        yticklabels=patient_ids,
                        figsize=(20, 15)
                        )

    for label in subtype_labels.unique():
        cm.ax_row_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

    cm.ax_row_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
    cm.ax_row_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)

    plt.savefig(f'{path}/infZ_ord_top.png')
    plt.savefig(f'{path}/svgs/infZ_ord_top.svg')
    plt.close()

    ## Calculating Subtype Scores

    ### "ids" IS A LIST THAT IS EITHER 0 OR 1S CORRESPONDING TO RELAPSE, THEREFORE "nsubt" is equal to [26, 57]
    nsubt = [sum([x == 1 for x in ids]), sum([x == 0 for x in ids])]

    ### "subtype_scores" IS A 2x83 NUMPY ARRAY FILLED WITH ZEROS.
    subtype_scores = np.zeros((len(nsubt), len(X)))

    ### THE DICTIONARY "scores_dict" CONTAIN TWO KEYS, 's1' AND 's2', EACH MAPPING TO A NUMPY ARRAY OF ZEROS WITH SHAPES (26, len(X)) AND (57, len(X)) RESPECTIVELY
    scores_dict = {'s1': np.zeros((nsubt[0], len(X))), 's2': np.zeros((nsubt[1], len(X)))}

    for k in range(len(X)):
        z_k = Z[:, k]
        ns = 0

        for s in range(len(nsubt)):
            z = z_k[ns:ns + nsubt[s]]
            subtype_scores[s, k] = np.mean(np.abs(z))
            scores_dict[f's{s + 1}'][:, k] = np.abs(z)
            ns += nsubt[s]

    ## Plotting Individual Scores and Components

    for k in range(top):

        z_k = Z[:, k]

        df_X = pd.DataFrame(X[ids_var[k]][0], columns=var_labels)

        df_Z = pd.DataFrame(z_k, columns=[f'Factor {k + 1}'])

        cm = sns.clustermap(df_Z,
                            vmin=np.min(Z),
                            vmax=np.max(Z),
                            cmap="vlag",
                            center=0.00,
                            row_colors=subtype_colors,
                            row_cluster=False,
                            col_cluster=False,
                            xticklabels=True,
                            yticklabels=patient_ids,
                            figsize=(2.5, 5)
                            )

        cm.ax_heatmap.set_yticklabels(patient_ids, fontsize=6, rotation=0)  # Adjust font size, rotation
        plt.tight_layout()

        for label in subtype_labels.unique():
            cm.ax_col_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

        cm.ax_col_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)

        cm.ax_col_dendrogram.legend(title='Relapse Status', loc="center", fontsize=8, ncol=1,
                                    bbox_transform=gcf().transFigure)

        plt.savefig(f'{path}/infZ_clustermap_comp{k + 1}.png', dpi=300)
        plt.savefig(f'{path}/svgs/infZ_clustermap_comp{k + 1}.svg')
        plt.close()



        ## Plotting Components on Data Space

        cm = sns.clustermap(df_X.T,
                            vmin=np.min(X[ids_var[k]][0]),
                            vmax=np.max(X[ids_var[k]][0]),
                            cmap="vlag",
                            center=0.00,
                            row_colors=view_colors,
                            col_colors=subtype_colors,
                            row_cluster=True,
                            col_cluster=True,
                            xticklabels=False,
                            yticklabels=True,
                            figsize=(30, 30)
                            )

        for label in subtype_labels.unique():
            cm.ax_col_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

        cm.ax_col_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
        cm.ax_col_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)

        for label in view_labels.unique():
            cm.ax_row_dendrogram.bar(0, 0, color=view_lut[label], label=label, linewidth=0)
        cm.ax_row_dendrogram.legend(title='Modality', loc="upper left", ncol=1, bbox_transform=gcf().transFigure)

        plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=4)
        plt.setp(cm.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)

        plt.savefig(f'{path}/infX_comp{k + 1}.png', dpi=300)
        plt.savefig(f'{path}/svgs/infX_comp{k + 1}.svg')
        plt.close()

    return subtype_scores, scores_dict

In [None]:
# DISEASE
def plot_components(params, top, ids_var, path):

    ## Reading Data

    X = params['infX']
    N = params['Z'].shape[0]

    data_dir = './aida_model'

    ### READS DATA WITH ALL MODALITIES, INCLUDING DEMOGRAPHY AND RELAPSE INFORMATION
    df_subjs = pd.read_csv(f'{data_dir}/visit11_data_{N}subjs.csv')

    ### 'Blinded Code' IS ASSUMED TO BE SUBGROUP INFORMATION, THEREFORE ids IS A LIST THAT IS EITHER 0 OR 1S CORRESPONDING TO RELAPSE
    #ids = list(df_subjs["relapse"])
    ids = list(df_subjs["isControl"])

    ### df_subjs["Genetic Group"] NO IDEA WHAT THIS IS, GOTTA LEARN, FOR NOW ['Blinded Code'] AND ["relapse"] ARE ASSUMED TO BE THE SAME
    #subtype_labels = df_subjs["relapse"]
    subtype_labels = df_subjs["isControl"]

    df_var = pd.read_csv(f'{data_dir}/var_labels.csv')

    ### ADDING THE COLUMNS 'new_labels' AND 'view' TO THE DATAFRAME df_var WHICH ALREADY CONTAINS THE COLUMN "labels"
    df_var['new_labels'] = df_var['labels']
    df_var['view'] = 0
    df_var.loc[0:154, 'view'] = 1
    df_var.loc[155:309, 'view'] = 2
    df_var.loc[310:349, 'view'] = 3
    #df_var.loc[226:230, 'view'] = 3
    #df_var.loc[169:239, 'view'] = 4
    #df_var.loc[240:244, 'view'] = 5
    #df_var.loc[309:325, 'view'] = 6
    #df_var.loc[499:571, 'view'] = 7
    #df_var.loc[572:588, 'view'] = 8

    ## Setting Up Colors and Labels

    # Subtype Colors:
    colors = ['#fdbb84', '#2b8cbe']                             # Defines colors for the subtypes.
    subtype_lut = dict(zip(subtype_labels.unique(), colors))    # Creates a lookup table mapping unique subtypes to colors.
    subtype_colors = subtype_labels.map(subtype_lut)            # Maps the subtype labels to their corresponding colors.
    subtype_colors.name = ''                                    # Sets the name of the Series to an empty string.

    # subtype_labels.unique() gets the unique values from the subtype_labels Series.
    # Since subtype_labels contain only 0 and 1, subtype_labels.unique() will return an array with these two values.
    # zip(subtype_labels.unique(), colors) creates pairs of the unique values (0 and 1) with the colors.
    # The result will be something like [(0, '#fdbb84'), (1, '#2b8cbe')].
    # dict(zip(subtype_labels.unique(), colors)) converts the pairs into a dictionary.
    # The resulting dictionary, subtype_lut, will look like {0: '#fdbb84', 1: '#2b8cbe'}
    # subtype_labels.map(subtype_lut) applies the mapping defined by subtype_lut to each value in subtype_labels.
    # It replaces each 0 with '#fdbb84' and each 1 with '#2b8cbe'.
    # subtype_colors will be a new pandas Series where each value in subtype_labels has been replaced by its corresponding color from subtype_lut.
    # subtype_colors is a Series where the original 0 and 1 values in the "relapse" column are replaced by their corresponding colors as specified in subtype_lut.

    # Feature Colors:

    ### ADD ['view'] COLUMN TO "df_vars" CREATED IN "get_data.genfi"
    view_labels = df_var['view']
    ### UPDATE THE LIST TO HAVE AS MANY ELEMENTS AS THE NUMBER OF MODALITIES
    #view_lut = dict(zip(view_labels.unique(), ['#993404', '#fec44f', '#ff7f00', '#377eb8', '#4daf4a', '#e41a1c', '#984ea3', '#a65628']))
    view_lut = dict(zip(view_labels.unique(), ['#993404', '#fec44f', '#ff7f00']))
    view_colors = [view_labels.map(view_lut)]
    ### CREATE df_var['labels'] THAT IS EXACTLY THE SAME AS df_var['labels']
    var_labels = list(df_var['new_labels'])

    ## Plotting Clustermap

    patient_ids = df_subjs['ppid'].tolist()

    Z = params['Z']
    Z = Z[:, ids_var]
    lcomps = [f'Factor {k + 1}' for k in range(Z.shape[1])]
    df_Z = pd.DataFrame(Z, columns=lcomps)

    df_Z.to_csv(f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/trials/df_Z_all.csv', index=False)

    cm = sns.clustermap(df_Z,
                        vmin=np.min(Z),
                        vmax=np.max(Z),
                        cmap="vlag",
                        center=0.00,
                        row_colors=subtype_colors,
                        row_cluster=False,
                        col_cluster=False,
                        xticklabels=True,
                        yticklabels=patient_ids,
                        figsize=(20, 15)
                        )

    # cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_yticklabels(), rotation=0, fontsize=8, color='black')
    # cm.ax_heatmap.yaxis.set_tick_params(pad=10)

    for label in subtype_labels.unique():
        cm.ax_row_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

    cm.ax_row_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
    #cm.ax_row_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    cm.ax_row_dendrogram.legend(title='Disease Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    plt.savefig(f'{path}/infZ_ord.png')
    plt.savefig(f'{path}/svgs/infZ_ord.svg')
    plt.close()

    ## Plotting Clustermap for Top Components

    df_Z = pd.DataFrame(Z[:, :top], columns=lcomps[:top])

    df_Z.to_csv(f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/trials/df_Z_top.csv', index=False)

    cm = sns.clustermap(df_Z,
                        vmin=np.min(Z[:, :top]),
                        vmax=np.max(Z[:, :top]),
                        cmap="vlag",
                        center=0.00,
                        row_colors=subtype_colors,
                        row_cluster=False,
                        col_cluster=False,
                        xticklabels=True,
                        yticklabels=patient_ids,
                        figsize=(20, 15)
                        )

    for label in subtype_labels.unique():
        cm.ax_row_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

    cm.ax_row_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
    #cm.ax_row_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    cm.ax_row_dendrogram.legend(title='Disease Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)

    plt.savefig(f'{path}/infZ_ord_top.png')
    plt.savefig(f'{path}/svgs/infZ_ord_top.svg')
    plt.close()

    ## Calculating Subtype Scores

    ### "ids" IS A LIST THAT IS EITHER 0 OR 1S CORRESPONDING TO RELAPSE, THEREFORE "nsubt" is equal to [26, 57]
    nsubt = [sum([x == 1 for x in ids]), sum([x == 0 for x in ids])]

    ### "subtype_scores" IS A 2x83 NUMPY ARRAY FILLED WITH ZEROS.
    subtype_scores = np.zeros((len(nsubt), len(X)))

    ### THE DICTIONARY "scores_dict" CONTAIN TWO KEYS, 's1' AND 's2', EACH MAPPING TO A NUMPY ARRAY OF ZEROS WITH SHAPES (26, len(X)) AND (57, len(X)) RESPECTIVELY
    scores_dict = {'s1': np.zeros((nsubt[0], len(X))), 's2': np.zeros((nsubt[1], len(X)))}

    for k in range(len(X)):
        z_k = Z[:, k]
        ns = 0

        for s in range(len(nsubt)):
            z = z_k[ns:ns + nsubt[s]]
            subtype_scores[s, k] = np.mean(np.abs(z))
            scores_dict[f's{s + 1}'][:, k] = np.abs(z)
            ns += nsubt[s]

    ## Plotting Individual Scores and Components

    for k in range(top):

        z_k = Z[:, k]

        df_X = pd.DataFrame(X[ids_var[k]][0], columns=var_labels)

        df_Z = pd.DataFrame(z_k, columns=[f'Factor {k + 1}'])

        cm = sns.clustermap(df_Z,
                            vmin=np.min(Z),
                            vmax=np.max(Z),
                            cmap="vlag",
                            center=0.00,
                            row_colors=subtype_colors,
                            row_cluster=False,
                            col_cluster=False,
                            xticklabels=True,
                            yticklabels=patient_ids,
                            figsize=(2.5, 5)
                            )

        for label in subtype_labels.unique():
            cm.ax_col_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

        cm.ax_col_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
        #cm.ax_col_dendrogram.legend(title='Relapse Status', loc="center", fontsize=8, ncol=1, bbox_transform=gcf().transFigure)
        cm.ax_col_dendrogram.legend(title='Disease Status', loc="center", fontsize=8, ncol=1, bbox_transform=gcf().transFigure)

        plt.savefig(f'{path}/infZ_clustermap_comp{k + 1}.png', dpi=300)
        plt.savefig(f'{path}/svgs/infZ_clustermap_comp{k + 1}.svg')
        plt.close()

        ## Plotting Components on Data Space

        cm = sns.clustermap(df_X.T,
                            vmin=np.min(X[ids_var[k]][0]),
                            vmax=np.max(X[ids_var[k]][0]),
                            cmap="vlag",
                            center=0.00,
                            row_colors=view_colors,
                            col_colors=subtype_colors,
                            row_cluster=True,
                            col_cluster=True,
                            xticklabels=False,
                            yticklabels=True,
                            figsize=(20, 20)
                            )

        for label in subtype_labels.unique():
            cm.ax_col_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

        cm.ax_col_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
        #cm.ax_col_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
        cm.ax_col_dendrogram.legend(title='Disease Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)

        for label in view_labels.unique():
            cm.ax_row_dendrogram.bar(0, 0, color=view_lut[label], label=label, linewidth=0)
        cm.ax_row_dendrogram.legend(title='Modality', loc="upper left", ncol=1, bbox_transform=gcf().transFigure)

        plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=3)
        plt.setp(cm.ax_heatmap.yaxis.get_majorticklabels(), rotation=45)

        plt.savefig(f'{path}/infX_comp{k + 1}.png', dpi=300)
        plt.savefig(f'{path}/svgs/infX_comp{k + 1}.svg')
        plt.close()

    return subtype_scores, scores_dict

In [None]:
# EEG
def plot_components(params, top, ids_var, path):

    ## Reading Data

    X = params['infX']
    N = params['Z'].shape[0]

    data_dir = './aida_model'

    ### READS DATA WITH ALL MODALITIES, INCLUDING DEMOGRAPHY AND RELAPSE INFORMATION
    df_subjs = pd.read_csv(f'{data_dir}/visit11_data_{N}subjs.csv')

    ### 'Blinded Code' IS ASSUMED TO BE SUBGROUP INFORMATION, THEREFORE ids IS A LIST THAT IS EITHER 0 OR 1S CORRESPONDING TO RELAPSE
    #ids = list(df_subjs["relapse"])
    ids = list(df_subjs["isControl"])

    ### df_subjs["Genetic Group"] NO IDEA WHAT THIS IS, GOTTA LEARN, FOR NOW ['Blinded Code'] AND ["relapse"] ARE ASSUMED TO BE THE SAME
    #subtype_labels = df_subjs["relapse"]
    subtype_labels = df_subjs["isControl"]

    df_var = pd.read_csv(f'{data_dir}/var_labels.csv')

    ### ADDING THE COLUMNS 'new_labels' AND 'view' TO THE DATAFRAME df_var WHICH ALREADY CONTAINS THE COLUMN "labels"
    df_var['new_labels'] = df_var['labels']
    df_var['view'] = 0
    df_var.loc[0:62, 'view'] = 1
    df_var.loc[63:125, 'view'] = 2
    df_var.loc[126:188, 'view'] = 3
    df_var.loc[189:251, 'view'] = 4
    df_var.loc[252:314, 'view'] = 5
    df_var.loc[315:354, 'view'] = 6


    ## Setting Up Colors and Labels

    # Subtype Colors:
    colors = ['#fdbb84', '#2b8cbe']                             # Defines colors for the subtypes.
    subtype_lut = dict(zip(subtype_labels.unique(), colors))    # Creates a lookup table mapping unique subtypes to colors.
    subtype_colors = subtype_labels.map(subtype_lut)            # Maps the subtype labels to their corresponding colors.
    subtype_colors.name = ''                                    # Sets the name of the Series to an empty string.

    # Feature Colors:

    ### ADD ['view'] COLUMN TO "df_vars" CREATED IN "get_data.genfi"
    view_labels = df_var['view']
    ### UPDATE THE LIST TO HAVE AS MANY ELEMENTS AS THE NUMBER OF MODALITIES
    view_lut = dict(zip(view_labels.unique(), ['#993404', '#fec44f', '#ff7f00', '#377eb8', '#4daf4a', '#e41a1c' ]))

    view_colors = [view_labels.map(view_lut)]
    ### CREATE df_var['labels'] THAT IS EXACTLY THE SAME AS df_var['labels']
    var_labels = list(df_var['new_labels'])

    ## Plotting Clustermap

    patient_ids = df_subjs['ppid'].tolist()

    Z = params['Z']
    Z = Z[:, ids_var]
    lcomps = [f'Factor {k + 1}' for k in range(Z.shape[1])]
    df_Z = pd.DataFrame(Z, columns=lcomps)

    df_Z.to_csv(f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/trials/df_Z_all.csv', index=False)

    cm = sns.clustermap(df_Z,
                        vmin=np.min(Z),
                        vmax=np.max(Z),
                        cmap="vlag",
                        center=0.00,
                        row_colors=subtype_colors,
                        row_cluster=False,
                        col_cluster=False,
                        xticklabels=True,
                        yticklabels=patient_ids,
                        figsize=(20, 15)
                        )

    # cm.ax_heatmap.set_yticklabels(cm.ax_heatmap.get_yticklabels(), rotation=0, fontsize=8, color='black')
    # cm.ax_heatmap.yaxis.set_tick_params(pad=10)

    for label in subtype_labels.unique():
        cm.ax_row_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

    cm.ax_row_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
    #cm.ax_row_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    cm.ax_row_dendrogram.legend(title='Disease Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    plt.savefig(f'{path}/infZ_ord.png')
    plt.savefig(f'{path}/svgs/infZ_ord.svg')
    plt.close()

    ## Plotting Clustermap for Top Components

    df_Z = pd.DataFrame(Z[:, :top], columns=lcomps[:top])

    df_Z.to_csv(f'/Users/mertenbiyaoglu/Desktop/ucl/thesis/codes/sGFA_AIDA/results/trials/df_Z_top.csv', index=False)

    cm = sns.clustermap(df_Z,
                        vmin=np.min(Z[:, :top]),
                        vmax=np.max(Z[:, :top]),
                        cmap="vlag",
                        center=0.00,
                        row_colors=subtype_colors,
                        row_cluster=False,
                        col_cluster=False,
                        xticklabels=True,
                        yticklabels=patient_ids,
                        figsize=(20, 15)
                        )

    for label in subtype_labels.unique():
        cm.ax_row_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

    cm.ax_row_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
    #cm.ax_row_dendrogram.legend(title='Relapse Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)
    cm.ax_row_dendrogram.legend(title='Disease Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)

    plt.savefig(f'{path}/infZ_ord_top.png')
    plt.savefig(f'{path}/svgs/infZ_ord_top.svg')
    plt.close()

    ## Calculating Subtype Scores

    ### "ids" IS A LIST THAT IS EITHER 0 OR 1S CORRESPONDING TO RELAPSE, THEREFORE "nsubt" is equal to [26, 57]
    nsubt = [sum([x == 1 for x in ids]), sum([x == 0 for x in ids])]

    ### "subtype_scores" IS A 2x83 NUMPY ARRAY FILLED WITH ZEROS.
    subtype_scores = np.zeros((len(nsubt), len(X)))

    ### THE DICTIONARY "scores_dict" CONTAIN TWO KEYS, 's1' AND 's2', EACH MAPPING TO A NUMPY ARRAY OF ZEROS WITH SHAPES (26, len(X)) AND (57, len(X)) RESPECTIVELY
    scores_dict = {'s1': np.zeros((nsubt[0], len(X))), 's2': np.zeros((nsubt[1], len(X)))}

    for k in range(len(X)):
        z_k = Z[:, k]
        ns = 0

        for s in range(len(nsubt)):
            z = z_k[ns:ns + nsubt[s]]
            subtype_scores[s, k] = np.mean(np.abs(z))
            scores_dict[f's{s + 1}'][:, k] = np.abs(z)
            ns += nsubt[s]

    ## Plotting Individual Scores and Components

    for k in range(top):

        z_k = Z[:, k]

        df_X = pd.DataFrame(X[ids_var[k]][0], columns=var_labels)

        df_Z = pd.DataFrame(z_k, columns=[f'Factor {k + 1}'])

        cm = sns.clustermap(df_Z,
                            vmin=np.min(Z),
                            vmax=np.max(Z),
                            cmap="vlag",
                            center=0.00,
                            row_colors=subtype_colors,
                            row_cluster=False,
                            col_cluster=False,
                            xticklabels=True,
                            yticklabels=patient_ids,
                            figsize=(2.5, 5)
                            )

        for label in subtype_labels.unique():
            cm.ax_col_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

        cm.ax_col_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)
        #cm.ax_col_dendrogram.legend(title='Relapse Status', loc="center", fontsize=8, ncol=1, bbox_transform=gcf().transFigure)
        cm.ax_col_dendrogram.legend(title='Disease Status', loc="center", fontsize=8, ncol=1, bbox_transform=gcf().transFigure)

        plt.savefig(f'{path}/infZ_clustermap_comp{k + 1}.png', dpi=300)
        plt.savefig(f'{path}/svgs/infZ_clustermap_comp{k + 1}.svg')
        plt.close()

        ## Plotting Components on Data Space

        cm = sns.clustermap(df_X.T,
                            vmin=np.min(X[ids_var[k]][0]),
                            vmax=np.max(X[ids_var[k]][0]),
                            cmap="vlag",
                            center=0.00,
                            row_colors=view_colors,
                            col_colors=subtype_colors,
                            row_cluster=True,
                            col_cluster=True,
                            xticklabels=False,
                            yticklabels=True,
                            figsize=(30, 30)
                            )

        for label in subtype_labels.unique():
            cm.ax_col_dendrogram.bar(0, 0, color=subtype_lut[label], label=label, linewidth=0)

        cm.ax_col_dendrogram.legend(loc="center", ncol=1, bbox_transform=gcf().transFigure)

        cm.ax_col_dendrogram.legend(title='Disease Status', loc="center", ncol=1, bbox_transform=gcf().transFigure)

        for label in view_labels.unique():
            cm.ax_row_dendrogram.bar(0, 0, color=view_lut[label], label=label, linewidth=0)
        cm.ax_row_dendrogram.legend(title='Modality', loc="upper left", ncol=1, bbox_transform=gcf().transFigure)

        plt.setp(cm.ax_heatmap.get_yticklabels(), fontsize=3)
        plt.setp(cm.ax_heatmap.yaxis.get_majorticklabels(), rotation=0)

        plt.savefig(f'{path}/infX_comp{k + 1}.png', dpi=300)
        plt.savefig(f'{path}/svgs/infX_comp{k + 1}.svg')
        plt.close()

    return subtype_scores, scores_dict


In [None]:
def plot_param(params, paths, args, cids=None, tr_vals=False):
    lcomps = list(range(1, params['W'].shape[1] + 1))

    ## Plotting Loading Matrices (W)

    if 'W' in params:
        W = params['W']
        pathW = paths['W']

        sns.heatmap(W,
                    vmin=-np.max(np.abs(W)),
                    vmax=np.max(np.abs(W)),
                    cmap="vlag",
                    yticklabels=False,
                    xticklabels=list(map(str, lcomps)))

        plt.xlabel('Factors', fontsize=11)
        plt.ylabel('D', fontsize=11)
        plt.title('Loading matrices (W)', fontsize=12)
        plt.savefig(f'{pathW}.png', dpi=200)
        plt.close()

    ## Plotting Lambda 
    if 'lmbW' in params:

        if cids is not None:
            lmbW = params['lmbW'][:, cids]
        else:
            lmbW = params['lmbW']

        pathlmbW = paths['lmbW']

        sns.heatmap(lmbW,
                    vmin=-np.max(np.abs(lmbW)),
                    vmax=np.max(np.abs(lmbW)),
                    cmap="vlag",
                    yticklabels=False,
                    xticklabels=list(map(str, lcomps)))

        plt.xlabel('Factors')
        plt.ylabel('D')
        plt.savefig(f'{pathlmbW}.png')
        plt.close()

    ## Plotting Z
    if 'Z' in params:
        Z = params['Z']
        pathZ = paths['Z']
        pathZ_svg = paths['Z_svg']

        plt.figure(figsize=(6, 5), dpi=300)
        dpi = plt.gcf().get_dpi()
        fontsize = 6 * (dpi / 100)

        sns.heatmap(Z,
                    vmin=-np.max(np.abs(Z)),
                    vmax=np.max(np.abs(Z)),
                    cmap="vlag",
                    yticklabels=False,
                    xticklabels=list(map(str, lcomps)))

        plt.xlabel('Factors', fontsize=fontsize)
        plt.ylabel('Latent variables', fontsize=fontsize)
        plt.xticks(fontsize=0.85 * fontsize)
        plt.savefig(f'{pathZ}.png')
        plt.savefig(f'{pathZ_svg}.svg')
        plt.close()

    ## Plotting Lambda Z
    if 'lmbZ' in params:

        if cids is not None:
            lmbZ = params['lmbZ'][:, cids]
        else:
            lmbZ = params['lmbZ']

        pathlmbZ = paths['lmbZ']

        sns.heatmap(lmbZ,
                    vmin=-np.max(np.abs(lmbZ)),
                    vmax=np.max(np.abs(lmbZ)),
                    cmap="vlag",
                    yticklabels=False,
                    xticklabels=list(map(str, lcomps)))

        plt.xlabel('Factors')
        plt.ylabel('Training samples')
        plt.savefig(f'{pathlmbZ}.png')
        plt.close()

    # Plotting Tau W
    if 'tauW_inf' in params:

        tau = params['tauW_inf']
        pathtau = paths['tauW']

        f, axes = plt.subplots(args.num_sources, 1, figsize=(8, 6))
        f.subplots_adjust(hspace=0.5, wspace=0.2)

        for m, ax in zip(range(args.num_sources), axes.flat):
            sns.histplot(tau[:, m], ax=ax, color='#2b8cbe')

            if 'synthetic' in args.dataset:
                ax.axvline(x=tr_vals['tauW'][0, m], color='red')

            ax.set_title(f'View {m + 1}')
            ax.set_ylabel('Number of samples')

        plt.savefig(f'{pathtau}.png')
        plt.close()

    # Plotting Sigmas
    if 'sigma_inf' in params:

        sigma = params['sigma_inf']
        pathsig = paths['sigma']

        f, axes = plt.subplots(args.num_sources, 1, figsize=(8, 6))
        f.subplots_adjust(hspace=0.5, wspace=0.2)

        for m, ax in zip(range(args.num_sources), axes.flat):
            sns.histplot(sigma[:, m], ax=ax, color='#2b8cbe')

            if 'synthetic' in args.dataset:
                ax.axvline(x=tr_vals['sigma'][m], color='red')

            ax.set_title(f'View {m + 1}')
            ax.set_ylabel('Number of samples')

        plt.savefig(f'{pathsig}.png')
        plt.close()

In [None]:
def plot_X(data, args, hypers, path, true_data=True):  # true_data=False IN THE ORIGINAL CODE => MAKE EXPERIMENTS

    ## Plotting Data (X) (helps in visualizing the reconstructed data from the latent variables and loading matrices)

    if true_data:
        X = np.dot(data['Z'], data['W'].T)                          # Calculates the product of Z and transpose of W.
        K = data['Z'].shape[1]                                      # Sets the number of components K to the number of columns in Z.

    else:
        X = np.zeros((data[0][0].shape[0], data[0][0].shape[1]))    # Initializes the matrix X with zeros, having the same shape as the first matrix in data
        K = len(data)                                               # Sets the number of components K to the length of data.

    ## Plotting Data (X) Components

    for k in range(K):

        if true_data:
            # Extracts the k-th column of the latent variables matrix Z and reshapes it into a column vector.
            z = np.reshape(data['Z'][:, k], (data['Z'].shape[0], 1))

            # Extracts the k-th column of the loading matrix W and reshapes it into a column vector.
            w = np.reshape(data['W'][:, k], (data['W'].shape[0], 1))

            # Computes the outer product of z and the transpose of w to obtain the k-th component of the data matrix.
            X_k = np.dot(z, w.T)

        else:
            # Extracts the k-th data matrix from data.
            X_k = data[k][0]

            # Adds the k-th data matrix to the cumulative data matrix X.
            X += X_k

        fig, axes = plt.subplots(ncols=args.num_sources)
        fig.subplots_adjust(wspace=0.02)

        Dm = hypers['Dm']
        d = 0

        # Iterates over the data sources.
        for m in range(args.num_sources):

            # Checks if the current data source is not the last one.
            if m < args.num_sources - 1:
                sns.heatmap(X_k[:, d:d + Dm[m]],
                            vmin=np.min(X_k),
                            vmax=np.max(X_k),
                            cmap="vlag",
                            ax=axes[m],
                            cbar=False,
                            xticklabels=False,
                            yticklabels=False
                            )

            else:
                sns.heatmap(X_k[:, d:d + Dm[m]],
                            vmin=np.min(X_k),
                            vmax=np.max(X_k),
                            cmap="vlag",
                            ax=axes[m],
                            cbar=True,
                            xticklabels=False,
                            yticklabels=False
                            )

            d += Dm[m]

        plt.title(f'Factor {k + 1} (Input space)')
        plt.savefig(f'{path}_comp{k + 1}.png')
        plt.close()

        ## Plotting Combined Data Matrix (X)

    plt.figure()

    sns.heatmap(X,
                vmin=np.min(X),
                vmax=np.max(X),
                cmap="vlag",
                xticklabels=False,
                yticklabels=False)

    plt.xlabel('D')
    plt.ylabel('N')
    plt.savefig(f'{path}.png')
    plt.close()

In [None]:
def define_box_properties(plot_name, color_code, label):
    # Iterates over the items (key-value pairs) in the plot_name dictionary, and returns a view object that displays a list of a dictionary's key-value tuple pairs.
    for k, v in plot_name.items():
        # Sets the property of each plot element to the specified color.
        plt.setp(plot_name.get(k), color=color_code)

        # Adds an invisible plot to use its properties for the legend.
    plt.plot([], c=color_code, label=label)

    plt.legend()