In [1]:
import itertools
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from pathlib import Path
from IPython.display import clear_output
from collections import defaultdict

import sys
sys.path.append("ComputeCanada/frequency_tagging")
from utils import (
    read_pkl, 
    find_quadrant_id_from_keys,
    read_bootstrap_txt,
    extract_carpet_data,
    decorate_fig_carpetplot,
    get_roi_colour_codes,
    set_base_dir,
    CARPET_PLOTS, SCRATCH_DIR, PICKLE_DIR,
    TimeSeries
)

Create subject level carpet plots in `CARPET_PLOTS`

Settings

In [2]:
datadir = SCRATCH_DIR
n_bootstraps = 400
bootstrap_id = 0
window_size = (39, 219)
close_figures = True

experiment_id = "1_frequency_tagging" 
normal_3T_sub_ids = ["000", "002", "003", "004", "005", "006", "007", "008", "009"] 
normal_7T_sub_ids = ["Pilot001", "Pilot009", "Pilot010", "Pilot011"]
vary_sub_ids = ["020"]*3 + ["021"]*3
vary_task_ids = [f"entrain{i}" for i in ["A", "B", "C", "D", "E", "F"]]

sub_ids = normal_3T_sub_ids*2 + normal_7T_sub_ids + vary_sub_ids*3*2
experiment_ids = ["1_frequency_tagging"]*len(normal_3T_sub_ids)*2 + ["1_attention"]*len(normal_7T_sub_ids) + ["1_frequency_tagging"]*len(vary_sub_ids)*3*2
mri_ids = ["3T"]*len(normal_3T_sub_ids)*2 + ["7T"]*len(normal_7T_sub_ids) + ["3T"]*len(vary_sub_ids)*3 + ["7T"]*len(vary_sub_ids)*3
roi_task_ids= ["entrain"]*len(normal_3T_sub_ids) + ["entrain"]*len(normal_3T_sub_ids) + ['AttendAway']*len(normal_7T_sub_ids) + (["entrainA"]*3 + ["entrainD"]*3 + ["entrainB"]*3 + ["entrainE"]*3 + ["entrainC"]*3 + ["entrainF"]*3) * 2
task_ids= ["control"]*len(normal_3T_sub_ids) + ["entrain"]*len(normal_3T_sub_ids) + ['AttendAway']*len(normal_7T_sub_ids) + vary_task_ids*3*2
roi_frequencies = [[.125, .2]]*(len(normal_3T_sub_ids)*2+len(normal_7T_sub_ids)) + ([[.125,.2]]*3 + [[.125,.2]]*3 + [[.125,.175]]*3 + [[.15,.2]]*3 + [[.125,.15]]*3 + [[.175,.2]]*3) * 2
task_frequencies = [[.125, .2]]*(len(normal_3T_sub_ids)*2+len(normal_7T_sub_ids)) + [[.125,.2],[.125,.175],[.125,.15],[.125,.2],[.15,.2],[.175,.2]]*3*2

TR = .3
fos = [.8]
pvals = ["uncp"]
fig_out_dir = Path(set_base_dir(str(CARPET_PLOTS)))
stim_start = 14
cmap = "Greys_r"
vmin, vmax = -1.31, 1.31

for i in [sub_ids, experiment_ids, mri_ids, roi_task_ids, task_ids, roi_frequencies, task_frequencies]:
    print(len(i))

58
58
58
58
58
58
58


Run

In [3]:
for fo, pval in itertools.product(fos,pvals):
        
    for experiment_id, mri_id, sub_id, roi_task_id, task_id, frequencies, _task_frequencies in zip(
        experiment_ids, mri_ids, sub_ids, roi_task_ids, task_ids, roi_frequencies, task_frequencies,
    ):

        f_data = {}
        for roi_f in frequencies:
            f_data[roi_f] = read_pkl(
                datadir, 
                n_bootstraps, 
                sub_id, 
                roi_task_id, 
                roi_f, 
                task_id,
                experiment_id=experiment_id,
                mri_id=mri_id,
                pval=pval,
                fo=fo,
            )

        task_quadrant = find_quadrant_id_from_keys(f_data[frequencies[0]], task_id)
                
        assert frequencies[1]>frequencies[0]
        f1_dict = f_data[frequencies[0]].copy()
        f2_dict = f_data[frequencies[1]].copy()
        f1_coords = f1_dict['roi_coords']
        f2_coords = f2_dict['roi_coords']
        f1_only_coords = f1_coords.astype(int) + f2_coords.astype(int)
        f1_only_coords = f1_only_coords[f1_coords]
        f2_only_coords = f1_coords.astype(int) + f2_coords.astype(int)
        f2_only_coords = f2_only_coords[f2_coords]
        # Masks
        inter_from_f1 = f1_only_coords == 2
        f1_from_f1 = f1_only_coords == 1
        f2_from_f2 = f2_only_coords == 1
        n_f1, n_f1f2, n_f2 = f1_from_f1.sum(), inter_from_f1.sum(), f2_from_f2.sum()

        # Load data from nifti
        # 1) untruncated, 2) preprocessed
        bootstrap_txt = Path(f"/scratch/fastfmri/experiment-{experiment_id}_mri-{mri_id}_smooth-0_truncate-{window_size[0]}-{window_size[1]}_n-25_batch-00_desc-basic_pval-{pval}_bootstrap/sub-{sub_id}/task-{task_id}{task_quadrant}_test_splits.txt")
        assert bootstrap_txt.exists(), f"{bootstrap_txt} not found."
        data_from_dtseries_raw, data_from_dtseries_windowed, data_from_dtseries_preprocessed = read_bootstrap_txt(bootstrap_txt, bootstrap_id) # Load single bootstrap

        data_from_dtseries_raw = np.hstack(
            [
                data_from_dtseries_raw[f1_coords,:][f1_from_f1,:].T,
                data_from_dtseries_raw[f1_coords,:][inter_from_f1,:].T,
                data_from_dtseries_raw[f2_coords,:][f2_from_f2,:].T,
            ]
        )
        data_from_dtseries_windowed = np.hstack(
            [
                data_from_dtseries_windowed[f1_coords,1:][f1_from_f1,:].T,
                data_from_dtseries_windowed[f1_coords,1:][inter_from_f1,:].T,
                data_from_dtseries_windowed[f2_coords,1:][f2_from_f2,:].T,
            ]
        )
        data_from_dtseries_preprocessed = np.hstack(
            [
                data_from_dtseries_preprocessed[f1_coords,1:][f1_from_f1,:].T,
                data_from_dtseries_preprocessed[f1_coords,1:][inter_from_f1,:].T,
                data_from_dtseries_preprocessed[f2_coords,1:][f2_from_f2,:].T,
            ]
        )
        # Load data from pickle
        # 3) preprocessed, 4) preprocessed & phased
        _, f1_data_from_pkl_preprocessed = extract_carpet_data(f1_dict, task_id, task_quadrant, bootstrap_id, False)
        f1_phased_tps, f1_data_from_pkl_preprocessed_phased = extract_carpet_data(f1_dict, task_id, task_quadrant, bootstrap_id, True)
        _, f2_data_from_pkl_preprocessed = extract_carpet_data(f2_dict, task_id, task_quadrant, bootstrap_id, False)
        f2_phased_tps, f2_data_from_pkl_preprocessed_phased = extract_carpet_data(f2_dict, task_id, task_quadrant, bootstrap_id, True)
        intersected_phased_tps = [i for i in set(f1_phased_tps).intersection(f2_phased_tps)]
        f1_phased_tp_mask = [tp in intersected_phased_tps for tp in f1_phased_tps]
        f2_phased_tp_mask = [tp in intersected_phased_tps for tp in f2_phased_tps]

        data_from_pkl_preprocessed = np.hstack(
            [
                f1_data_from_pkl_preprocessed[:,f1_from_f1],
                f1_data_from_pkl_preprocessed[:,inter_from_f1],
                f2_data_from_pkl_preprocessed[:,f2_from_f2],
            ]
        )
        data_from_pkl_preprocessed_phased = np.hstack(
            [
                f1_data_from_pkl_preprocessed_phased[:,f1_from_f1][f1_phased_tp_mask,:],
                f1_data_from_pkl_preprocessed_phased[:,inter_from_f1][f1_phased_tp_mask,:],
                f2_data_from_pkl_preprocessed_phased[:,f2_from_f2][f2_phased_tp_mask,:],
            ]
        )

        ts_labels = [
            "raw", "windowed", "denoised", "denoised_rephased"
        ]
        ts_data = [
            data_from_dtseries_raw, 
            data_from_dtseries_windowed,
            data_from_dtseries_preprocessed,
            #data_from_pkl_preprocessed, 
            data_from_pkl_preprocessed_phased,
        ]

        # Get sorting order based on `data_from_dtseries_preprocessed`
        # Sort for each set of vertices: f1, f1f2, and f2 (this is the order that the reoriented data)
        # Note: `data_from_dtseries_preprocessed` == `data_from_pkl_preprocessed`

        # This will error out if there is any of f1, f2, or f1f2 has 0 vertices.. I THINK?
        y = data_from_dtseries_preprocessed.copy()
        y = (( y - y.mean(0)) / y.std(0) ).T
        sorted_voxels = {}
        y_f1 = y[:n_f1,:].copy()
        y_f1f2 = y[n_f1:n_f1+n_f1f2,:].copy()
        y_f2 = y[n_f1+n_f1f2:,:].copy()
        for f_group, y in zip(["f1","f1f2","f2"], [y_f1, y_f1f2, y_f2]):
            C = np.corrcoef(y)
            correlation_strength = np.abs(C).sum(axis=1)
            sorted_voxels[f_group] = np.argsort(correlation_strength)[::-1]

        for y_ix, (y_label, y) in enumerate(zip(ts_labels, ts_data)):
            
            fig, ax = plt.subplots(
                nrows=1,ncols=1, figsize=(2.,1.2), dpi=300,
                #gridspec_kw=dict(height_ratios=[286, 119]),
            )
            
            y = (( y - y.mean(0)) / y.std(0) ).T

            # Sort
            y[:n_f1, :] = y[:n_f1,:][sorted_voxels["f1"],:]
            y[n_f1:n_f1+n_f1f2, :] = y[n_f1:n_f1+n_f1f2,:][sorted_voxels["f1f2"],:]
            y[n_f1+n_f1f2:, :] = y[n_f1+n_f1f2:,:][sorted_voxels["f2"],:]

            im = ax.imshow(y, cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')
            if y_ix == 0:
                for i in window_size:
                    ax.plot([i/TR]*2,[0,n_f1+n_f2+n_f1f2], color='orange', linestyle='-', linewidth=1)
                ax.plot([stim_start/TR]*2,[0,n_f1+n_f2+n_f1f2], color='green', linestyle='dotted', linewidth=1)
            fig, ax = decorate_fig_carpetplot(y,fig, ax, im, _task_frequencies[0], _task_frequencies[1], n_f1, n_f2, n_f1f2, FONTSIZE=8, TR=TR)

            fig.tight_layout()

            png_out = fig_out_dir / f"experiment-{experiment_id}_mri-{mri_id}_sub-{sub_id}_task-{roi_task_id}_task-{task_id}_pval-{pval}_fo-{fo}_{y_ix}{y_label}.png"
            fig.savefig(png_out, dpi='figure')
            
            #if y_label == "denoised" and task_id == "entrain":
                #import pdb; pdb.set_trace()

            if close_figures:
                plt.close()


        clear_output()


Reading: /scratch/fastfmri/experiment-1_frequency_tagging_mri-3T_smooth-0_truncate-39-219_n-400_batch-merged_desc-basic_roi-entrain-0.125_pval-uncp_fo-0.8_bootstrap/sub-008/bootstrap/task-control_bootstrapped_data.pkl
Reading: /scratch/fastfmri/experiment-1_frequency_tagging_mri-3T_smooth-0_truncate-39-219_n-400_batch-merged_desc-basic_roi-entrain-0.2_pval-uncp_fo-0.8_bootstrap/sub-008/bootstrap/task-control_bootstrapped_data.pkl


  y_all = (( y_all - y_all.mean(0)) / y_all.std(0) ).T
