In [32]:
import sys
import json
import pickle
import numpy as np
import pandas as pd
import nibabel as nb
from copy import deepcopy
import matplotlib.pylab as plt
from matplotlib.patches import Patch
from matplotlib.lines import Line2D
from matplotlib import colors
from mne import read_epochs, pick_types
from mne.time_frequency import psd_array_welch
from scipy.interpolate import interp1d
import new_files
import os.path as op
from os import sep
import trimesh
import networkx as nx
import open3d as o3d
from tools import compute_rel_power, get_crossover, detect_crossing_points, data_to_rgb

In [34]:
def comp_np_arr(ar1, ar2):
    if ar1.shape == ar2.shape:
        return np.all(ar1 == ar2)
    elif ar1.shape != ar2.shape:
        return False
    else:
        return False

In [7]:
dir_search = new_files.Files()
dataset_path = "/home/common/bonaiuto/multiburst/derivatives/processed"
img_path = "/scratch/poster_visualisations"
all_jsons = dir_search.get_files(dataset_path,"*.json", strings=["info"])

In [8]:
epoch_types = {
    "visual": [np.linspace(-0.2, 0.8, num=601), [0.0, 0.2], -0.01],
    "motor": [np.linspace(-0.5, 0.5, num=601), [-0.2, 0.2], -0.2]
}

crop_info = {
    "visual": (-0.2, 0.8),
    "motor": (-0.5, 0.5)
}

flims = [0.1,125] # freq limits for psd

In [58]:
for sub_ix in range(6):
    json_file = all_jsons[sub_ix]
    with open(json_file) as pipeline_file:
        info = json.load(pipeline_file)

    cortical_thickness = np.load(info["cortical_thickness_path"])
    atlas_labels = np.load(info["atlas_labels_path"])
    atlas_colours = np.load(info["atlas_colors_path"])
    atlas = pd.read_csv(info["atlas"])

    fif_mu = list(zip(info["sensor_epochs_paths"], info["MU_paths"]))

    # SUBJ NAMES
    for file_no in range(4):
        fif, MU = fif_mu[file_no]
        epoch_type = [i for i in epoch_types.keys() if i in fif][0]
        subject = fif.split(sep)[-3]

        # CSD CALC with FIF MU
        core_name = fif.split(sep)[-1].split("_")[-1].split(".")[0]
        epo_type = [i for i in crop_info.keys() if i in fif][0]
        fif = read_epochs(fif, verbose=False)
        fif = fif.pick_types(meg=True, ref_meg=False, misc=False, verbose=False)
        fif_all = fif.get_data()
        fif = fif.crop(tmin=crop_info[epo_type][0], tmax=crop_info[epo_type][1])
        sfreq = fif.info["sfreq"]
        fif_times = fif.times
        fif = fif.get_data()
        fif = np.mean(fif, axis=0) # if no split on conditions

        new_mu_path = MU.split(".")[0] + ".npy"
        if not op.exists(new_mu_path):
            MU = pd.read_csv(MU, sep="\t", header=None).to_numpy()
            np.save(new_mu_path, MU)
        elif op.exists(new_mu_path):
            MU = np.load(new_mu_path)

        MU = np.split(MU, info["n_surf"], axis=0)
        layer_shape = MU[0].shape[0]

        src = []
        for i in range(layer_shape):
            vertex_layers = np.array([mx[i] for mx in MU])
            vertex_source = np.dot(fif.T, vertex_layers.T).T
            src.append(vertex_source) 
        src = np.array(src)

        metric = np.log10(np.var(np.mean(src, axis=1), axis=1))
        # metric = np.log10(np.max(np.abs(np.mean(src, axis=1)), axis=1))
        brain = nb.load(info["pial_ds_nodeep_inflated"])
        vertices, faces = brain.agg_data()
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False, validate=False)
        vx_neighbours = mesh.vertex_neighbors
        map_perc = metric >= np.percentile(metric, 99)
        vx_neighbours = [np.array(i) for i in vx_neighbours]
        removed_mesh = np.arange(vertices.shape[0])[map_perc]
        removed_mesh_neighbours = [np.intersect1d(vx_neighbours[i], removed_mesh) for i in removed_mesh]
        dict_mesh = {i[0]: i[1] for i in list(zip(removed_mesh, removed_mesh_neighbours)) if len(i[1]) > 0}
        vv_graph = nx.from_dict_of_lists(dict_mesh)
        all_clusters = {i: np.sort(np.array(list(nx.node_connected_component(vv_graph, i)))) for i in list(dict_mesh)}
        unique_clusters = {}
        for val in list(all_clusters.values()):
            if unique_clusters.get(tuple(val)) == None:
                keys = [i for i in list(all_clusters.keys()) if comp_np_arr(all_clusters[i], val)]
                unique_clusters[tuple(val)] = len(keys)

        unique_clusters = [np.array(i) for i in list(unique_clusters.keys())]
        unique_clusters = [i for i in unique_clusters if i.shape[0] > 4]
        cluster_vertices = np.hstack(unique_clusters)
        data = {}
        data["json"] = json_file
        data["log_variance"] = log_vpv
        data["atlas_labels"] = np.array([i.decode("utf=8") for i in atlas_labels])
        data["clusters"] = unique_clusters

        # loop over the cluster vertices
        vertex_results = {}
        for vx in cluster_vertices:
            signal = src[vx]


            vertex_source = []
            for trial in fif_all:
                layered = np.array([np.dot(trial.T, MU[i][vx]) for i in range(info["layers"])])
                vertex_source.append(np.array(layered))
            vertex_source = np.array(vertex_source)

            winsize = int(sfreq)
            overlap = int(winsize/2)
            psd, freqs = psd_array_welch(
                vertex_source, sfreq, fmin=flims[0], 
                fmax=flims[1], n_fft=2000, 
                n_overlap=overlap, n_per_seg=winsize,
                window="hann", verbose=False, n_jobs=1
            )
            mean_psd = np.nanmean(psd,axis=0)

            vertex_results[(vx, "signal")] = signal
            vertex_results[(vx, "mean_psd")] = mean_psd
            vertex_results["freqs"] = freqs

        data["vertex_mean_signal_psd"] = vertex_results

        filename = "raw_preproc_{}.pickle".format(core_name)
        out_path = op.join(img_path, filename)
        with open(out_path, "wb") as handle:
            pickle.dump(data, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [57]:
atlas_labels

array(['L_A4_ROI', 'L_A4_ROI', 'L_A4_ROI', ..., 'R_PFop_ROI',
       'R_PFop_ROI', 'R_PFop_ROI'], dtype='<U12')