In [None]:
from dotenv import load_dotenv
load_dotenv()
import os
import sys
sys.path.append(os.getenv('PYTHONPATH')) 
import numpy as np
import h5py
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

#local
from src.utils.transforms import SelectROIs

In [None]:
root = os.path.join(os.getenv("DATASETS_ROOT", "/default/path/to/datasets"), "MOSAIC")
project_root = os.path.join(os.getenv("PROJECT_ROOT"))
print(f"root: {root}")
print(f"project root: {project_root}")
config = {
    'fmri': {
        'rois': [f"GlasserGroup_{x}" for x in range(1,6)],
    }
}
ROI_selection = SelectROIs(selected_rois=config['fmri']['rois'])

In [None]:
cols = ['subject', 'dataset', 'median_nc', 'mean_nc']
print("loading hdf5 file...")
with h5py.File(os.path.join(root,'mosaic_version-1_0_0_chunks_renamed.hdf5'), 'r') as file:
    print(f"Keys: {file.keys()}")
    data = {col: [] for col in cols}
    for subjectID in file.keys():
        if 'noiseceilings' in file[subjectID].keys():
            dataset = subjectID.split('_')[-1]
            if dataset in ['THINGS', 'BOLD5000', 'GOD']:
                task = 'test'
            else:
                task = 'train'
            nc = file[subjectID]['noiseceilings'][f"{subjectID}_phase-{task}_n-1_noiseceiling"][ROI_selection.selected_roi_indices]/100
            data['subject'].append(subjectID)
            data['dataset'].append(dataset)
            data['median_nc'].append(np.median(nc))
            data['mean_nc'].append(np.mean(nc))


In [None]:
df = pd.DataFrame(data)

In [None]:
sort_col = 'mean_nc'
df_sorted = df.sort_values(sort_col, ascending=False)
fs = 10
# Create the plot
plt.figure(figsize=(15, 6))
ax = sns.barplot(data=df_sorted, 
            x='subject', 
            y=sort_col,
            hue='dataset',
            dodge=False)  # dodge=False ensures bars aren't grouped by dataset

# Customize the plot
plt.xticks(rotation=45, ha='right')  # Rotate x-axis labels
plt.ylabel(sort_col, fontsize=fs)
plt.xlabel('SubjectID', fontsize=fs)
plt.xticks(rotation=45, fontsize=fs)
plt.yticks(fontsize=fs)
sns.despine(ax=ax)
plt.ylim([0,0.25])
plt.tight_layout()  # Adjust layout to prevent label cutoff
plt.savefig(os.path.join(project_root, "src", "fmriDatasetPreparation", "process_nans", "output", "plots", f"dataquality_{sort_col}.svg"))
plt.show()