In [None]:
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import scipy
from sklearn.preprocessing import MinMaxScaler
from nilearn.plotting import plot_surf_stat_map
from nilearn.surface import vol_to_surf
from nilearn.datasets import fetch_surf_fsaverage, fetch_atlas_difumo
from umap import UMAP

from datasets_shuffle.eegfmri_vu import EEGfMRIVuDataset, EEGfMRIVuDatasetConfig
from datasets_shuffle.NIH_ecr import NIHECRDataset, NIHECRDatasetConfig
from datasets_shuffle.NIH_ect import NIHECTDataset, NIHECTDatasetConfig

DATASET_FUNCTIONS = {
    "eegfmri_vu": [EEGfMRIVuDataset, EEGfMRIVuDatasetConfig],
    "NIH_ecr": [NIHECRDataset, NIHECRDatasetConfig],
    "NIH_ect": [NIHECTDataset, NIHECTDatasetConfig],
}

def build_dataset(dataset_name):
    dataset_builder = DATASET_FUNCTIONS[dataset_name][0]
    dataset_config = DATASET_FUNCTIONS[dataset_name][1]()
    dataset_dict = {
        "train": dataset_builder(
            dataset_config,
            split_set="train",
        ),
        "test": dataset_builder(
            dataset_config,
            split_set="test",
        ),
        "zero_shot": dataset_builder(
            dataset_config,
            split_set="zero_shot",
        ),
    }
    return dataset_dict, dataset_config

In [None]:
dataset_name = ""
dataset, dataset_config = build_dataset(dataset_name)

In [None]:
model_path = "outputs/cbrain/"
dataset_path = dataset_name + "/predictions_" + dataset_name + "/"
dataset_partition = "zero_shot"

fmri_feats_total_path = "fmri_feats_total.csv"
eeg_feats_total_path = "eeg_feats_total.csv"
fmri_map_feats_total_path = "fmri_map_feats_total.csv"
eeg_map_feats_total_path = "eeg_map_feats_total.csv"
bad_tr_total_path = "bad_tr_total.csv"
eeg_index_total_path = "eeg_index_total.csv"
predictions_total_path = "predictions_total.csv"
indices_total_path = "indices_total.csv"

In [None]:
def collect_data(dataset_path, curr_dataset, curr_dataset_config):
    sample_data = {}
    sample_data["fmri_feats_total"] = np.array(pd.read_csv(os.path.join(dataset_path, fmri_feats_total_path), header=None))
    sample_data["eeg_feats_total"] = np.array(pd.read_csv(os.path.join(dataset_path, eeg_feats_total_path), header=None))
    sample_data["fmri_map_feats_total"] = np.array(pd.read_csv(os.path.join(dataset_path, fmri_map_feats_total_path), header=None))
    sample_data["eeg_map_feats_total"] = np.array(pd.read_csv(os.path.join(dataset_path, eeg_map_feats_total_path), header=None))
    sample_data["bad_tr_total"] = np.array(pd.read_csv(os.path.join(dataset_path, bad_tr_total_path), header=None))
    sample_data["eeg_index_total"] = np.array(pd.read_csv(os.path.join(dataset_path, eeg_index_total_path), header=None))
    sample_data["predictions_total"] = np.array(pd.read_csv(os.path.join(dataset_path, predictions_total_path), header=None))
    sample_data["indices_total"] = np.array(pd.read_csv(os.path.join(dataset_path, indices_total_path), header=None))
    inverse_indices = np.argsort(sample_data["indices_total"].reshape(-1))

    sample_data_seq = {}
    sample_data_seq["fmri_feats_total"] = sample_data["fmri_feats_total"][inverse_indices]
    sample_data_seq["eeg_feats_total"] = sample_data["eeg_feats_total"][inverse_indices]
    sample_data_seq["fmri_map_feats_total"] = sample_data["fmri_map_feats_total"][inverse_indices]
    sample_data_seq["eeg_map_feats_total"] = sample_data["eeg_map_feats_total"][inverse_indices]
    sample_data_seq["bad_tr_total"] = sample_data["bad_tr_total"][inverse_indices]
    sample_data_seq["eeg_index_total"] = sample_data["eeg_index_total"][inverse_indices]
    sample_data_seq["predictions_total"] = sample_data["predictions_total"][inverse_indices]
    sample_data_seq["indices_total"] = sample_data["indices_total"][inverse_indices]

    sample_data_seq["scan_names"] = curr_dataset.scan_names
    scan_len = len(sample_data_seq["scan_names"])

    eeg_index_data_reference_total = []
    if curr_dataset_config.name == "eegfmri_vu":
        for idx in range(scan_len):
            scan_name = sample_data_seq["scan_names"][idx]
            eeg_index_path = os.path.join("vigall_path", scan_name+'_eeg_vig-sig.mat')
            eeg_index_data_raw = scipy.io.loadmat(eeg_index_path)
            slide_cnt = 114
            eeg_index_data_reference = eeg_index_data_raw['VIG_SIG'][0][0][1][2:slide_cnt*5+2]
            eeg_index_data_reference_total.append(eeg_index_data_reference)
        eeg_index_data_reference_total = np.stack(eeg_index_data_reference_total, axis=0)
        sample_data_seq["eeg_index_data_reference_total"] = eeg_index_data_reference_total
    
    else:
        for idx in range(scan_len):
            scan_name = sample_data_seq["scan_names"][idx]
            eeg_scan_paths = glob.glob(os.path.join("eeg_data_path", scan_name+'*.set'))
            eeg_scan_name = scan_name
            self_scan_task_dict = {'ect': 'ectp'}
            if len(eeg_scan_paths) == 0:
                scan_task = scan_name[17:20]
                if scan_task in self_scan_task_dict:
                    eeg_scan_name = scan_name[:17] + self_scan_task_dict[scan_task] + scan_name[20:]
                    eeg_scan_paths = glob.glob(os.path.join("eeg_data_path", eeg_scan_name+'*.set'))
                if len(eeg_scan_paths) == 0:
                    eeg_scan_name = scan_name[:20] + '_' + scan_name[21:]
                    eeg_scan_paths = glob.glob(os.path.join("eeg_data_path", eeg_scan_name+'*.set'))
            scan_vigall_index = scan_name[0:12] + "{:04d}".format(int(eeg_scan_name[12:16]) + 1)
            eeg_index_path = glob.glob(os.path.join("eeg_data_path", scan_vigall_index+'*.mat'))
            eeg_index_data_raw = scipy.io.loadmat(eeg_index_path[0])
            slide_cnt = 138
            eeg_index_data_reference = eeg_index_data_raw['VIG_SIG'][0][0][1][2:slide_cnt*5+2]
            eeg_index_data_reference_total.append(eeg_index_data_reference)
        eeg_index_data_reference_total = np.stack(eeg_index_data_reference_total, axis=0)
        sample_data_seq["eeg_index_data_reference_total"] = eeg_index_data_reference_total

    for key in sample_data_seq.keys():
        if key == "scan_names":
            continue
        elif key != "eeg_index_data_reference_total":
            sample_data_seq[key] = sample_data_seq[key].reshape(scan_len, sample_data_seq[key].shape[0]//scan_len, -1)
        print(f"{key}" + f": {sample_data_seq[key].shape}")
    
    return sample_data_seq

In [None]:
def plot_results(dataset_path, sample_data_seq):
    scan_len = len(sample_data_seq["scan_names"])
    save_folder = os.path.join(dataset_path, "plots")
    os.makedirs(save_folder, exist_ok=True)

    for idx in range(scan_len):
        scan_name = sample_data_seq["scan_names"][idx]
        eeg_index = sample_data_seq["eeg_index_total"][idx]
        eeg_index_data_reference = sample_data_seq["eeg_index_data_reference_total"][idx]
        bad_tr = sample_data_seq["bad_tr_total"][idx]
        predictions = sample_data_seq["predictions_total"][idx]
        full_eeg_index = []
        full_bad_tr = []
        full_predictions = []
        eeg_index_val = -1
        bad_tr_val = -1
        prediction_val = -1
        for index in range(len(eeg_index)):
            eeg_index_val = eeg_index[index][0]
            bad_tr_val = bad_tr[index][0]
            prediction_val = predictions[index][0]
            for _ in range(5):
                full_eeg_index.append(eeg_index_val)
                full_bad_tr.append(bad_tr_val)
                full_predictions.append(prediction_val)
        for _ in range(5):
            full_eeg_index.append(eeg_index_val)
            full_bad_tr.append(bad_tr_val)
            full_predictions.append(prediction_val)
        full_eeg_index = np.array(full_eeg_index)
        full_bad_tr = np.array(full_bad_tr).reshape(-1)
        full_predictions = np.array(full_predictions)
        x = np.arange(len(full_eeg_index))
        plt.figure(figsize=(20, 6))
        plt.plot(x, full_eeg_index, label='10-frame-wise ground truth', linewidth=5, color='orange')
        plt.plot(x, eeg_index_data_reference, label='VIGALL labels ranging 2 to 6', linewidth=3, color='green')
        plt.plot(x, full_predictions, label='predictions', linewidth=4, color='blue')

        start_x, end_x = None, None
        for i in range(len(full_bad_tr)):
            if full_bad_tr[i] > 0 and start_x is None:
                start_x = x[i]
            elif full_bad_tr[i] == 0 and start_x is not None:
                end_x = x[i] - 0.5
                plt.axvspan(start_x, end_x, color='gray', alpha=0.3)
                start_x, end_x = None, None

        if start_x is not None:
            plt.axvspan(start_x, x[-1]+0.05, color='gray', alpha=0.3)

        plt.xlabel('time', fontsize=25)
        plt.ylabel('VIGALL labels', fontsize=25)
        plt.xticks(fontsize=16)  
        plt.yticks(fontsize=16)
        plt.ylim(-0.25, 6.25) 
        plt.legend(fontsize=20, loc='upper right', frameon=True)
        plt.grid(True)
        save_path = os.path.join(save_folder, f"{scan_name}_predictions.png")
        plt.savefig(save_path, dpi=300, bbox_inches='tight')  
        plt.close()  
        print(f"Plot saved at: {save_path}")


In [None]:
def plot_umap_cluster(dataset_path, sample_data_seq, neighbour_cnt=30):
    save_folder = os.path.join(dataset_path, "umaps")
    os.makedirs(save_folder, exist_ok=True)

    fmri_feats_total = sample_data_seq["fmri_feats_total"].reshape(-1, sample_data_seq["fmri_feats_total"].shape[2])
    eeg_feats_total = sample_data_seq["eeg_feats_total"].reshape(-1, sample_data_seq["eeg_feats_total"].shape[2])
    fmri_map_feats_total = sample_data_seq["fmri_map_feats_total"].reshape(-1, sample_data_seq["fmri_map_feats_total"].shape[2])
    eeg_map_feats_total = sample_data_seq["eeg_map_feats_total"].reshape(-1, sample_data_seq["eeg_map_feats_total"].shape[2])
    bad_tr_total = sample_data_seq["bad_tr_total"].reshape(-1)
    eeg_index_total = sample_data_seq["eeg_index_total"].reshape(-1)

    remove_fmri_feats_total = fmri_feats_total[bad_tr_total == 0]
    remove_eeg_feats_total = eeg_feats_total[bad_tr_total == 0]
    remove_fmri_map_feats_total = fmri_map_feats_total[bad_tr_total == 0]
    remove_eeg_map_feats_total = eeg_map_feats_total[bad_tr_total == 0]
    remove_eeg_index_total = eeg_index_total[bad_tr_total == 0]

    umap_3d_fmri_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_3d_eeg_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_3d_fmri_map_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_3d_eeg_map_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    proj_fmri_feats = umap_3d_fmri_feats.fit_transform(remove_fmri_feats_total)
    proj_eeg_feats = umap_3d_eeg_feats.fit_transform(remove_eeg_feats_total)
    proj_fmri_map_feats = umap_3d_fmri_map_feats.fit_transform(remove_fmri_map_feats_total)
    proj_eeg_map_feats = umap_3d_eeg_map_feats.fit_transform(remove_eeg_map_feats_total)

    fmri_feats_individual_total = remove_fmri_feats_total.reshape(remove_fmri_feats_total.shape[0], 66, -1)
    eeg_feats_individual_total = remove_eeg_feats_total.reshape(remove_eeg_feats_total.shape[0], 26, -1)
    fmri_map_feats_individual_total = remove_fmri_map_feats_total.reshape(remove_fmri_map_feats_total.shape[0], 10, -1)
    eeg_map_feats_individual_total = remove_eeg_map_feats_total.reshape(remove_eeg_map_feats_total.shape[0], 10, -1)
    remove_eeg_index_label_total = remove_eeg_index_total.reshape(remove_eeg_index_total.shape[0])

    alert_fmri_feats_individual_total = fmri_feats_individual_total[remove_eeg_index_label_total == 1]
    alert_eeg_feats_individual_total = eeg_feats_individual_total[remove_eeg_index_label_total == 1]
    drowsy_fmri_feats_individual_total = fmri_feats_individual_total[remove_eeg_index_label_total == 0] 
    drowsy_eeg_feats_individual_total = eeg_feats_individual_total[remove_eeg_index_label_total == 0]
    alert_fmri_feats_channels = np.array([np.mean(alert_fmri_feats_individual_total[:, i, :], axis=0) for i in range(66)]) 
    drowsy_fmri_feats_channels = np.array([np.mean(drowsy_fmri_feats_individual_total[:, i, :], axis=0) for i in range(66)]) 
    total_fmri_feats_channels = np.vstack((alert_fmri_feats_channels, drowsy_fmri_feats_channels))
    total_fmri_feats_channels_labels = np.vstack((np.ones((66, 1)), np.zeros((66, 1))))
    umap_2d_fmri_channel_feats = UMAP(n_components=2, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    proj_fmri_channel_feats = umap_2d_fmri_channel_feats.fit_transform(total_fmri_feats_channels) 

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_fmri_channel_feats[:, 0], proj_fmri_channel_feats[:, 1], c=total_fmri_feats_channels_labels, cmap='viridis', s=5)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of channel wise fmri_feats")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_channel_wise_fmri.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()
    total_fmri_feats_individual_channels_labels = np.arange(132).reshape(132, 1)
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_fmri_channel_feats[:, 0], proj_fmri_channel_feats[:, 1], c=total_fmri_feats_individual_channels_labels, cmap=plt.cm.get_cmap('tab10', 132), s=5)
    for i, (x, y_pos) in enumerate(proj_fmri_channel_feats):
        plt.text(x, y_pos, str(i), fontsize=8, alpha=0.7)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of channel wise fmri_feats with labels")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_channel_wise_fmri_withlabels.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    alert_eeg_feats_channels = np.array([np.mean(alert_eeg_feats_individual_total[:, i, :], axis=0) for i in range(26)]) 
    drowsy_eeg_feats_channels = np.array([np.mean(drowsy_eeg_feats_individual_total[:, i, :], axis=0) for i in range(26)]) 
    total_eeg_feats_channels = np.vstack((alert_eeg_feats_channels, drowsy_eeg_feats_channels))
    total_eeg_feats_channels_labels = np.vstack((np.ones((26, 1)), np.zeros((26, 1))))
    umap_2d_eeg_channel_feats = UMAP(n_components=2, n_neighbors=20, min_dist=0.1, metric='cosine', init='random', random_state=0)
    proj_eeg_channel_feats = umap_2d_eeg_channel_feats.fit_transform(total_eeg_feats_channels)
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_eeg_channel_feats[:, 0], proj_eeg_channel_feats[:, 1], c=total_eeg_feats_channels_labels, cmap='viridis', s=5)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of channel wise eeg_feats")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_channel_wise_eeg.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()
    total_eeg_feats_individual_channels_labels = np.arange(52).reshape(52, 1)
    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_eeg_channel_feats[:, 0], proj_eeg_channel_feats[:, 1], c=total_eeg_feats_individual_channels_labels, cmap=plt.cm.get_cmap('tab10', 52), s=5)
    for i, (x, y_pos) in enumerate(proj_eeg_channel_feats):
        plt.text(x, y_pos, str(i), fontsize=8, alpha=0.7)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of channel wise eeg_feats with labels")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_channel_wise_eeg_withlabels.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    # 3d visualizations of fMRI feats.
    umap_3d_fmri_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_3d_eeg_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_3d_fmri_map_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_3d_eeg_map_feats = UMAP(n_components=3, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)

    proj_fmri_feats = umap_3d_fmri_feats.fit_transform(remove_fmri_feats_total)
    proj_eeg_feats = umap_3d_eeg_feats.fit_transform(remove_eeg_feats_total)
    proj_fmri_map_feats = umap_3d_fmri_map_feats.fit_transform(remove_fmri_map_feats_total)
    proj_eeg_map_feats = umap_3d_eeg_map_feats.fit_transform(remove_eeg_map_feats_total)

    fig_fmri_feats = plt.figure(figsize=(10, 8))
    ax = fig_fmri_feats.add_subplot(111, projection='3d')
    scatter = ax.scatter(proj_fmri_feats[:, 0], proj_fmri_feats[:, 1], proj_fmri_feats[:, 2],
                        c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    ax.set_title("UMAP Visualization (3D) of remove bad tr fmri_feats")
    ax.set_xlabel("UMAP Dimension 1")
    ax.set_ylabel("UMAP Dimension 2")
    ax.set_zlabel("UMAP Dimension 3")
    image_path = os.path.join(save_folder, "umap_3d_removebad_fmri_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    fig_eeg_feats = plt.figure(figsize=(10, 8))
    ax = fig_eeg_feats.add_subplot(111, projection='3d')
    scatter = ax.scatter(proj_eeg_feats[:, 0], proj_eeg_feats[:, 1], proj_eeg_feats[:, 2],
                        c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    ax.set_title("UMAP Visualization (3D) of remove bad tr eeg_feats")
    ax.set_xlabel("UMAP Dimension 1")
    ax.set_ylabel("UMAP Dimension 2")
    ax.set_zlabel("UMAP Dimension 3")
    image_path = os.path.join(save_folder, "umap_3d_removebad_eeg_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    fig_fmri_map_feats = plt.figure(figsize=(10, 8))
    ax = fig_fmri_map_feats.add_subplot(111, projection='3d')
    scatter = ax.scatter(proj_fmri_map_feats[:, 0], proj_fmri_map_feats[:, 1], proj_fmri_map_feats[:, 2],
                        c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    ax.set_title("UMAP Visualization (3D) of remove bad tr fmri_map_feats of ")
    ax.set_xlabel("UMAP Dimension 1")
    ax.set_ylabel("UMAP Dimension 2")
    ax.set_zlabel("UMAP Dimension 3")
    image_path = os.path.join(save_folder, "umap_3d_removebad_fmri_map_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    fig_eeg_map_feats = plt.figure(figsize=(10, 8))
    ax = fig_eeg_map_feats.add_subplot(111, projection='3d')
    scatter = ax.scatter(proj_eeg_map_feats[:, 0], proj_eeg_map_feats[:, 1], proj_eeg_map_feats[:, 2],
                        c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    ax.set_title("UMAP Visualization (3D) of remove bad tr eeg_map_feats of ")
    ax.set_xlabel("UMAP Dimension 1")
    ax.set_ylabel("UMAP Dimension 2")
    ax.set_zlabel("UMAP Dimension 3")
    image_path = os.path.join(save_folder, "umap_3d_removebad_eeg_map_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    # visualize umap 2d features
    umap_2d_fmri_feats = UMAP(n_components=2, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_2d_eeg_feats = UMAP(n_components=2, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_2d_fmri_map_feats = UMAP(n_components=2, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    umap_2d_eeg_map_feats = UMAP(n_components=2, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)

    proj_fmri_feats_2d = umap_2d_fmri_feats.fit_transform(remove_fmri_feats_total)
    proj_eeg_feats_2d = umap_2d_eeg_feats.fit_transform(remove_eeg_feats_total)
    proj_fmri_map_feats_2d = umap_2d_fmri_map_feats.fit_transform(remove_fmri_map_feats_total)
    proj_eeg_map_feats_2d = umap_2d_eeg_map_feats.fit_transform(remove_eeg_map_feats_total)

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_fmri_feats_2d[:, 0], proj_fmri_feats_2d[:, 1], c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of fmri_feats")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_2d_removebad_fmri_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_eeg_feats_2d[:, 0], proj_eeg_feats_2d[:, 1], c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of eeg_feats")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_2d_removebad_eeg_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_fmri_map_feats_2d[:, 0], proj_fmri_map_feats_2d[:, 1], c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of fmri_map_feats")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_2d_removebad_fmri_map_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    plt.figure(figsize=(8, 6))
    scatter = plt.scatter(proj_eeg_map_feats_2d[:, 0], proj_eeg_map_feats_2d[:, 1], c=remove_eeg_index_total, cmap='viridis', s=2)
    plt.colorbar(scatter, label='Label')
    plt.title("UMAP Visualization of eeg_map_feats")
    plt.xlabel("UMAP Dimension 1")
    plt.ylabel("UMAP Dimension 2")
    image_path = os.path.join(save_folder, "umap_2d_removebad_eeg_map_feats.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()
    

In [None]:
def plot_umap_surface(dataset_path, sample_data_seq_1, neighbour_cnt=30):
    join_dataset = {}
    join_dataset["scan_names"] = sample_data_seq_1["scan_names"] 
    for key in sample_data_seq_1.keys():
        if key == "scan_names" or key == "eeg_index_data_reference_total":
            continue
        x1 = sample_data_seq_1[key].reshape(-1, 1, sample_data_seq_1[key].shape[2])
        join_dataset["sample_data_seq_1_cnt"] = x1.shape[0]
        join_dataset["sample_data_seq_1_removebad_cnt"] = x1.shape[0] - sum(sample_data_seq_1["bad_tr_total"].reshape(-1))
        break
    
    save_folder = os.path.join(dataset_path, "umaps")
    os.makedirs(save_folder, exist_ok=True)

    # extract fMRI and EEG features
    fmri_feats_total_1 = sample_data_seq_1["fmri_feats_total"].reshape(-1, sample_data_seq_1["fmri_feats_total"].shape[2])
    bad_tr_total_1 = sample_data_seq_1["bad_tr_total"].reshape(-1)
    eeg_index_total_1 = sample_data_seq_1["eeg_index_total"].reshape(-1)

    # remove bad tr
    remove_fmri_feats_total_1 = fmri_feats_total_1[bad_tr_total_1 == 0]
    remove_eeg_index_total_1 = eeg_index_total_1[bad_tr_total_1 == 0]

    # extract channel-wise features
    fmri_feats_individual_total_1 = remove_fmri_feats_total_1.reshape(remove_fmri_feats_total_1.shape[0], 66, -1)
    remove_eeg_index_label_total_1 = remove_eeg_index_total_1.reshape(remove_eeg_index_total_1.shape[0])
    
    # extract channel-wise features based on ground truth
    alert_fmri_feats_individual_total_1 = fmri_feats_individual_total_1[remove_eeg_index_label_total_1 == 1]
    drowsy_fmri_feats_individual_total_1 = fmri_feats_individual_total_1[remove_eeg_index_label_total_1 == 0] 

    alert_fmri_feats_channels_1 = np.array([np.mean(alert_fmri_feats_individual_total_1[:, i, :], axis=0) for i in range(66)]) 
    drowsy_fmri_feats_channels_1 = np.array([np.mean(drowsy_fmri_feats_individual_total_1[:, i, :], axis=0) for i in range(66)]) 
    total_fmri_feats_channels = np.vstack((alert_fmri_feats_channels_1, drowsy_fmri_feats_channels_1))

    umap_1d_fmri_channel_feats = UMAP(n_components=1, n_neighbors=neighbour_cnt, min_dist=0.1, metric='cosine', init='random', random_state=0)
    proj_fmri_channel_feats_1d = umap_1d_fmri_channel_feats.fit_transform(total_fmri_feats_channels) # # (264, 1)
    alert_proj_feats_1 = proj_fmri_channel_feats_1d[:66, :] # (66, 2)
    drowsy_proj_feats_1 = proj_fmri_channel_feats_1d[66:, :] # (66, 2)

    # seq_1
    # umap_alert_features_left
    fsaverage = fetch_surf_fsaverage()
    difumo = fetch_atlas_difumo(dimension=64)
    difumo_labels = difumo['maps']
    scaler = MinMaxScaler()
    roi_on_left_surface = vol_to_surf(difumo_labels, fsaverage['pial_left'])
    alert_left_roi_assignments = np.dot(roi_on_left_surface, alert_proj_feats_1[:64, 0])
    alert_left_roi_assignments = scaler.fit_transform(alert_left_roi_assignments.reshape(-1, 1)).reshape(-1)
    fig0 = plt.figure()
    plot_surf_stat_map(
        fsaverage['infl_left'],  
        stat_map=alert_left_roi_assignments, 
        hemi='left',
        colorbar=True,
        cmap='Spectral_r',
        title="Alert UMAP Features on left Cortical Surface",
        bg_map=fsaverage['sulc_left'] 
    )
    image_path = os.path.join(save_folder, "reverse_umap_alert_features_left_.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    # umap_alert_features_right
    roi_on_right_surface = vol_to_surf(difumo_labels, fsaverage['pial_right'])
    alert_right_roi_assignments = np.dot(roi_on_right_surface, alert_proj_feats_1[:64, 0])
    alert_right_roi_assignments = scaler.fit_transform(alert_right_roi_assignments.reshape(-1, 1)).reshape(-1)
    fig1 = plt.figure()
    plot_surf_stat_map(
        fsaverage['infl_right'],  
        stat_map=alert_right_roi_assignments, 
        hemi='right',
        colorbar=True,
        cmap='Spectral_r',
        title="Alert UMAP Features on right Cortical Surface",
        bg_map=fsaverage['sulc_right'] 
    )
    image_path = os.path.join(save_folder, "reverse_umap_alert_features_right_.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    # umap_drowsy_features_left
    drowsy_left_roi_assignments = np.dot(roi_on_left_surface, drowsy_proj_feats_1[:64, 0])
    drowsy_left_roi_assignments = scaler.fit_transform(drowsy_left_roi_assignments.reshape(-1, 1)).reshape(-1)
    fig2 = plt.figure()
    plot_surf_stat_map(
        fsaverage['infl_left'],  
        stat_map=drowsy_left_roi_assignments, 
        hemi='left',
        colorbar=True,
        cmap='Spectral_r',
        title="Drowsy UMAP Features on left Cortical Surface",
        bg_map=fsaverage['sulc_left'] 
    )
    image_path = os.path.join(save_folder, "reverse_umap_drowsy_features_left.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

    # umap_drowsy_features_right
    drowsy_right_roi_assignments = np.dot(roi_on_right_surface, drowsy_proj_feats_1[:64, 0])
    drowsy_right_roi_assignments = scaler.fit_transform(drowsy_right_roi_assignments.reshape(-1, 1)).reshape(-1)
    fig2 = plt.figure()
    plot_surf_stat_map(
        fsaverage['infl_right'],  
        stat_map=drowsy_right_roi_assignments, 
        hemi='right',
        colorbar=True,
        cmap='Spectral_r',
        title="Drowsy UMAP Features on right Cortical Surface",
        bg_map=fsaverage['sulc_right'] 
    )
    image_path = os.path.join(save_folder, "reverse_umap_drowsy_features_right.jpeg")
    plt.savefig(image_path, format='jpeg', dpi=300, bbox_inches='tight')
    print(f"Image saved at: {image_path}")
    plt.close()

In [None]:
dataset_final = collect_data(model_path+dataset_path, dataset[dataset_partition], dataset_config)
plot_results(model_path+dataset_path, dataset_final)

In [None]:
plot_umap_cluster(model_path+dataset_path, dataset_final)

In [None]:
plot_umap_surface(model_path+dataset_path, dataset_final)