In [None]:
import pickle
import pandas as pd
import numpy as np
import nibabel as nib
from pathlib import Path
import matplotlib.pyplot as plt

def read_pickle(pkl_path):
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)

    return data

In [None]:
experiment_id = "1_frequency_tagging"
mri_id = "3T"
deriv_dir_base = "oscprep_grayords_fmapless"
deriv_dir = f"/data/{experiment_id}/{mri_id}/bids/derivatives/{deriv_dir_base}/bold_preproc"
sub_ids = !ls {deriv_dir}
for sub_id in sub_ids:
    ses_ids = !ls {deriv_dir}/{sub_id}
    for ses_id in ses_ids:
        dtseries = !ls {deriv_dir}/{sub_id}/{ses_id}/func/*dtseries.nii
        for raw_bold in dtseries:
            task_id = raw_bold.split("_task-")[1].split("_")[0]
            if not any([task_id.startswith("Attend"), task_id.startswith("entrain"), task_id.startswith("control")]):
                continue
            task_id = f"task-{task_id}"

            run_id = raw_bold.split("_run-")[1].split("_")[0]
            run_id = f"run-{run_id}"
            
            sub_base = f"{sub_id}_{ses_id}_{task_id}_{run_id}"
            
            raw_bold = !ls {deriv_dir}/{sub_id}/{ses_id}/func/{sub_id}_{ses_id}_{task_id}_*_{run_id}_desc-preproc_bold.dtseries.nii
            confounds = !ls {deriv_dir}/{sub_id}/{ses_id}/func/{sub_id}_{ses_id}_{task_id}_*_{run_id}_desc-confounds_timeseries.tsv
            for i in [raw_bold, confounds]:
                assert len(i) == 1

            denoised_bold = Path("/scratch/fastfmri") / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_desc-denoised_bold" / "00_experiment-min+motion24+wmcsfmean" / sub_id / ses_id / task_id / run_id/ "GLM" / f"{sub_base}_desc-denoised_bold.dtseries.nii"
            
            if not denoised_bold.exists():
                continue
            
            nvols = nib.load(raw_bold[0]).shape[0]

            data = {
                "raw_bold": Path(raw_bold[0]),
                "confounds": Path(confounds[0]),
                "denoised_bold": denoised_bold,
                "eyetracking": Path("/data/behaviour") / "eyetracking" / experiment_id / mri_id / f"{sub_base}_eyetracking.pkl", 
                "fingertracking": Path("/data/behaviour") / "fingertracking" / experiment_id / mri_id / f"{sub_base}_fingertracking.pkl", 
            }

            all_exists = True
            for k, v in data.items():
                if not v.exists():
                    print(f"[{sub_base} - {k} - n_vols: {nvols}] {v} not found.")
                    all_exists = False

            if all_exists and "task-entrain" in str(data['raw_bold']):
                import pdb; pdb.set_trace()

In [None]:
data

Carpet plots

In [None]:
def read_pkl(datadir, n_bootstraps, sub_id, roi_task_id, roi_frequency, task_id, experiment_id="1_frequency_tagging", mri_id="7T", fo=.8, corr_type='fdrp', roi_frequency_2=None, control_roi_size=False):

    import pickle

    if roi_frequency_2 is not None:
        if control_roi_size:
            bootstrap_pkl: Path = datadir / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-{n_bootstraps}_batch-merged_desc-IMall_roi-{roi_task_id}-{roi_frequency}_controlroisizetomatch-{roi_frequency_2}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/task-{task_id}_bootstrapped_data.pkl"
        else:
            bootstrap_pkl: Path = datadir / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-{n_bootstraps}_batch-merged_desc-IMall_roi-{roi_task_id}-{roi_frequency}-{roi_frequency_2}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/task-{task_id}_bootstrapped_data.pkl"
    else:
        bootstrap_pkl: Path = datadir / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_n-{n_bootstraps}_batch-merged_desc-IMall_roi-{roi_task_id}-{roi_frequency}_pval-{corr_type}_fo-{fo}_bootstrap/sub-{sub_id}/bootstrap/task-{task_id}_bootstrapped_data.pkl"
    if not bootstrap_pkl.exists():
        print(f"Warning: {bootstrap_pkl} does not exist.\nReturn None")
        return None

    print(f"Reading: {bootstrap_pkl}")
    with open(bootstrap_pkl, 'rb') as f:
        data = pickle.load(f)

    return data

def find_quadrant_id_from_keys(_dict):
    for i in _dict.keys():
        if task_id in i:
            q_idx = i.find("Q")
            q_id = i[q_idx:q_idx+2]
            assert q_id in ['Q1', 'Q2']
            return q_id
    raise ValueError("No quadrant id found.")

def pad_denoised_data(denoised_bold_data, n_tps_raw, timepoints_raw, timepoints_denoised):

    n_tps_denoised = denoised_bold_data.shape[0]
    n_vertices = denoised_bold_data.shape[1]

    first_denoised_tp = timepoints_denoised[0]
    start_ix = None
    for ix, i in enumerate(timepoints_raw):
        if i == first_denoised_tp:
            start_ix = ix
    assert start_ix is not None

    start_paddings = np.ones((start_ix, n_vertices)) * -1
    end_paddings = np.ones((n_tps_raw - start_ix - n_tps_denoised, n_vertices)) * -1

    return np.vstack([start_paddings, denoised_bold_data, end_paddings])

def custom_colormap():
    from matplotlib.colors import LinearSegmentedColormap
    # Create the base colormap (Blue-White-Red)
    bwr = plt.cm.bwr
    
    # Extract colors from the Blue-White-Red colormap
    bwr_colors = bwr(np.linspace(0, 1, 256))
    
    # Replace all values of -1 with grey
    bwr_colors[0] = (0.5, 0.5, 0.5, .5)  # Grey color
    
    # Create a new colormap with the modified colors
    custom_cmap = LinearSegmentedColormap.from_list('custom_bwr', bwr_colors)
    
    return custom_cmap

def load_bold_data(
    raw_bold_path, raw_ax,
    denoised_bold_path, denoised_ax,
    truncate_window, repetition_time,
    experiment_id, mri_id, sub_id, task_id, frequencies, 
    roi_task_id,
    corr_type="fdrp", roi_frequency_2=None, control_roi_size=False, fo=.8,
    datadir = Path("/scratch/fastfmri"),
    n_bootstraps = 200, period_offset = 2,
    vmin=-.05, vmax=.05, # [-5,5] %BOLD
):
    assert corr_type in ["fdrp", "uncp"]

    if not raw_bold_path.exists() or not denoised_bold_path.exists():
        return None

    # Load pre-generated pickles
    f_data = {}
    for roi_f in frequencies:
        f_data[roi_f] = read_pkl(
            datadir, 
            n_bootstraps, 
            sub_id.split('-')[-1], 
            roi_task_id, 
            roi_f, 
            task_id.split('-')[-1].split("Q")[0],
            experiment_id=experiment_id,
            mri_id=mri_id,
            fo=fo,
            corr_type=corr_type,
            roi_frequency_2=roi_frequency_2,
            control_roi_size=control_roi_size,
        )
    # Get vertex coordinates of f1, f2, f1/f2
    assert frequencies[1]>frequencies[0]
    f1_dict = f_data[frequencies[0]].copy()
    f2_dict = f_data[frequencies[1]].copy()
    f1_coords = f1_dict['roi_coords']
    f2_coords = f2_dict['roi_coords']
    f1_only_coords = f1_coords.astype(int) + f2_coords.astype(int)
    f1_only_coords = f1_only_coords[f1_coords]
    f2_only_coords = f1_coords.astype(int) + f2_coords.astype(int)
    f2_only_coords = f2_only_coords[f2_coords]
    # Masks
    inter_from_f1 = f1_only_coords == 2
    f1_from_f1 = f1_only_coords == 1
    f2_from_f2 = f2_only_coords == 1
    n_f1, n_f1f2, n_f2 = f1_from_f1.sum(), inter_from_f1.sum(), f2_from_f2.sum()
    #print(n_f1, n_f1f2, n_f2)
    # Load bold data
    bold_data = {}
    for bold_type, bold_path in zip(["raw", "denoised"], [raw_bold_path, denoised_bold_path]):
        bold_data[bold_type] = nib.load(bold_path).get_fdata()
    # Get timepoints of data
    n_tps_raw = bold_data['raw'].shape[0]
    timepoints_raw = np.arange(repetition_time, repetition_time*n_tps_raw+repetition_time, repetition_time)
    n_tps_denoised = bold_data['denoised'].shape[0]
    timepoints_denoised = np.arange(repetition_time, truncate_window[1]+repetition_time, repetition_time)[-n_tps_denoised:]
    assert n_tps_denoised < n_tps_raw
    assert timepoints_denoised.shape[0] == n_tps_denoised
    assert timepoints_raw.shape[0] == n_tps_raw
    # Convert to percent BOLD change 
    baseline = bold_data['denoised'].mean(0)
    bold_data['raw'] = ((bold_data['raw'] - baseline) / baseline).T
    bold_data['denoised'] = ((bold_data['denoised'] - baseline) / baseline).T
    # Extract vertices from bold_data
    selected_bold_data = {}
    selected_bold_data['raw'] = np.hstack([
            bold_data['raw'][f1_coords,:][f1_from_f1,:].T,
            bold_data['raw'][f1_coords,:][inter_from_f1,:].T,
            bold_data['raw'][f2_coords,:][f2_from_f2,:].T,
    ])
    selected_bold_data['denoised'] = np.hstack([
            bold_data['denoised'][f1_coords,:][f1_from_f1,:].T,
            bold_data['denoised'][f1_coords,:][inter_from_f1,:].T,
            bold_data['denoised'][f2_coords,:][f2_from_f2,:].T,
    ])
    # Pad the truncated data
    selected_bold_data['padded_denoised'] = pad_denoised_data(selected_bold_data['denoised'], n_tps_raw, timepoints_raw, timepoints_denoised)

    # Plot carpet plots
    extent=[
        timepoints_raw[0], timepoints_raw[-1], n_f1+n_f1f2+n_f2, 0
    ]
    raw_ax.imshow(selected_bold_data['raw'].T, cmap=custom_colormap(), vmin=vmin, vmax=vmax, aspect='auto', extent=extent)
    denoised_ax.imshow(selected_bold_data['padded_denoised'].T, cmap=custom_colormap(), vmin=vmin, vmax=vmax, aspect='auto', extent=extent)
    
    # decorate
    total_vertices = n_f1+n_f1f2+n_f2
    for ax in [raw_ax, denoised_ax]:
        square_f1 = plt.Polygon([(0, 0), (2,0), (2, n_f1), (0, n_f1)], closed=True, color='red')
        square_f1f2 = plt.Polygon([(0, n_f1), (2, n_f1), (2, n_f1+n_f1f2), (0, n_f1+n_f1f2)], closed=True, color='gold')
        square_f2 = plt.Polygon([(0, n_f1+n_f1f2), (2, n_f1+n_f1f2), (2, total_vertices), (0, total_vertices)], closed=True, color='blue')
        for square in [square_f1, square_f2, square_f1f2]:
            ax.add_patch(square)
    period_f1 = 1/frequencies[0]
    period_f2 = 1/frequencies[1]
    timescale_f1 = plt.Polygon(
        [
            (period_offset, total_vertices*.08), 
            (period_offset+period_f1, total_vertices*.08), 
            (period_offset+period_f1, total_vertices*.11), 
            (period_offset, total_vertices*.11)
        ], 
        closed=True, color='red', linewidth=0., zorder=10,
    )
    timescale_f2 = plt.Polygon(
        [
            (period_offset, total_vertices*.12), 
            (period_offset+period_f2, total_vertices*.12),
            (period_offset+period_f2, total_vertices*.15),
            (period_offset, total_vertices*.15)
        ], 
        closed=True, color='blue', linewidth=0., zorder=10,
    )
    for square in [timescale_f1, timescale_f2]:
        denoised_ax.add_patch(square)
        
    data_dict = {
            "timepoints": timepoints_raw,
            "bold_raw": selected_bold_data['raw'],
            "bold_denoised": selected_bold_data['denoised'],
            "bold_padded_denoised": selected_bold_data['padded_denoised'],
            "n_f1": n_f1,
            "n_f2": n_f2,
            "n_f1f2": n_f1f2,
            "stimulated_frequencies": frequencies,
    }

    return data_dict

fig, axs = plt.subplots(nrows=2, figsize=(5,7), dpi=200)

truncate_window = (39, 219)
repetition_time = .3
frequencies = [.125, .2]
roi_task_id = "entrain"
corr_type = "fdrp"
n_bootstraps = 200
bold_data = load_bold_data(
    data['raw_bold'], axs[0], data['denoised_bold'], axs[1],
    truncate_window, repetition_time,
    experiment_id, mri_id, sub_id, task_id, frequencies, 
    roi_task_id,
    corr_type=corr_type, roi_frequency_2=None, control_roi_size=False, fo=.8,
    datadir = Path("/scratch/fastfmri"),
    n_bootstraps = n_bootstraps,
    vmin=-.05, vmax=.05
)

fig.tight_layout()

Finger tracking

In [None]:
def match_response_to_target(target_onset_time, combined_response_times, correct_response_times, incorrect_response_times):

    if len(combined_response_times) == 0:
        return None, combined_response_times

    while True:

        if len(combined_response_times) == 0:
            return ("wrong", "no_response"), combined_response_times

        response_time = combined_response_times[0]
        if response_time in correct_response_times:
            response_type = "correct"
        if response_time in incorrect_response_times:
            response_type = "wrong"

        response_delay = response_time - target_onset_time

        if response_delay > 0:
            combined_response_times = combined_response_times[1:]
            return (response_type, response_delay), combined_response_times
        
        combined_response_times = combined_response_times[1:]

def condense_delay_times(target_onset_times_mapping):

    correct_delay_times = []
    wrong_delay_times = []
    for time_mapping in target_onset_times_mapping:
        _, response_time = time_mapping[0], time_mapping[1]
        if response_time is None:
            continue
        if response_time[0] == "correct":
            correct_delay_times.append(response_time[1])
        if response_time[0] == "wrong":
            wrong_delay_times.append(response_time[1])

    return  correct_delay_times, wrong_delay_times

def calculate_response_delay_times(target_onset_times, correct_response_times, incorrect_response_times):
    
    combined_response_times = correct_response_times + incorrect_response_times
    combined_response_times.sort()
    target_onset_times_mapping = []
    for ix, target_onset_time in enumerate(target_onset_times):
        val, combined_response_times = match_response_to_target(target_onset_time, combined_response_times, correct_response_times, incorrect_response_times)
        target_onset_times_mapping.append((target_onset_time, val))

    return condense_delay_times(target_onset_times_mapping)

def load_fingertracking_data(pkl_path, ax):

    if not pkl_path.exists():
        return None

    data = read_pickle(pkl_path)
    PKL_KEYS = [i for i in data.keys()]
    for k in ["correctResponsesTimes", "incorrectResponsesTimes", "targetOnsetTimes"]:
        assert k in PKL_KEYS, f"key: {k} not found in {pkl_path}."

    n_incorrect = len(data["incorrectResponsesTimes"])
    n_correct = len(data["correctResponsesTimes"])
    if n_correct+n_incorrect == 0:
        return None
    accuracy = n_correct / (n_correct+n_incorrect)

    target_onset_times = [float(i) for i in data["targetOnsetTimes"]]
    correct_response_times = [float(i) for i in data["correctResponsesTimes"]]
    incorrect_response_times = [float(i) for i in data["incorrectResponsesTimes"]]
    correct_delay_times, wrong_delay_times = calculate_response_delay_times(target_onset_times, correct_response_times, incorrect_response_times)

    _response_times = [target_onset_times, correct_response_times, incorrect_response_times]
    for response_times, response_c in zip(_response_times, ['k', 'g', 'r']):
        for response_time in response_times:
            ax.axvline(x=response_time, c=response_c, lw=.3)

    data_dict = {
        "n_correct": n_correct,
        "n_incorrect": n_incorrect,
        "accuracy": accuracy,
        "timing_onset": data["targetOnsetTimes"],
        "correct_response_times": data["correctResponsesTimes"],
        "correct_delay_times": correct_delay_times,
        "wrong_response_times": data["incorrectResponsesTimes"],
        "wrong_delay_times": wrong_delay_times,
    }

    #assert len(data_dict["wrong_response_times"]) == len(data_dict["wrong_delay_times"])
    #assert len(data_dict["correct_response_times"]) == len(data_dict["correct_delay_times"])

    return data_dict


fig,ax = plt.subplots(figsize=(6,1), dpi=200)
fingertracking_data = load_fingertracking_data(data['fingertracking'], ax)
if fingertracking_data is not None:
    for k, v in fingertracking_data.items():
        if isinstance(v, list):
            print(k, len(v))

Framewise displacement

In [None]:
def load_motion_data(pkl_path, tr, ax):

    data = pd.read_csv(pkl_path, sep='\t')
    cols = data.columns

    assert 'framewise_displacement' in cols

    fd = data.framewise_displacement.values
    n_tps = fd.shape[0]

    timepoints = np.arange(tr, n_tps*tr+tr, tr)[1:]
    fd = fd[1:]

    assert fd.shape == timepoints.shape

    data_dict = {
        "timepoints": timepoints,
        "framewise_displacement": fd
    }

    ax.plot(timepoints, fd, c='k', lw=.2, zorder=2, alpha=1.)

    return data_dict

fig,ax = plt.subplots(figsize=(6,1), dpi=200)
tr = .3
motion_data = load_motion_data(data['confounds'], tr, ax)

Combine all, which also includes eyetracking

In [None]:
def get_chunks(lst):
    # Find indices of non-NaN values
    number_indices = [i for i, x in enumerate(lst) if not np.isnan(x)]

    # Initialize list to store chunks
    chunks = []

    if len(number_indices) == 0:
        return chunks
    # Iterate through the list to find consecutive sequences
    start = number_indices[0]
    for i in range(1, len(number_indices)):
        if number_indices[i] != number_indices[i-1] + 1:
            chunks.append((start, number_indices[i-1]))
            start = number_indices[i]

    # Add the last chunk
    chunks.append((start, number_indices[-1]))

    return chunks

def load_eyetracking_data(pkl_path, ax, sample_rate=1/500):
    if not pkl_path.exists():
        return None
    data = read_pickle(pkl_path)
    PKL_KEYS = [i for i in data.keys()]

    assert "gazeX" in PKL_KEYS
    assert "gazeY" in PKL_KEYS

    gaze_x = data["gazeX"]
    gaze_y = data["gazeY"]

    assert gaze_x.shape == gaze_y.shape
    assert gaze_x.shape[-1] == 1

    gaze_x = gaze_x[:,0]
    gaze_y = gaze_y[:,0]
    
    n_tps = gaze_x.shape[0]
    n_duration = sample_rate*n_tps
    timepoints = np.arange(sample_rate, n_duration+sample_rate, sample_rate)

    #print(gaze_x.shape, n_tps, timepoints.shape)
    #print(timepoints)

    x_chunks = get_chunks(gaze_x)
    y_chunks = get_chunks(gaze_y)
    if len(x_chunks) == 0:
        return None

    assert len(x_chunks) == len(y_chunks)


    for i,j in zip(x_chunks, y_chunks):
        assert i == j
        chunk_gaze_x = gaze_x[i[0]:i[1]+1]
        chunk_gaze_y = gaze_y[i[0]:i[1]+1]
        chunk_timepoints = timepoints[i[0]:i[1]+1]

        ax.plot(chunk_timepoints, chunk_gaze_x, c='r', lw=.2, zorder=2, alpha=.4)
        ax.plot(chunk_timepoints, chunk_gaze_y, c='b', lw=.2, zorder=2, alpha=.4)

    data_dict = {
        "timepoints": timepoints,
        "gaze_x": gaze_x,
        "gaze_y": gaze_y,
    }

    return data_dict

Function to calculate vertex-wise power spectrums

In [None]:
def get_im_frequencies(f1,f2):
    assert f2 > f1, f"{f2} <= {f1}"
    im_frequencies = {}
    im_frequencies["first_order"] = [f1, f2]
    f2_sub_f1 = round(f2-f1, 10)
    f1_plus_f2 = round(f1+f2, 10)
    f1_mul_2 = round(f1*2, 10)
    f2_mul_2 = round(f2*2, 10)
    f1_mul_2_sub_f2 = round(2*f1-f2, 10)
    f2_mul_2_sub_f1 = round(2*f2-f1, 10)
    im_frequencies["second_order"] = [
        f2_sub_f1, # f2-f1
        f1_plus_f2, # f1+f2
        f1_mul_2, # 2*f1
        f2_mul_2, # 2*f2
    ]
    im_frequencies["third_order"] = [
        f1_mul_2_sub_f2, # 2f1 - f2
        f2_mul_2_sub_f1, # 2f2 - f1
    ]

    return im_frequencies

def calculate_power_spectrum_from_denoised_bold(bold_data, fois, repetition_time, nperseg=1024,):
    from collections import defaultdict
    from scipy.signal import welch

    denoised_bold = bold_data['bold_denoised']
    n_tps, n_vertices = denoised_bold.shape[0], denoised_bold.shape[1]

    all_fois = get_im_frequencies(fois[0], fois[1])

    psd_per_f = defaultdict(list)
    foi_list = []
    for vertex_id in range(n_vertices):
        vertex_bold = denoised_bold[:, vertex_id]
        fs, ps = welch(vertex_bold, fs=1/repetition_time, nperseg=nperseg)

        for im_key, fois in all_fois.items():
            for foi in fois:
                psd_per_f[foi].append(np.interp(foi, fs, ps))
            if vertex_id == 0:
                foi_list += fois

    return foi_list, psd_per_f, n_vertices

Aggregate all plots

In [None]:
def set_base_dir(basedir):
    basedir = Path(basedir)
    if not basedir.exists():
        basedir.mkdir(exist_ok=True, parents=True)

    return basedir

def explore_single_subject_run(
    experiment_id, 
    mri_id, 
    sub_id, 
    ses_id, 
    run_id, 
    roi_task_id, 
    task_id,
    data,
    repetition_time,
    frequencies, 
    corr_type = "uncp",
    STIMULUS_TIMINGS = [0, 14, 14+25, 219],
    FONTSIZE=4,
    DPI=300,
    CLOSE_FIGURES = True,
    SKIP_IF_EXISTS = True,
):

    fig, axs = plt.subplots(
        nrows=5, ncols=2, 
        sharex="col", 
        sharey='row',
        figsize=(5,4), 
        dpi=DPI,
        gridspec_kw = {
            "height_ratios": [.5, 1, 1, 6, 6],
            "width_ratios": [9,1.5]
        }
    )

    png_out = Path(set_base_dir(f"./ComputeCanada/frequency_tagging/figures/data_exploration/carpet_plots/{experiment_id}/mri-{mri_id}/{sub_id}/{ses_id}/figures")) / f"{sub_id}_{ses_id}_{run_id}_roi-task-{roi_task_id}-{corr_type}_{task_id}_carpetplot.png"
    if png_out.exists() and SKIP_IF_EXISTS:
        return None

    fig.suptitle(f"{experiment_id}, {mri_id}, {sub_id}, {ses_id}, {run_id}, roi-task-{roi_task_id}, {task_id}", fontsize=FONTSIZE)

    eyetracking_data = load_eyetracking_data(data["eyetracking"], axs[1,0])
    fingertracking_data = load_fingertracking_data(data['fingertracking'], axs[0,0])
    motion_data = load_motion_data(data['confounds'], repetition_time, axs[2,0])
    """bold data
    """
    truncate_window = (STIMULUS_TIMINGS[-2], STIMULUS_TIMINGS[-1])
    n_bootstraps = 200
    bold_data = load_bold_data(
        data['raw_bold'], axs[3,0], data['denoised_bold'], axs[4,0],
        truncate_window, repetition_time,
        experiment_id, mri_id, sub_id, task_id, frequencies, 
        roi_task_id,
        corr_type=corr_type, roi_frequency_2=None, control_roi_size=False, fo=.8,
        datadir = Path("/scratch/fastfmri"),
        n_bootstraps = n_bootstraps, period_offset = 16,
        vmin=-.04, vmax=.04
    )

    for ix, ax in enumerate([i for i in axs[:,0]]+[axs[4,1]]):
        for stim_time in STIMULUS_TIMINGS:
            if ix != 5:
                ax.axvline(x=stim_time, lw=1, c='k', zorder=1)
        for spine_type in ["top", "bottom", "right", "left"]:
            ax.spines[spine_type].set_visible(False)
        ax.tick_params(axis='x', bottom=False, pad=0, labelsize=FONTSIZE)
        ax.tick_params(axis='y', length=2, pad=0, labelsize=FONTSIZE)
        ax.set_xlim(0, max(ax.get_xlim()))
        if ix in [3,4]:
            ax.set_yticks([])

    # Add psd, bold
    fois, power_per_vertex, n_vertices = calculate_power_spectrum_from_denoised_bold(bold_data, bold_data['stimulated_frequencies'], repetition_time)
    extent=[
        0, 1, bold_data['n_f1']+bold_data['n_f1f2']+bold_data['n_f2'], 0
    ]
    psds = np.vstack([
            np.array(power_per_vertex[fois[0]])[np.newaxis,:],
            np.array(power_per_vertex[fois[1]])[np.newaxis,:],
            np.array(power_per_vertex[fois[2]])[np.newaxis,:],
            np.array(power_per_vertex[fois[3]])[np.newaxis,:],
            np.array(power_per_vertex[fois[4]])[np.newaxis,:],
            np.array(power_per_vertex[fois[5]])[np.newaxis,:],
            np.array(power_per_vertex[fois[6]])[np.newaxis,:],
            np.array(power_per_vertex[fois[7]])[np.newaxis,:],
    ])
    axs[4,1].imshow(
        psds.T, 
        cmap='magma', 
        vmin=0, vmax=.001, 
        aspect='auto', 
        extent=extent,
        interpolation='none',
    )
    axs[4,1].set_xticks([
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*0), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*1), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*2), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*3), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*4), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*5), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*6), 
        axs[4,1].get_xlim()[-1]*(1/16)+((1/8)*7), 
    ])
    xticklabels = [
        "$f_{1}$",
        "$f_{2}$",
        "$f_{2}$-$f_{1}$",
        "$f_{1}$+$f_{2}$",
        "2$f_{1}$",
        "2$f_{2}$",
        "2$f_{1}$-$f_{2}$",
        "2$f_{2}$-$f_{1}$",
    ]
    axs[4,1].set_xticklabels(xticklabels, fontsize=FONTSIZE)
    axs[4,1].set_xlabel("Power", fontsize=FONTSIZE)
    axs[4,1].tick_params(axis='x', rotation=90)


    # Add plot descriptions
    if fingertracking_data is not None:
        accuracy = fingertracking_data["accuracy"]
        mean_correct_delay_times = np.mean(fingertracking_data["correct_delay_times"])
        mean_wrong_delay_times = np.mean([i for i in fingertracking_data["wrong_delay_times"] if i != "no_response"])
        fingertracking_str = f"Trigger timings [accuracy, correct/wrong mean response times]: {100*accuracy:.2f}%, {mean_correct_delay_times:.3f}/{mean_wrong_delay_times:.3f}s"
        axs[0,0].text(STIMULUS_TIMINGS[-1]*.49, axs[0,0].get_ylim()[-1]*1.15, fingertracking_str, fontsize=FONTSIZE-1, zorder=10)
    else:
        axs[0,0].remove()
    if motion_data is not None:
        mean_fd = np.mean(motion_data['framewise_displacement'])
        median_fd = np.median(motion_data['framewise_displacement'])
        motion_str = f"Framewise displacement [mean/median]: {mean_fd:.3f}/{median_fd:.3f}mm"
        axs[2,0].text(STIMULUS_TIMINGS[-1]*.66, axs[2,0].get_ylim()[-1]*1.12, motion_str, fontsize=FONTSIZE-1, zorder=10)
    else:
        axs[2,0].remove()
    if eyetracking_data is not None:
        axs[1,0].text(STIMULUS_TIMINGS[-1]*.93, axs[1,0].get_ylim()[-1]*1.15, "Eye tracking", fontsize=FONTSIZE-1, zorder=10)
    else:
        axs[1,0].remove()
    axs[3,0].set_ylabel("Vertex", fontsize=FONTSIZE)
    axs[4,0].set_ylabel("Vertex", fontsize=FONTSIZE)
    axs[4,0].set_xlabel("Timepoints", fontsize=FONTSIZE)
    y_tracker = .25
    stimulated_frequencies = bold_data["stimulated_frequencies"]
    stimulated_frequencies = [stimulated_frequencies[0], None, stimulated_frequencies[1]]
    for i, (f_count, f) in enumerate(zip(["n_f1", "n_f1f2", "n_f2"], stimulated_frequencies)):
        n_f = bold_data[f_count]
        if f is not None:
            f_label = i+1
            if i == 2:
                f_label = 2
            f_subscript = f"{f_label},{f}Hz"
        else:
            f_subscript = "1,2"
        axs[4,0].text(16, axs[4,0].get_ylim()[0]*y_tracker, f"$f_{{{f_subscript}}}$={n_f}", fontsize=FONTSIZE-1, zorder=10)
        y_tracker += .08
    # Remove usused plots
    axs[0,1].remove()
    axs[1,1].remove()
    axs[2,1].remove()
    axs[3,1].remove()

    plt.subplots_adjust(hspace=.1, wspace=0, top=.94)

    fig.savefig(png_out, dpi='figure')

    if CLOSE_FIGURES:
        plt.close()

In [None]:
repetition_time = .3
frequencies = [.125, .2]
corr_type = "uncp"
STIMULUS_TIMINGS = [0, 14, 14+25, 219]
FONTSIZE = 4
DPI=300
explore_single_subject_run(
    experiment_id, 
    mri_id, 
    sub_id, 
    ses_id, 
    run_id, 
    roi_task_id, 
    task_id,
    data,
    repetition_time, frequencies, 
    corr_type=corr_type, 
    STIMULUS_TIMINGS = STIMULUS_TIMINGS,
    FONTSIZE = FONTSIZE,
    DPI=DPI,
    CLOSE_FIGURES=False,
    SKIP_IF_EXISTS=False,
)

In [None]:
experiment_id = "1_attention"
mri_id = "7T"

deriv_dir_base = "oscprep_grayords_fmapless"
deriv_dir = f"/data/{experiment_id}/{mri_id}/bids/derivatives/{deriv_dir_base}/bold_preproc"
sub_ids = !ls {deriv_dir}
ATTENTION_7T = [f"sub-{i}" for i in ["010", "011", "012", "013", "014", "015", "016"]]
NORMAL_3T = [f"sub-{i}" for i in ["000", "002", "003", "004", "005", "006", "007", "008", "009"]]
NORMAL_7T = [f"sub-{i}" for i in ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]]
VARY_3T = [f"sub-{i}" for i in ["020", "021"]]
VARY_7T = [f"sub-{i}" for i in ["020", "021"]]
STIMULUS_TIMINGS = [0, 14, 14+25, 219]
FONTSIZE = 4
DPI=300
SKIP_IF_EXISTS = True
for corr_type in ["fdrp", "uncp"]:
    for sub_id in sub_ids:

        # Set repetition_time
        if experiment_id == "1_attention" and sub_id in ATTENTION_7T:
            repetition_time = .25
        else:
            repetition_time = .3

        # Set roi_task_id
        # Attention 7T
        if experiment_id == "1_attention" and mri_id =="7T" and sub_id in ATTENTION_7T:
            roi_task_ids = ["AttendAway", "AttendInF1", "AttendInF2", "AttendInF1F2"]
            roi_task_ids = ["match"]
        # Normal 3T
        elif experiment_id == "1_frequency_tagging" and mri_id == "3T" and sub_id in NORMAL_3T:
            roi_task_ids = ["entrain"]
        # Normal 7T
        elif experiment_id == "1_attention" and mri_id == "7T" and sub_id in NORMAL_7T:
            roi_task_ids = ["match"]
        # Vary 3T
        elif experiment_id == "1_frequency_tagging" and mri_id == "3T" and sub_id in VARY_3T:
            roi_task_ids = ["match"]
        # Vary 7T
        elif experiment_id == "1_frequency_tagging" and mri_id == "7T" and sub_id in VARY_7T:
            roi_task_ids = ["match"]
        else:
            print(f"Skipping {experiment_id} {mri_id} {sub_id} {ses_id}")
            continue
        
        ses_ids = !ls {deriv_dir}/{sub_id}
        for ses_id in ses_ids:
            dtseries = !ls {deriv_dir}/{sub_id}/{ses_id}/func/*dtseries.nii
            for raw_bold in dtseries:
                task_id = raw_bold.split("_task-")[1].split("_")[0]
                if not any([task_id.startswith("Attend"), task_id.startswith("entrain"), task_id.startswith("control")]):
                    continue
                task_id = f"task-{task_id}"

                run_id = raw_bold.split("_run-")[1].split("_")[0]
                run_id = f"run-{run_id}"
                
                sub_base = f"{sub_id}_{ses_id}_{task_id}_{run_id}"
                
                raw_bold = !ls {deriv_dir}/{sub_id}/{ses_id}/func/{sub_id}_{ses_id}_{task_id}_*_{run_id}_desc-preproc_bold.dtseries.nii
                confounds = !ls {deriv_dir}/{sub_id}/{ses_id}/func/{sub_id}_{ses_id}_{task_id}_*_{run_id}_desc-confounds_timeseries.tsv
                for i in [raw_bold, confounds]:
                    assert len(i) == 1

                denoised_bold = Path("/scratch/fastfmri") / f"experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-39-219_desc-denoised_bold" / "00_experiment-min+motion24+wmcsfmean" / sub_id / ses_id / task_id / run_id/ "GLM" / f"{sub_base}_desc-denoised_bold.dtseries.nii"
                
                if not denoised_bold.exists():
                    continue
                
                nvols = nib.load(raw_bold[0]).shape[0]

                data = {
                    "raw_bold": Path(raw_bold[0]),
                    "confounds": Path(confounds[0]),
                    "denoised_bold": denoised_bold,
                    "eyetracking": Path("/data/behaviour") / "eyetracking" / experiment_id / mri_id / f"{sub_base}_eyetracking.pkl", 
                    "fingertracking": Path("/data/behaviour") / "fingertracking" / experiment_id / mri_id / f"{sub_base}_fingertracking.pkl", 
                }

                for roi_task_id in roi_task_ids:
                    _roi_task_id = roi_task_id
                    if roi_task_id == "match":
                        _roi_task_id = task_id.split("-")[-1].split("Q")[0]

                    # Set frequencies
                    if "entrainB" in task_id:
                        frequencies = [.125, .175]
                    elif "entrainC" in task_id:
                        frequencies = [.125, .15]
                    elif "entrainE" in task_id:
                        frequencies = [.15, .2]
                    elif "entrainF" in task_id:
                        frequencies = [.175, .2]
                    else:
                        frequencies = [.125, .2]
                    
                    print(
                        experiment_id, 
                        mri_id, 
                        sub_id, 
                        ses_id, 
                        run_id, 
                        f"roi-task-{_roi_task_id}", 
                        task_id,
                        f"TR: {repetition_time}", f"Frequencies: {frequencies}", corr_type,
                        STIMULUS_TIMINGS,
                        FONTSIZE,
                        DPI,
                        True,
                    )
                    explore_single_subject_run(
                        experiment_id, 
                        mri_id, 
                        sub_id, 
                        ses_id, 
                        run_id, 
                        _roi_task_id, 
                        task_id,
                        data,
                        repetition_time,
                        frequencies, 
                        corr_type = corr_type, 
                        STIMULUS_TIMINGS = STIMULUS_TIMINGS,
                        FONTSIZE = FONTSIZE,
                        DPI=DPI,
                        CLOSE_FIGURES=True,
                        SKIP_IF_EXISTS=SKIP_IF_EXISTS,
                    )