### MRIQC documentation: https://mriqc.readthedocs.io/en/latest/
### MRIQCeption: https://github.com/elizabethbeard/mriqception
### Example Notebook for querying the MRIQC Web API: https://notebook.community/poldracklab/mriqc/docs/notebooks/MRIQC%20Web%20API


In [None]:
# -*- coding: utf-8 -*-
"""
Created on Tue Dec 21 17:01:52 2021

@author: Ben3
"""

import json
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import numpy as np

def get_MRIQC_Metrics(taskList, subjects, metricList, mriqcpath):

    tasks = taskList
    col = []
    col.append('Source')
    col.append('Subject')
    col.append('Task')
    for m in metricList:
        col.append(m)
    data_IQM = {c: [] for c in col}
    
    for sub in subjects:
        for task in tasks:
            if task == 'T1w':
                json_dir = os.path.join(mriqcpath, sub, "ses-01", "anat", f"{sub}_ses-01_{task}")
                file_list = glob.glob(json_dir + "*.json")
                assert(len(file_list)==1)
    
            elif task == 'T2w':
                json_dir = os.path.join(mriqcpath, sub, "ses-01", "anat", f"{sub}_ses-01_{task}")
                file_list = glob.glob(json_dir + "*.json")
                assert(len(file_list)==1)
    
            elif task == 'localizer':
                json_dir = os.path.join(mriqcpath, sub, "ses-01", "func", f"{sub}_ses-01_task-{task}_run-")
                file_list = glob.glob(json_dir + "*_bold.json") #get all runs. Should be 5
                assert(len(file_list)==5)
    
            elif task == 'rest':
                json_dir = os.path.join(mriqcpath, sub, "ses-01", "func", f"{sub}_ses-01_task-{task}_run-")
                file_list = glob.glob(json_dir + "*_bold.json") #get all runs. Should be 5
                assert(len(file_list)==5)
    
            elif task == 'train':
                file_list = []
                for ses in ['ses-02','ses-03','ses-04','ses-05']:
                    json_dir = os.path.join(mriqcpath, sub, ses, "func", f"{sub}_{ses}_task-{task}_run-")
                    file_list.extend(glob.glob(json_dir + "*_bold.json")) #get all runs. Should be 10 per session
                assert(len(file_list)==40)
                
            elif task == 'test':
                file_list = []
                for ses in ['ses-02','ses-03','ses-04','ses-05']:
                    json_dir = os.path.join(mriqcpath, sub, ses, "func", f"{sub}_{ses}_task-{task}_run-")
                    file_list.extend(glob.glob(json_dir + "*_bold.json")) #get all runs. Should be 10 per session
                assert(len(file_list)==12)
                
            else:
                print("Invalid task name")
                
            for file in file_list:
                with open(file, "r") as read_file:
                    data = json.load(read_file)
                data_IQM['Subject'].append(sub)
                data_IQM['Task'].append(task)   
                data_IQM['Source'] = 'BMD'                 
                for m in metricList:
                    data_IQM[m].append(data[m])
    df = pd.DataFrame(data=data_IQM)
    
    return df

def mriqception_df(mriqception_df, modality, MetricList):
    col = []
    col.append('Source')
    col.append('Subject')
    col.append('Task')
    for m in MetricList:
        col.append(m)
    data_IQM = {c: [] for c in col}
    num_rows = mriqception_df.shape[0]
    for m in MetricList:
        vals = mriqception_df.loc[:,m].values
        if np.isnan(vals).all():
            print("isnan")
        else:
            data_IQM[m] = vals
    if modality == 'bold':
        data_IQM['Source'] = ['mriqception']
        data_IQM['Subject'] = ['']
        data_IQM['Task'] = [modality]
    data_IQM['Source'] = ['mriqception']
    data_IQM['Subject'] = ['mriqception']
    data_IQM['Task'] = [modality]
    df = pd.DataFrame(data=data_IQM, index=range(num_rows))
    return df
    
def boxplot_anatomical(data, MetricList, savefig=False):
    for m, lim in MetricList.items():
        plt.figure(figsize=(15,8))
        ax = sns.boxplot(x='Task', y=m, data=data, hue='Source')
        sns.stripplot(x="Task",
                      y=m,
                      data=data,
                      color="black",
                      hue='Source',
                      dodge=True,
                      edgecolor="gray")
        ax.set_title(m, fontsize=30)
        ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)
        ax.set_ylabel('')
        ax.set_xlabel('')
        ax.set_ylim(lim)
        ax.legend().set_visible(False)
        if savefig:
            if not os.path.exists(savefig):
                os.makedirs(savefig)
            fname = os.path.join(savefig, f"modality-anatomical_IQM-{m}_mriqception.svg")
            plt.savefig(fname)
        plt.show()
        plt.clf()
        
def boxplot_functional(data, MetricList, tasks=['train','test','bold'], savefig=False):
    for m, lim in MetricList.items():
        plt.figure(figsize=(15,8))
        ax = sns.boxplot(x='Subject', y=m, data=data, hue='Task')
        sns.stripplot(x='Subject',
                      y=m, 
                      data=data,
                      hue='Task',
                      color="black",
                      dodge=True,
                      edgecolor="gray")
        ax.set_title(m, fontsize=30)
        ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)
        plt.xticks(fontsize=20)
        plt.yticks(fontsize=20)
        ax.set_ylabel('')
        ax.set_xlabel('')
        ax.set_ylim(lim)
        ax.legend().set_visible(False)
        if savefig:
            if not os.path.exists(savefig):
                os.makedirs(savefig)
            fname = os.path.join(savefig, f"modality-functional_IQM-{m}_task-{'_'.join(tasks)}_mriqception.svg")
            plt.savefig(fname)
        plt.show()
        plt.clf()
        
root = "/your/path/to/BOLDMomentsDataset"
project_root = os.path.join(root, "derivatives","versionA")
save_root = "your/save/path" #path to where you want to save plots and results
subjects = [f"sub-{s:02}" for s in range(1,11)]

metricList_anatomical = {'snr_total': [4,19],'cnr':[-0.5, 5.5],'fwhm_avg':[2,4.5],'cjv':[0,1.8],'efc':[0.35, 0.8],'fber':[-2000,17000]}
metricList_functional = {'aqi':[0,0.0225],'aor':[0,0.016],'fd_mean':[0,0.75],'fwhm_avg':[2.1,3.2],'snr':[1.75,6],'tsnr':[20,100]}

#load mriqception data
tmp_T1w = pd.read_pickle(os.path.join(project_root, "MRIQC", "MRIQCEPTION_API_modality-T1w_TeslaMin-2_TeslaMax-4_TRMin-1_TRMax-3.pkl"))
mriqception_T1w = mriqception_df(tmp_T1w,'T1w',list(metricList_anatomical.keys()))
mriqception_T1w.drop_duplicates(inplace=True)

tmp_T2w = pd.read_pickle(os.path.join(project_root, "MRIQC", "MRIQCEPTION_API_modality-T2w_TeslaMin-2_TeslaMax-4_TRMin-1_TRMax-3.pkl"))
mriqception_T2w = mriqception_df(tmp_T2w,'T2w',list(metricList_anatomical.keys()))
mriqception_T2w.drop_duplicates(inplace=True)

tmp_bold = pd.read_pickle(os.path.join(project_root, "MRIQC", "MRIQCEPTION_API_modality-bold_TeslaMin-2_TeslaMax-4_TRMin-1_TRMax-3.pkl"))
mriqception_bold = mriqception_df(tmp_bold,'bold',list(metricList_functional.keys()))
mriqception_bold.drop_duplicates(inplace=True)

print(f"MRIQception T1w shape: {mriqception_T1w.shape}")
print(f"MRIQception T2w shape: {mriqception_T2w.shape}")
print(f"MRIQception bold shape: {mriqception_bold.shape}")

mriqcpath = os.path.join(project_root, "MRIQC")
BMD_anatomical = get_MRIQC_Metrics(['T1w','T2w'], subjects, metricList_anatomical, mriqcpath)
BMD_functional_traintest = get_MRIQC_Metrics(['train','test'], subjects, metricList_functional, mriqcpath)
BMD_functional_localizerrest = get_MRIQC_Metrics(['localizer','rest'], subjects, metricList_functional, mriqcpath)

anatomical_all = pd.concat([BMD_anatomical,mriqception_T1w, mriqception_T2w])
bold_traintest = pd.concat([BMD_functional_traintest, mriqception_bold])
bold_localizerrest = pd.concat([BMD_functional_localizerrest, mriqception_bold])

anatomical_all.to_excel(os.path.join(save_root, 'fig2a_source_data.xlsx'), sheet_name='Figure 2a', index=False)
bold_traintest.to_excel(os.path.join(save_root, 'fig2b_source_data.xlsx'), sheet_name='Figure 2b', index=False)
bold_localizerrest.to_excel(os.path.join(save_root, 'sfig1_source_data.xlsx'), sheet_name='Supplementary Figure 1', index=False)

boxplot_anatomical(anatomical_all, metricList_anatomical, savefig=os.path.join(save_root, "output_mriqc"))
boxplot_functional(bold_traintest, metricList_functional, tasks=['train','test','bold'], savefig=os.path.join(save_root, "output_mriqc"))
boxplot_functional(bold_localizerrest, metricList_functional, tasks=['localizer','rest','bold'], savefig=os.path.join(save_root, "output_mriqc"))