In [1]:
import pandas as pd
import os.path as op
from os import sep
import nibabel as nb
import numpy as np
import json
import trimesh
import open3d as o3d
import open3d.visualization.rendering as rendering
import matplotlib.pylab as plt
from matplotlib import cm, colors
from utilities import files
import new_files
import tqdm.auto as tqdm
from copy import copy
from sklearn.cluster import KMeans
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler, RobustScaler, minmax_scale
from sklearn.manifold import MDS
from scipy.spatial.distance import euclidean
from scipy.stats import spearmanr
from brain_tools import *
import pickle
from functools import reduce
import itertools as it

dir_search = new_files.Files()

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def plot_csd(smooth_csd, list_ROI_vertices, bb_path, times, ax, cb=True, cmap="RdBu_r", vmin_vmax=None):
    layer_labels = ["I", "II", "III", "IV", "V", "VI"]
    with open(bb_path, "r") as fp:
        bb = json.load(fp)
    bb = [np.array(bb[i])[list_ROI_vertices] for i in bb.keys()]
    bb_mean = [np.mean(i) for i in bb]
    bb_std = [np.std(i) for i in bb]
    max_smooth = np.max(np.abs(smooth_csd))
    if vmin_vmax == None:
        divnorm = colors.TwoSlopeNorm(vmin=-max_smooth, vcenter=0, vmax=max_smooth)
    else:
        divnorm = colors.TwoSlopeNorm(vmin=vmin_vmax[0], vcenter=0, vmax=vmin_vmax[1])
    extent = [times[0], times[-1], 1, 0]
    csd_imshow = ax.imshow(
        smooth_csd, norm=divnorm, origin="lower",
        aspect="auto", extent=extent,
        cmap=cmap
    )
    ax.set_ylim(1,0)
    for l_ix, th in enumerate(np.cumsum(bb_mean)):
            ax.axhline(th, linestyle=(0, (5,5)), c="black", lw=0.5)
            # ax.axhspan(th-bb_std[l_ix], th+bb_std[l_ix], alpha=0.05, color="black", lw=0)
            ax.annotate(layer_labels[l_ix],[times[0]+0.01, th-0.01],size=15)
    if cb:
        plt.colorbar(csd_imshow, ax=ax)
    plt.tight_layout()

In [3]:
def shift_cor(x1, x2, shifts=[-2, -1, 0, 1, 2], ret_sr=True):
    results = {}
    for s in shifts:
        if s < 0:
            x1 = x1[-s:]
            x2 = x2[:s]
        elif s > 0:
            x1 = x1[:-s]
            x2 = x2[s:]
        elif s == 0:
            x1 = x1
            x2 = x2
        sr = spearmanr(x1.flatten(), x2.flatten())
        results[s] = list(sr)
    max_corr_ix = np.argmax(np.abs(np.array(list(results.values()))[:, 0]))
    max_corr_key = list(results.keys())[max_corr_ix]
    max_corr_value = results[max_corr_key][0]
    if ret_sr:
        return max_corr_value, max_corr_key    
    else:
        return max_corr_value, results

In [4]:
dataset_location = "/home/common/bonaiuto/multiburst/derivatives/processed"
epoch_types = {
    "visual": [np.linspace(-0.2, 0.8, num=601), [0.0, 0.2], -0.01],
    "motor": [np.linspace(-0.5, 0.5, num=601), [-0.2, 0.2], -0.2]
}

csd_files = dir_search.get_files(
    dataset_location, "*.npy", prefix="time_CSD_autoreject"
)

json_files = dir_search.get_files(
    dataset_location, "*.json", prefix="info"
)

info_dict = {}
for i in json_files:
    sub = i.split(sep)[-3]
    with open(i, "r") as fp:
        info_dict[sub] = json.load(fp)

In [5]:
TOTAL_RESULTS = {}
n_bins=100
for csd_file in tqdm.tqdm(csd_files):
    epoch_type = [i for i in epoch_types.keys() if i in csd_file][0]
    subject = csd_file.split(sep)[-4]
    core_name = csd_file.split(sep)[-1].split("_")[-1].split(".")[0]
    info = info_dict[subject]
    atlas = pd.read_csv(info["atlas"])
    atlas_labels = np.load(info["atlas_colors_path"])
    visual_ROI = atlas.loc[(atlas.PRIMARY_SECTION == 1)].USED_LABEL.values
    visual_ROI = np.hstack([visual_ROI, [i for i in atlas.USED_LABEL.values if "_MT_" in i]])
    sensorimotor_ROI = ["L_4_ROI", "R_4_ROI"]
    labels_xxx = {
        "visual": visual_ROI,
        "motor": sensorimotor_ROI
    }
    ROI_labels = labels_xxx[epoch_type]
    vertex_num = np.arange(atlas_labels.shape[0])
    ROI_vertices = {i: vertex_num[[i == al.decode("utf=8") for al in atlas_labels]] for i in ROI_labels}
    times, pca_sel, baseline_lim = epoch_types[epoch_type]

    csd_data = np.load(csd_file)
    true_CSD = {}
    for l in ROI_labels:
        true_CSD[l] = []
        for rv in ROI_vertices[l]:
            true_CSD[l].append(csd_data[rv, :, :])
    true_CSD = {i: np.array(true_CSD[i]) for i in ROI_labels}
    pca_time_sel = np.where((times >= pca_sel[0]) & (times <= pca_sel[1]))[0]
    pca_csd_dataset = {i: true_CSD[i][:,:, pca_time_sel].reshape(true_CSD[i].shape[0], -1) for i in ROI_labels}

    nans_map = {}
    outlier_map = {}
    all_map = {}
    for roi_ix, roi in enumerate(ROI_labels):
        metric = pca_csd_dataset[roi].std(axis=1)
        nan_map = np.isnan(metric)
        minmax = np.percentile(metric[~nan_map], 0.005), np.percentile(metric[~nan_map], 99.995)
        out_map = metric > minmax[1] + minmax[1] * 0.000
        outlier_map[roi] = out_map
        nans_map[roi] = nan_map
        all_map[roi] = out_map | nan_map

    unique_lab = list(set([i[2:] for i in ROI_labels]))
    lab_cats = {i: [j for j in ROI_labels if i in j] for i in unique_lab}

    
    pca_results = {}
    for lab_cat in lab_cats:
        lab_row = atlas.loc[atlas.USED_LABEL == lab_cats[lab_cat][0]]
        pca_fit = np.vstack([pca_csd_dataset[i][~all_map[i]] for i in lab_cats[lab_cat]])
        pca_transform = np.vstack([pca_csd_dataset[i][~nans_map[i]] for i in lab_cats[lab_cat]])
        scaler = RobustScaler()
        pca = PCA(n_components=10)
        scaler.fit(pca_fit)
        pca_fit = scaler.transform(pca_fit)
        pca.fit(pca_fit)
        pca_transform = pca.transform(scaler.transform(pca_transform))
        pca_results[lab_cat] = [pca_transform, pca.components_, pca.explained_variance_ratio_]
        pca_scores, eigenv, exp_var = [pca_transform, pca.components_, pca.explained_variance_ratio_]
        cat_vertices = np.hstack([ROI_vertices[i] for i in lab_cats[lab_cat]])
        cat_nan_map = np.hstack([nans_map[i] for i in lab_cats[lab_cat]])

        gray = np.array([0.5, 0.5, 0.5])
        brain = nb.load(info["pial_ds_nodeep_inflated"])
        vertices, faces = brain.agg_data()
        f, ax = plt.subplots(5, 1, figsize=(14, 20), facecolor="white")
        ax[0].bar(np.arange(exp_var.shape[0])+1, exp_var)
        ax[0].set_xticks(np.arange(exp_var.shape[0])+1)
        ax[0].set_ylabel("Var Exp Ratio")

        colours_pc = []
        for i in range(4):
            pc_abs_log = np.log10(np.abs(pca_scores[:,i]))
            datacolors, mappable = data_to_rgb(
                pc_abs_log, n_bins, "pink_r",
                np.percentile(pc_abs_log, 5),
                np.percentile(pc_abs_log, 95),
                vcenter=np.percentile(pc_abs_log, 50),
                ret_map=True, 
                
            )
            hist, bins, barlist = ax[i+1].hist(pc_abs_log, bins=n_bins, edgecolor='black', linewidth=0.2);
            ax[i+1].set_ylabel("log10(|PC {}|)".format(i+1))
            for _bin_ix, _bin in enumerate(barlist):
                plt.setp(_bin, "facecolor", mappable.to_rgba(bins[_bin_ix+1]))
            plt.suptitle(lab_row.LONG_NAME.values[0], y=1)
            plt.tight_layout()
            data_colour_map = np.repeat(np.array([[1., 1., 1,]]), cat_vertices.shape[0], axis=0)
            data_colour_map[~cat_nan_map] = datacolors[:,:3]
            colours = np.repeat(gray.reshape(1,-1), vertices.shape[0], axis=0)
            colours[cat_vertices] = data_colour_map
            colours_pc.append(colours)
        filename_template = "/home/mszul/git/DANC_multilayer_laminar/output/{}_{}_PCA_results.{}"
        plt.savefig(filename_template.format(core_name, lab_cat, "png"), dpi=300)
        # plt.savefig(filename_template.format(core_name, lab_cat, "svg"))
        plt.close(f)
        
        csd = np.vstack([true_CSD[i] for i in lab_cats[lab_cat]])
        
        prc = np.linspace(0, 100, num=21)
        prc_bounds = list(zip(prc[:-1], prc[1:]))
        
        csd_smoothed_pc = []
        for pc_comp in range(4):
            pc_sc = pca_scores[:,pc_comp]
            pc_smoothed_csd = []
            for ix, pb in enumerate(prc_bounds):
                bounds = [np.percentile(pc_sc, i) for i in pb]
                pr_mask = np.where((pc_sc >= bounds[0]) & (pc_sc <= bounds[1]))[0]
                mean_smooth_csd = smooth_csd(np.mean(csd[pr_mask], axis=0), info["n_surf"])
                baseline = np.mean(mean_smooth_csd[:, np.where(times < baseline_lim)], axis=2)
                pc_smoothed_csd.append(mean_smooth_csd - baseline)
            pc_smoothed_csd = np.array(pc_smoothed_csd)
            csd_smoothed_pc.append(pc_smoothed_csd)

        f, ax = plt.subplots(4, 3, figsize=(14, 4*4), facecolor="white")
        ax[0,0].set_title("{} - {} percentile".format(*prc_bounds[0]))
        ax[0,1].set_title("{} - {} percentile".format(*prc_bounds[-1]))
        ax[0,2].set_title("eigenvector motifs")
        for pc_comp in range(4):
            ax[pc_comp,0].set_ylabel("PC {}".format(pc_comp+1))
            vmm = csd_smoothed_pc[pc_comp].min(), csd_smoothed_pc[pc_comp].max()
            data = np.array([csd_smoothed_pc[pc_comp][0], csd_smoothed_pc[pc_comp][-1]])
            plot_csd(data[0], cat_vertices, info["big_brain_layers_path"], times, ax=ax[pc_comp,0])
            plot_csd(data[1], cat_vertices, info["big_brain_layers_path"], times, ax=ax[pc_comp,1])
            ev = smooth_csd(np.array(np.split(eigenv[pc_comp], info["layers"])), info["layers"])
            nan_ar = np.zeros(data[0].shape)
            nan_ar[nan_ar == 0] = np.nan
            nan_ar[:, pca_time_sel] = ev
            plot_csd(nan_ar, cat_vertices, info["big_brain_layers_path"], times, ax=ax[pc_comp,2], cmap="viridis")
            for col in range(3):
                ax[pc_comp, col].axvline(times[pca_time_sel][0], lw=0.2, c="black")
                ax[pc_comp, col].axvline(times[pca_time_sel][-1], lw=0.2, c="black")
        lab_row = atlas.loc[atlas.USED_LABEL == lab_cats[lab_cat][0]]
        plt.suptitle(lab_row.LONG_NAME.values[0], y=1)
        plt.tight_layout()
        filename_template = "/home/mszul/git/DANC_multilayer_laminar/output/{}_{}_CSD_results.{}"
        plt.savefig(filename_template.format(core_name, lab_cat, "png"), dpi=300)
        # plt.savefig(filename_template.format(core_name, lab_cat, "svg"))
        plt.close(f)
        
        TOTAL_RESULTS[(core_name, lab_cat)] = {
            "info": info,
            "brain": info["pial_ds_nodeep_inflated"],
            "pca_results": [pca_transform, pca.components_, pca.explained_variance_ratio_],
            "ROI_vertices": cat_vertices,
            "nan_vertices": cat_nan_map,
            "PC4_colours": colours_pc,
            "PC4_smooth_csd": csd_smoothed_pc
        }

  0%|          | 0/12 [00:00<?, ?it/s]

In [6]:
def rotation_matrix(theta1, theta2, theta3):
    c1 = np.cos(theta1 * np.pi / 180)
    s1 = np.sin(theta1 * np.pi / 180)
    c2 = np.cos(theta2 * np.pi / 180)
    s2 = np.sin(theta2 * np.pi / 180)
    c3 = np.cos(theta3 * np.pi / 180)
    s3 = np.sin(theta3 * np.pi / 180)
    matrix=np.array([
        [c2*c3, -c2*s3, s2], 
        [c1*s3+c3*s1*s2, c1*c3-s1*s2*s3, -c2*s1], 
        [s1*s3-c1*c3*s2, c3*s1+c1*s2*s3, c1*c2]
    ])
    return matrix

In [7]:
def custom_draw_geometry(mesh, filename="render.png", visible=True, wh=[960, 960], save=True):
    vis = o3d.visualization.Visualizer()
    vis.create_window(width=wh[0], height=wh[1], visible=visible)
    vis.add_geometry(mesh)
    vis.get_render_option().mesh_show_back_face=True
    vis.run()
    if save:
        vis.capture_screen_image(filename, do_render=True)
    vis.destroy_window()

In [8]:
rotations = {
    "L_4_ROI": [rotation_matrix(-45, 0, 90), None],
    "R_4_ROI": [rotation_matrix(-45, 0, -90), None],
    "LR_4_ROI": [rotation_matrix(-25, 0, 0), None],
    "L_V1_ROI": [rotation_matrix(-130, 45, 0), "R"],
    "R_V1_ROI": [rotation_matrix(-130, -45, 0), "L"],
    "L_MT_ROI": [rotation_matrix(-130, -60, 0), None],
    "R_MT_ROI": [rotation_matrix(-130, 60, 0), None]
}

In [31]:
v1 = {s: [i for i in list(TOTAL_RESULTS.keys()) if all([(s in i[0]), ("V1" in i[1])])] for s in ["sub-001", "sub-002"]}
m1 = {s: [i for i in list(TOTAL_RESULTS.keys()) if all([(s in i[0]), ("4_" in i[1])])] for s in ["sub-001", "sub-002"]}

for [core_name, lab_cat] in v1["sub-001"]:
    results = TOTAL_RESULTS[(core_name, lab_cat)]
    atlas = pd.read_csv(results["info"]["atlas"])
    atlas_labels = np.load(results["info"]["atlas_colors_path"])
    brain = nb.load(results["brain"])
    vertices, faces = brain.agg_data()
    pc_results = results["pca_results"][0]
    for pc_comp in range(4):
        rot_keys = [i for i in rotations.keys() if lab_cat in i]
        for roi in rot_keys:
            rot_mx, hem_sel = rotations[roi]
            filename = "/home/mszul/git/DANC_multilayer_laminar/output/{}_{}_PC_{}_brain_render.png".format(core_name, str(pc_comp).zfill(2), roi)
            mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, validate=False)
            mesh = mesh.as_open3d
            mesh.compute_vertex_normals(normalized=True)
            mesh.vertex_colors = o3d.utility.Vector3dVector(results["PC4_colours"][pc_comp])
            mesh.rotate(rot_mx)
            if hem_sel != None:
                vxmask = np.array([i.decode("utf=8")[0] == hem_sel for i in atlas_labels]) | np.array([i.decode("utf=8")[0] == "?" for i in atlas_labels])
                mesh.remove_vertices_by_mask(vxmask)
            custom_draw_geometry(mesh, filename, save=True)

In [None]:
results = TOTAL_RESULTS[('autoreject-sub-002-ses-03-003-visual-epo', 'V1_ROI')]
atlas = pd.read_csv(results["info"]["atlas"])
atlas_labels = np.load(results["info"]["atlas_colors_path"])
brain = nb.load(results["brain"])
vertices, faces = brain.agg_data()
mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, validate=False)
mesh = mesh.as_open3d
mesh.compute_vertex_normals(normalized=True)
mesh.vertex_colors = o3d.utility.Vector3dVector(results["PC4_colours"][pc_comp])
mesh.rotate(rotation_matrix(-130, 60, 0))
filename = "/home/mszul/git/DANC_multilayer_laminar/output/{}_{}_PC_{}_brain_render.png".format(core_name, str(pc_comp).zfill(2), roi)
custom_draw_geometry(mesh, filename)

In [14]:
for subs in v1.keys():
    results = TOTAL_RESULTS[v1[subs][0]]
    gray = np.array([0.5, 0.5, 0.5])
    brain = nb.load(results["info"]["pial_ds_nodeep_inflated"])
    vertices, faces = brain.agg_data()
    atlas = pd.read_csv(results["info"]["atlas"])
    atlas_labels = np.load(results["info"]["atlas_colors_path"])
    ROI_labels = ['L_V1_ROI', 'R_V1_ROI']
    vertex_num = np.arange(atlas_labels.shape[0])
    ROI_vertices = {i: vertex_num[[i == al.decode("utf=8") for al in atlas_labels]] for i in ROI_labels}
    multiplied_score = multiply_list([TOTAL_RESULTS[k]["pca_results"][0][:,0] for k in v1[subs]])[:ROI_vertices["L_V1_ROI"].shape[0]]
    log_multiplied_score = np.log10(np.abs(multiplied_score))
    datacolors, mappable = data_to_rgb(
        log_multiplied_score, 100, "pink_r",
        np.percentile(log_multiplied_score, 5),
        np.percentile(log_multiplied_score, 95),
        vcenter=np.percentile(log_multiplied_score, 50),
        ret_map=True, 
    )
    colours = np.repeat(gray.reshape(1,-1), vertices.shape[0], axis=0)
    
    map_thr = log_multiplied_score >= np.percentile(log_multiplied_score, 95)
    
    colours[ROI_vertices["L_V1_ROI"][map_thr]] = datacolors[map_thr,:3]
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, validate=False)
    mesh = mesh.as_open3d
    mesh.compute_vertex_normals(normalized=True)
    mesh.vertex_colors = o3d.utility.Vector3dVector(colours)
    custom_draw_geometry(mesh, save=False)

NameError: name 'x' is not defined

In [28]:
for subs in m1.keys():
    results = TOTAL_RESULTS[m1[subs][0]]
    gray = np.array([0.5, 0.5, 0.5])
    brain = nb.load(results["info"]["pial_ds_nodeep_inflated"])
    vertices, faces = brain.agg_data()
    atlas = pd.read_csv(results["info"]["atlas"])
    atlas_labels = np.load(results["info"]["atlas_colors_path"])
    ROI_labels = ['L_4_ROI', 'R_4_ROI']
    vertex_num = np.arange(atlas_labels.shape[0])
    ROI_vertices = {i: vertex_num[[i == al.decode("utf=8") for al in atlas_labels]] for i in ROI_labels}
    multiplied_score = [multiply_list([TOTAL_RESULTS[k]["pca_results"][0][:,h] for h in range(2)]) for k in v1[subs]]
    log_multiplied_score = np.log10(np.abs(multiplied_score[1][:ROI_vertices["L_4_ROI"].shape[0]]))
    datacolors, mappable = data_to_rgb(
        log_multiplied_score, 100, "pink_r",
        np.percentile(log_multiplied_score, 5),
        np.percentile(log_multiplied_score, 95),
        vcenter=np.percentile(log_multiplied_score, 50),
        ret_map=True, 
    )
    colours = np.repeat(gray.reshape(1,-1), vertices.shape[0], axis=0)
    
    map_thr = log_multiplied_score >= np.percentile(log_multiplied_score, 90)
    
    colours[ROI_vertices["L_4_ROI"]] = datacolors[:,:3]
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, validate=False)
    mesh = mesh.as_open3d
    mesh.compute_vertex_normals(normalized=True)
    mesh.vertex_colors = o3d.utility.Vector3dVector(colours)
    custom_draw_geometry(mesh, save=False)

array([  4682243.84607544,   -796695.1917763 ,    412211.27869708, ...,
          495901.39132174, -17269966.95093748,  -1389302.3933078 ])

In [15]:
def multiply_list(iterable):
    result = 1
    for item in iterable:
        result = result * item
    return result

In [None]:
sort_df = {
    "subject": [],
    "epoch_type": [],
    "roi": [],
    "key": [],
    "run": [],
}

label_names = {
    "4_ROI": "M1",
    "MT_ROI": "MT",
    "V1_ROI": "V1"
}

for k in TOTAL_RESULTS.keys():
    inf = k[0].split("-")
    roi = k[1]
    sort_df["subject"].append("sub-"+inf[2])
    sort_df["epoch_type"].append(inf[6])
    sort_df["roi"].append(label_names[roi])
    sort_df["run"].append(inf[5])
    sort_df["key"].append(k)

corr_data = pd.DataFrame.from_dict(sort_df)

all_cds = []
for i in range(4):
    cd = copy(corr_data)
    cd["PC"] = i+1
    all_cds.append(cd)
corr_data = pd.concat(all_cds)
corr_data = corr_data.sort_values(by=["subject", "epoch_type", "roi", "run", "PC"], ignore_index=True)
corr_data["res_label"] = corr_data.subject.map(str) + corr_data.roi.apply(lambda x: str(x).rjust(4)).map(str) + corr_data.run.apply(lambda x: "run {}".format(x).rjust(8)).map(str) + corr_data.PC.apply(lambda x: "PC {}".format(x).rjust(6)).map(str)
corr_data_split = {
    "motor": corr_data.loc[corr_data.epoch_type == "motor"].reset_index(drop=True),
    "visual": corr_data.loc[corr_data.epoch_type == "visual"].reset_index(drop=True)
}

In [None]:
corr_mx_res = {}
for datakey in corr_data_split.keys():
    data = corr_data_split[datakey]
    data_res = np.zeros([data.shape[0], data.shape[0]])
    ixes = list(it.product(range(data.shape[0]), range(data.shape[0])))
    for i in ixes:
        a = TOTAL_RESULTS[data.iloc[i[0]].key]["pca_results"][1][data.iloc[i[0]].PC]
        b = TOTAL_RESULTS[data.iloc[i[1]].key]["pca_results"][1][data.iloc[i[1]].PC]
        r, p = spearmanr(a, b)
        data_res[i] = r
    corr_mx_res[datakey] = data_res

In [None]:
divnorm = colors.Normalize(vmin=0, vmax=1)

f, ax = plt.subplots(1, 2, figsize=(25, 10), facecolor="white")
for ix, key in enumerate(corr_mx_res.keys()):
    ax[ix].set_title(key)
    im = ax[ix].imshow(np.abs(corr_mx_res[key]), norm=divnorm, cmap="YlGnBu")
    ax[ix].set_xticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_xticklabels(ax[ix].get_xticks(), rotation=90)
    ax[ix].set_xticklabels(corr_data_split[key].res_label)
    ax[ix].set_yticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_yticklabels(corr_data_split[key].res_label)
plt.colorbar(im, ax=ax[1], shrink=0.75, label="Spearman p") 
plt.xticks(rotation=90);

In [None]:
k = "motor"

x = corr_data_split[k].loc[(corr_data_split[k].run == "002") & (corr_data_split[k].subject == "sub-002")]
y = corr_data_split[k].loc[(corr_data_split[k].run == "001") & (corr_data_split[k].subject == "sub-002")]



divnorm = colors.Normalize(vmin=0, vmax=1)
f, ax = plt.subplots(1, 1, figsize=(15, 10), facecolor="white")
im = ax.imshow(np.abs(corr_mx_res[k][x.index, :][:, y.index]), norm=divnorm, cmap="YlGnBu")
ax.set_xticks(range(x.shape[0]))
ax.set_xticklabels(ax.get_xticks(), rotation=25)
ax.set_xticklabels(x.res_label)
ax.set_yticks(range(y.shape[0]))
ax.set_yticklabels(y.res_label)
plt.colorbar(im, ax=ax, shrink=0.75, label="Spearman p");

In [None]:
corr_mx_res_shift = {}
shift_mx_res_shift = {}
for datakey in corr_data_split.keys():
    data = corr_data_split[datakey]
    data_res = np.zeros([data.shape[0], data.shape[0]])
    data_res_shift = np.zeros([data.shape[0], data.shape[0]])
    ixes = list(it.product(range(data.shape[0]), range(data.shape[0])))
    for i in ixes:
        a = TOTAL_RESULTS[data.iloc[i[0]].key]["pca_results"][1][data.iloc[i[0]].PC]
        b = TOTAL_RESULTS[data.iloc[i[1]].key]["pca_results"][1][data.iloc[i[1]].PC]
        r, s = shift_cor(a, b)
        data_res[i] = r
        data_res_shift[i] = s
    corr_mx_res_shift[datakey] = data_res
    shift_mx_res_shift[datakey] = data_res_shift

In [None]:
divnorm = colors.Normalize(vmin=0, vmax=1)

f, ax = plt.subplots(1, 2, figsize=(25, 10), facecolor="white")
for ix, key in enumerate(corr_mx_res_shift.keys()):
    ax[ix].set_title(key)
    im = ax[ix].imshow(np.abs(corr_mx_res_shift[key]), norm=divnorm, cmap="YlGnBu")
    ax[ix].set_xticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_xticklabels(ax[ix].get_xticks(), rotation=90)
    ax[ix].set_xticklabels(corr_data_split[key].res_label)
    ax[ix].set_yticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_yticklabels(corr_data_split[key].res_label)
plt.colorbar(im, ax=ax[1], shrink=0.75, label="Spearman p") 
plt.xticks(rotation=90);

In [None]:
divnorm = colors.TwoSlopeNorm(0, vmin=-0.10, vmax=0.15)

f, ax = plt.subplots(1, 2, figsize=(25, 10), facecolor="white")
for ix, key in enumerate(corr_mx_res_shift.keys()):
    ax[ix].set_title(key)
    im = ax[ix].imshow(np.abs(corr_mx_res_shift[key]) - np.abs(corr_mx_res[key]), norm=divnorm, cmap="PiYG")
    ax[ix].set_xticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_xticklabels(ax[ix].get_xticks(), rotation=90)
    ax[ix].set_xticklabels(corr_data_split[key].res_label)
    ax[ix].set_yticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_yticklabels(corr_data_split[key].res_label)
plt.colorbar(im, ax=ax[1], shrink=0.75, label="Spearman p") 
plt.xticks(rotation=90);

In [None]:
divnorm = colors.Normalize(vmin=-2, vmax=2)

f, ax = plt.subplots(1, 2, figsize=(25, 10), facecolor="white")
for ix, key in enumerate(corr_mx_res_shift.keys()):
    ax[ix].set_title(key)
    im = ax[ix].imshow(shift_mx_res_shift[key], norm=divnorm, cmap="rainbow")
    ax[ix].set_xticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_xticklabels(ax[ix].get_xticks(), rotation=90)
    ax[ix].set_xticklabels(corr_data_split[key].res_label)
    ax[ix].set_yticks(range(corr_data_split[key].shape[0]))
    ax[ix].set_yticklabels(corr_data_split[key].res_label)
plt.colorbar(im, ax=ax[1], shrink=0.75, label="Spearman p") 
plt.xticks(rotation=90);

In [None]:
k = "motor"

x = corr_data_split[k].loc[(corr_data_split[k].run == "002") & (corr_data_split[k].subject == "sub-002")]
y = corr_data_split[k].loc[(corr_data_split[k].run == "001") & (corr_data_split[k].subject == "sub-002")]



divnorm = colors.Normalize(vmin=0, vmax=1)
f, ax = plt.subplots(1, 1, figsize=(15, 10), facecolor="white")
im = ax.imshow(np.abs(corr_mx_res_shift[k][x.index, :][:, y.index]), norm=divnorm, cmap="YlGnBu")
ax.set_xticks(range(x.shape[0]))
ax.set_xticklabels(ax.get_xticks(), rotation=25)
ax.set_xticklabels(x.res_label)
ax.set_yticks(range(y.shape[0]))
ax.set_yticklabels(y.res_label)
plt.colorbar(im, ax=ax, shrink=0.75, label="Spearman p");

In [None]:
k = "motor"

x = corr_data_split[k].loc[(corr_data_split[k].run == "002") & (corr_data_split[k].subject == "sub-002")]
y = corr_data_split[k].loc[(corr_data_split[k].run == "001") & (corr_data_split[k].subject == "sub-002")]



divnorm = colors.Normalize(vmin=-2, vmax=2)
f, ax = plt.subplots(1, 1, figsize=(15, 10), facecolor="white")
im = ax.imshow(shift_mx_res_shift[k][x.index, :][:, y.index], norm=divnorm, cmap="rainbow")
ax.set_xticks(range(x.shape[0]))
ax.set_xticklabels(ax.get_xticks(), rotation=25)
ax.set_xticklabels(x.res_label)
ax.set_yticks(range(y.shape[0]))
ax.set_yticklabels(y.res_label)
plt.colorbar(im, ax=ax, shrink=0.75, label="Spearman p");

In [None]:
k = "visual"

x = corr_data_split[k].loc[(corr_data_split[k].run == "002") & (corr_data_split[k].subject == "sub-002")]
y = corr_data_split[k].loc[(corr_data_split[k].run == "001") & (corr_data_split[k].subject == "sub-002")]



divnorm = colors.Normalize(vmin=0, vmax=1)
f, ax = plt.subplots(1, 1, figsize=(15, 10), facecolor="white")
im = ax.imshow(np.abs(corr_mx_res_shift[k][x.index, :][:, y.index]), norm=divnorm, cmap="YlGnBu")
ax.set_xticks(range(x.shape[0]))
ax.set_xticklabels(ax.get_xticks(), rotation=25)
ax.set_xticklabels(x.res_label)
ax.set_yticks(range(y.shape[0]))
ax.set_yticklabels(y.res_label)
plt.colorbar(im, ax=ax, shrink=0.75, label="Spearman p");

In [None]:
k = "visual"

x = corr_data_split[k].loc[(corr_data_split[k].run == "002") & (corr_data_split[k].subject == "sub-002")]
y = corr_data_split[k].loc[(corr_data_split[k].run == "001") & (corr_data_split[k].subject == "sub-002")]


divnorm = colors.Normalize(vmin=-2, vmax=2)
f, ax = plt.subplots(1, 1, figsize=(15, 10), facecolor="white")
im = ax.imshow(shift_mx_res_shift[k][x.index, :][:, y.index], norm=divnorm, cmap="rainbow")
ax.set_xticks(range(x.shape[0]))
ax.set_xticklabels(ax.get_xticks(), rotation=25)
ax.set_xticklabels(x.res_label)
ax.set_yticks(range(y.shape[0]))
ax.set_yticklabels(y.res_label)
plt.colorbar(im, ax=ax, shrink=0.75, label="Spearman p");