In [193]:
import sys
import json
import pickle
import numpy as np
import pandas as pd
import nibabel as nb
from copy import deepcopy
import matplotlib.pylab as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from matplotlib import colors
from mne import read_epochs, pick_types
from mne.time_frequency import psd_array_welch
from scipy.interpolate import interp1d
import new_files
import os.path as op
from os import sep
import trimesh
import open3d as o3d
from tools import compute_rel_power, get_crossover, detect_crossing_points, data_to_rgb
from ffntr import fooofinator

In [242]:
def custom_draw_geometry(mesh, filename="render.png", visible=True, wh=[960, 960], save=True):
    vis = o3d.visualization.Visualizer()
    vis.create_window(width=wh[0], height=wh[1], visible=visible)
    if isinstance(mesh, list):
        for i in mesh:
            vis.add_geometry(i)
    else:
        vis.add_geometry(mesh)
    vis.get_render_option().mesh_show_back_face=True
    vis.run()
    if save:
        vis.capture_screen_image(filename, do_render=True)
    vis.destroy_window()


def csd_compute(surf_tcs, spacing):
    # Compute CSD
    nd=2;

    ex_surf_tcs=np.vstack([surf_tcs[0,:], surf_tcs[0,:], surf_tcs, surf_tcs[-1,:], surf_tcs[-1,:]])
    
    csd=np.zeros((surf_tcs.shape[0], surf_tcs.shape[1]))
    for t in range(surf_tcs.shape[1]):
        phi=ex_surf_tcs[:,t]
        for z in range(surf_tcs.shape[0]):
            csd[z,t]=(phi[(z+2)+nd]-2*phi[(z+2)]+phi[(z+2)-nd])/((nd*spacing)**2)
        
    return csd


def csd_compute_upd(surf_tcs, spacing, nd=2):
    # Compute CSD

    ex_surf_tcs = np.vstack([surf_tcs[0, :], surf_tcs[0, :], surf_tcs, surf_tcs[-1, :], surf_tcs[-1, :]])

    csd = np.zeros((surf_tcs.shape[0], surf_tcs.shape[1]))
    for t in range(surf_tcs.shape[1]):
        phi = ex_surf_tcs[:, t]
        for z in range(surf_tcs.shape[0]):
            csd[z, t] = (phi[(z + 2) + nd] - 2 * phi[z+nd] + phi[(z - 2) + nd]) / ((nd * spacing) ** 2)
        
    return csd


def csd_smooth(csd, layers=11):
    layers, time = csd.shape
    smoothed = []
    x = np.linspace(0, 1, num=layers)
    xs = np.linspace(0, 1, num=500)
    for t in range(time):
        fx = interp1d(x, csd[:, t], kind="cubic")
        ys = fx(xs)
        smoothed.append(ys)
    smoothed = np.array(smoothed).T
    return smoothed


def plot_csd(smooth_csd, list_ROI_vertices, bb_path, times, ax, cb=True, cmap="RdBu_r", vmin_vmax=None, return_details=False):
    layer_labels = ["I", "II", "III", "IV", "V", "VI"]
    with open(bb_path, "r") as fp:
        bb = json.load(fp)
    bb = [np.array(bb[i])[list_ROI_vertices] for i in bb.keys()]
    bb_mean = [np.mean(i) for i in bb]
    bb_std = [np.std(i) for i in bb]
    max_smooth = np.max(np.abs(smooth_csd))
    if vmin_vmax == None:
        divnorm = colors.TwoSlopeNorm(vmin=-max_smooth, vcenter=0, vmax=max_smooth)
    elif vmin_vmax == "norm":
        divnorm = colors.Normalize()
    else:
        divnorm = colors.TwoSlopeNorm(vmin=vmin_vmax[0], vcenter=0, vmax=vmin_vmax[1])
    extent = [times[0], times[-1], 1, 0]
    csd_imshow = ax.imshow(
        smooth_csd, norm=divnorm, origin="lower",
        aspect="auto", extent=extent,
        cmap=cmap, interpolation="none"
    )
    ax.set_ylim(1,0)
    
    layers_params = []
    for l_ix, th in enumerate(np.cumsum(bb_mean)):
        ax.axhline(th, linestyle=(0, (5,5)), c="black", lw=0.5)
        # ax.axhspan(th-bb_std[l_ix], th+bb_std[l_ix], alpha=0.05, color="black", lw=0)
        ax.annotate(layer_labels[l_ix],[times[0]+0.06, th-0.01], size=15, ha='left')
        layers_params.append([th, layer_labels[l_ix]])
    if cb:
        plt.colorbar(csd_imshow, ax=ax)
    plt.tight_layout()
    if return_details:
        return layers_params, csd_imshow

def plot_spaced_signal(signal, times, ax=ax):
    v_contacts = signal.shape[0]
    cm = plt.colormaps["rainbow"]
    cm_l = cm(np.linspace(0,1, num=v_contacts))[:,:3]
    max_amp = np.abs(signal).max()
    max_amp = max_amp + max_amp*0.05
    spacing = np.linspace(0, max_amp * v_contacts, num=v_contacts)
    vis_signal = signal + spacing.reshape(-1,1)
    for ix, vs in enumerate(vis_signal):
        ax.plot(times, vs, lw=1, c=cm_l[ix])
    ax.set_yticks(spacing)
    ax.set_yticklabels(np.arange(1,12)[::-1])
    ax.set_ylim(spacing[0] - max_amp, spacing[-1] + max_amp)
    

In [3]:
dir_search = new_files.Files()
dataset_path = "/home/common/bonaiuto/multiburst/derivatives/processed"
img_path = "/scratch/poster_visualisations"
all_jsons = dir_search.get_files(dataset_path,"*.json", strings=["info"])

In [4]:
ROIS = {
    "L_4_ROI": "Left Primary Motor Cortex",
    "L_6ma_ROI": "Left Supplementary Motor Area",
    "L_6mp_ROI": "Left Supplementary Motor Area",
    "L_6d_ROI": "Left Pre-Motor Area",
    "L_6v_ROI": "Left Pre-Motor Area",
    "L_6r_ROI": "Left Pre-Motor Area",
    "L_6a_ROI": "Left Pre-Motor Area",
    "L_LIPv_ROI": "Left Intraparietal Area", 
    "L_VIP_ROI": "Left Intraparietal Area", 
    "L_MIP_ROI": "Left Intraparietal Area",
    "L_V1_ROI": "Left Primary Visual Cortex",
    "L_MT_ROI": "Left V5-MT Visual Cortex",
    "R_4_ROI": "Right Primary Motor Cortex",
    "R_6ma_ROI": "Right Supplementary Motor Area",
    "R_6mp_ROI": "Right Supplementary Motor Area",
    "R_6d_ROI": "Right Pre-Motor Area",
    "R_6v_ROI": "Right Pre-Motor Area",
    "R_6r_ROI": "Right Pre-Motor Area",
    "R_6a_ROI": "Right Pre-Motor Area",
    "R_LIPv_ROI": "Right Intraparietal Area", 
    "R_VIP_ROI": "Right Intraparietal Area", 
    "R_MIP_ROI": "Right Intraparietal Area",
    "R_V1_ROI": "Right Primary Visual Cortex",
    "R_MT_ROI": "Right V5-MT Visual Cortex",
}

epoch_types = {
    "visual": [np.linspace(-0.2, 0.8, num=601), [0.0, 0.2], -0.01],
    "motor": [np.linspace(-0.5, 0.5, num=601), [-0.2, 0.2], -0.2]
}

crop_info = {
    "visual": (-0.2, 0.8),
    "motor": (-0.5, 0.5)
}

flims = [0.1,125] # freq limits for psd

In [250]:
subject = 5

json_file = all_jsons[subject]
with open(json_file) as pipeline_file:
    info = json.load(pipeline_file)
cortical_thickness = np.load(info["cortical_thickness_path"])
atlas_labels = np.load(info["atlas_labels_path"])
atlas_colours = np.load(info["atlas_colors_path"])
atlas = pd.read_csv(info["atlas"])

fif_mu = list(zip(info["sensor_epochs_paths"], info["MU_paths"]))

# ROI VERTICES
vertex_num = np.arange(atlas_labels.shape[0])
roi_k = np.array(list(ROIS.keys()))
roi_v = np.array(list(ROIS.values()))
ROI_subfields = {i : list(roi_k[roi_v == i]) for i in np.unique(roi_v)}
subfields_vertices = {i: vertex_num[[i == al.decode("utf=8") for al in atlas_labels]] for i in roi_k}
ROI_vertices = {i: np.hstack([subfields_vertices[j] for j in ROI_subfields[i]]) for i in ROI_subfields.keys()}

for file_no in range(4):
    # SUBJ NAMES
    fif, MU = fif_mu[file_no]
    epoch_type = [i for i in epoch_types.keys() if i in fif][0]
    subject = epochs.split(sep)[-3]


    # CSD CALC with FIF MU
    core_name = fif.split(sep)[-1].split("_")[-1].split(".")[0]
    epo_type = [i for i in crop_info.keys() if i in fif][0]
    fif = read_epochs(fif, verbose=False)
    fif = fif.pick_types(meg=True, ref_meg=False, misc=False, verbose=False)
    fif = fif.crop(tmin=crop_info[epo_type][0], tmax=crop_info[epo_type][1])
    sfreq = fif.info["sfreq"]
    fif_times = fif.times
    fif = fif.get_data()
    fif_all = deepcopy(fif)
    fif = np.mean(fif, axis=0) # if no split on conditions

    new_mu_path = MU.split(".")[0] + ".npy"
    if not op.exists(new_mu_path):
        MU = pd.read_csv(MU, sep="\t", header=None).to_numpy()
        np.save(new_mu_path, MU)
    elif op.exists(new_mu_path):
        MU = np.load(new_mu_path)

    MU = np.split(MU, info["n_surf"], axis=0)
    layer_shape = MU[0].shape[0]
    src = []
    for i in range(layer_shape):
        vertex_layers = np.array([mx[i] for mx in MU])
        vertex_source = np.dot(fif.T, vertex_layers.T).T
        src.append(vertex_source)
    src = np.array(src)
    log_vpv = np.log10(np.var(np.mean(src, axis=1), axis=1))

    n_bins=100
    data_in = log_vpv
    datacolors, mappable = data_to_rgb(
        data_in, n_bins, "pink_r",
        np.percentile(data_in, 5),
        np.percentile(data_in, 99.5),
        vcenter=np.percentile(data_in, 50),
        ret_map=True, 

    )

    viz = False
    if viz:
        # dist
        f, ax = plt.subplots(figsize=(10,3))
        hist, bins, barlist = ax.hist(data_in, bins=n_bins, edgecolor='black', linewidth=0.2)
        for _bin_ix, _bin in enumerate(barlist):
            plt.setp(_bin, "facecolor", mappable.to_rgba(bins[_bin_ix+1]))
        ax.set_title(core_name)
        ax.set_xlabel("log10(variance)")

        # full brian
        brain = nb.load(info["pial_ds_nodeep_inflated"])
        vertices, faces = brain.agg_data()
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, validate=False)
        mesh = mesh.as_open3d
        mesh.compute_vertex_normals(normalized=True)
        mesh.vertex_colors = o3d.utility.Vector3dVector(datacolors[:,:3])
        custom_draw_geometry(mesh, save=False)

    ROI_data = {}
    ROI_data["json"] = json_file
    ROI_data["log_variance"] = log_vpv

    for k in ROI_vertices:
        roi_vx = ROI_vertices[k]
        max_vx = roi_vx[np.argmax(log_vpv[roi_vx])]
        signal = src[max_vx]
        gradient = csd_compute_upd(signal, (cortical_thickness[max_vx]/10)*1e3, nd=2)

        vertex_source = []
        for trial in fif_all:
            layered = np.array([np.dot(trial.T, MU[i][max_vx]) for i in range(info["layers"])])
            vertex_source.append(np.array(layered))
        vertex_source = np.array(vertex_source)
        winsize = int(sfreq)
        overlap = int(winsize/2)
        psd, freqs = psd_array_welch(
            vertex_source, sfreq, fmin=flims[0], 
            fmax=flims[1], n_fft=2000, 
            n_overlap=overlap, n_per_seg=winsize,
            window="hann", verbose=False, n_jobs=1
        )
        mean_psd = np.nanmean(psd,axis=0)
        mean_split = np.vsplit(mean_psd, 11)
        layer_power = []
        layer_periodic = []
        layer_aperiodic = []
        for i in mean_split:
            ff = fooofinator.FOOOFinator()
            PsD = i.flatten()
            ff.fit(freqs, PsD, flims, n_jobs=-1)
            layer_power.append(np.log10(PsD))
            layer_periodic.append(np.log10(PsD) - ff._ap_fit)
            layer_aperiodic.append(ff._ap_fit)
            del ff
        layer_power = np.array(layer_power)
        layer_periodic = np.array(layer_periodic)
        layer_aperiodic = np.array(layer_aperiodic)    

        relative_total = compute_rel_power(layer_power, freqs)
        relative_periodic = compute_rel_power(layer_periodic, freqs)
        relative_aperiodic = compute_rel_power(layer_aperiodic, freqs)
        smooth_rel_total = csd_smooth(relative_total)
        smooth_rel_per = csd_smooth(relative_periodic)
        smooth_rel_aper = csd_smooth(relative_aperiodic)
        try:
            crossover = get_crossover(freqs, smooth_rel_per, smooth_rel_aper)
        except:
            crossover = 0

        ROI_data[(k, "ROI_vertex")] = max_vx
        ROI_data[(k, "signal")] = signal
        ROI_data[(k, "gradient")] = gradient
        ROI_data[(k, "layer_PSD")] = layer_power
        ROI_data[(k, "layer_periodic")] = layer_periodic
        ROI_data[(k, "layer_aperiodic")] = layer_aperiodic
        f, ax = plt.subplots(3,4, figsize=(14, 10))
        plot_spaced_signal(signal, epoch_types[epo_type][0], ax[0,0])
        plot_spaced_signal(gradient, epoch_types[epo_type][0], ax[0,1])
        CSD_res = csd_smooth(gradient)
        baseline_lim = epoch_types[epo_type][2]
        baseline = np.mean(CSD_res[:, np.where(epoch_types[epo_type][0] < baseline_lim)], axis=2)
        CSD_res = CSD_res - baseline
        CSD_res = (CSD_res - CSD_res.mean()) / CSD_res.std()
        im = plot_csd(
            CSD_res, roi_vx,
            info["big_brain_layers_path"], 
            epoch_types[epo_type][0], ax=ax[0,2],
            return_details=True, cb=False
        )
        plt.colorbar(im[1], ax=ax[0,3], label="Z-score")

        ax[0,2].set_yticks(np.linspace(0,1, num=info["layers"]))
        ax[0,2].set_yticklabels(np.arange(1,12))

        for ix in range(info["layers"]):

            ax[1,0].plot(freqs, layer_power[ix], c=cm_l[ix], lw=1.5)
            ax[1,0].set_title("PSD per layer")
            ax[1,0].axvline(7, lw=0.5, c="black")
            ax[1,0].axvline(30, lw=0.5, c="black")
            ax[1,0].axvline(50, lw=0.5, c="black")
            ax[1,1].plot(freqs, layer_periodic[ix], c=cm_l[ix], lw=1.5)
            ax[1,1].set_title("Periodic power")
            ax[1,1].axvline(7, lw=0.5, c="black")
            ax[1,1].axvline(30, lw=0.5, c="black")
            ax[1,1].axvline(50, lw=0.5, c="black")
            ax[1,2].plot(freqs, layer_aperiodic[ix], c=cm_l[ix], lw=1.5)
            ax[1,2].set_title("Aperiodic power")
            ax[1,2].axvline(7, lw=0.5, c="black")
            ax[1,2].axvline(30, lw=0.5, c="black")
            ax[1,2].axvline(50, lw=0.5, c="black")


        ax[2,0].set_title("Relative PSD per layer")
        plot_csd(
            smooth_rel_total, roi_vx,
            info["big_brain_layers_path"], 
            freqs, ax=ax[2,0], cmap="YlGnBu",
            vmin_vmax="norm", cb=False
        )

        ax[2,1].set_title("Relative periodic power")
        plot_csd(
            smooth_rel_per, roi_vx,
            info["big_brain_layers_path"], 
            freqs, ax=ax[2,1], cmap="YlGnBu",
            vmin_vmax="norm", cb=False
        )

        ax[2,2].set_title("Relative aperiodic power")
        layer_deets = plot_csd(
            smooth_rel_aper, roi_vx,
            info["big_brain_layers_path"], 
            freqs, ax=ax[2,2], cmap="YlGnBu",
            vmin_vmax="norm", cb=False, return_details=True
        )
        plt.colorbar(layer_deets[1], ax=ax[1,3], label="Relative PSD (cubic interp.)")

        handles = [
            Line2D([], [], color="blue", label="Aperiodic Gamma"),
            Line2D([], [], color="green", label="Periodic Alpha-beta"),
            Line2D([], [], color="red", linestyle="dashed", label="Crossover")
        ]
        ax[1,3].legend(handles=handles, loc="center", frameon=False)

        ax[2,0].set_yticks(np.linspace(0,1, num=info["layers"]))
        ax[2,0].set_yticklabels(np.arange(1,12))
        yl=ax[2,0].get_ylim()

        ax[2,0].plot([7,7],yl,':', lw=0.5, c="black")
        ax[2,0].plot([30,30],yl,':', lw=0.5, c="black")
        ax[2,0].plot([50,50],yl,':', lw=0.5, c="black")
        ax[2,0].axhline(crossover/smooth_rel_total.shape[0], linestyle=(0, (5,5)), c="red", lw=1)
        ax[2,1].set_yticks(np.linspace(0,1, num=info["layers"]))
        ax[2,1].set_yticklabels(np.arange(1,12))
        ax[2,1].plot([7,7],yl,':', lw=0.5, c="black")
        ax[2,1].plot([30,30],yl,':', lw=0.5, c="black")
        ax[2,1].plot([50,50],yl,':', lw=0.5, c="black")
        ax[2,1].axhline(crossover/smooth_rel_total.shape[0], linestyle=(0, (5,5)), c="red", lw=1)
        ax[2,2].set_yticks(np.linspace(0,1, num=info["layers"]))
        ax[2,2].set_yticklabels(np.arange(1,12));
        ax[2,2].plot([7,7],yl,':', lw=0.5, c="black")
        ax[2,2].plot([30,30],yl,':', lw=0.5, c="black")
        ax[2,2].plot([50,50],yl,':', lw=0.5, c="black")
        ax[2,2].axhline(crossover/smooth_rel_total.shape[0], linestyle=(0, (5,5)), c="red", lw=1)

        ab_idx = np.where((freqs >= 7) & (freqs <= 30))[0]
        g_idx = np.where((freqs >= 50) & (freqs <= 125))[0]

        ab_rel_pow = np.mean(smooth_rel_per[:, ab_idx], axis=1)
        g_rel_pow = np.mean(smooth_rel_aper[:, g_idx], axis=1)

        ax[2,3].plot(ab_rel_pow,np.linspace(0,1,smooth_rel_total.shape[0]),label='Periodic apha-beta', c="green")
        ax[2,3].plot(g_rel_pow,np.linspace(0,1,smooth_rel_total.shape[0]),label='Aperiodic gamma', c="blue")

        ax[2,3].set_ylim(1, 0)
        ax[2,3].set_yticks(np.linspace(0,1, num=11))
        ax[2,3].set_yticklabels(np.arange(1,12));
        for th, lab in layer_deets[0]:
            ax[2,3].axhline(th, linestyle=(0, (5,5)), c="black", lw=0.5)
            ax[2,3].annotate(lab,[0.01, th-0.01],size=15)
        ax[2,3].axhline(crossover/smooth_rel_total.shape[0], linestyle=(0, (5,5)), c="red", lw=1, label="Crossover")

        ax[0,0].set_title("Source Localized Signal")
        ax[0,1].set_title("Second spatial\nderivative of the signal")
        ax[0,2].set_title("Current Source Density\n(cubic interpolation)")
        ax[0,0].set_xlabel("Time [s]")
        ax[0,0].set_ylabel("Virtual electrode contact\nDipole moment [nAm]")
        ax[0,1].set_xlabel("Time [s]")
        ax[0,2].set_xlabel("Time [s]")

        ax[1,0].set_title("Power Spectrum Density (PSD)\n(per layer)")
        ax[1,1].set_title("Periodic PSD per layer\n(PSD with aperiodic\nspectrum subtracted)")
        ax[1,2].set_title("Aperiodic PSD per layer\n(1/f fit to the PSD)")
        ax[1,0].set_xlim(0, 125)
        ax[1,0].set_ylabel("Power Spectrum Density\n$log^{10}(nAm^{2}/Hz)$")
        ax[1,1].set_xlim(0, 125)
        ax[1,2].set_xlim(0, 125)


        ax[2,0].set_title("Relative PSD")
        ax[2,1].set_title("Relative Periodic")
        ax[2,2].set_title("Relative Aperiodic")
        ax[2,3].set_title("Crossover")
        ax[2,0].set_xlim(0, 125)
        ax[2,0].set_ylabel("Virtual electrode contact\nBig Brain layer thickness")
        ax[2,1].set_xlim(0, 125)
        ax[2,2].set_xlim(0, 125)

        ax[0,3].axis("off")
        ax[1,3].axis("off")

        f.suptitle(k + " " + core_name)
        plt.tight_layout()
        filename = "results_csd_power_{}_{}.svg".format(core_name, k)
        out_path = op.join(img_path, filename)
        plt.savefig(out_path)
        plt.close()

    filename = "data_csd_power_{}.pickle".format(core_name)
    out_path = op.join(img_path, filename)
    with open(out_path, "wb") as handle:
        pickle.dump(ROI_data, handle, protocol=pickle.HIGHEST_PROTOCOL)