In [1]:
import json
import pickle
import numpy as np
import pandas as pd
import os.path as op
import itertools as it
from utilities import files
from mne import read_epochs
import matplotlib.pylab as plt
from joblib import Parallel, delayed
from tqdm.notebook import trange, tqdm
from sklearn.preprocessing import RobustScaler
from sklearn.decomposition import PCA
from scipy.ndimage import gaussian_filter, gaussian_filter1d
from extra.tools import many_is_in, cat, shuffle_array, shuffle_array_range, consecutive_margin_ix, dump_the_dict

In [2]:
visual_epoch_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/processed/sub-145/sub-145-002-visual-epo.fif"
motor_epoch_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/processed/sub-145/sub-145-002-motor-epo.fif"
burst_features_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/PCA_results/burst_features.csv"
waveform_array_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/PCA_results/all_waveforms.npy"

In [3]:
visual_epoch = read_epochs(visual_epoch_file, verbose=False)
motor_epoch = read_epochs(motor_epoch_file, verbose=False)
visual_epoch_times = visual_epoch.times
motor_epoch_times = motor_epoch.times
del visual_epoch
del motor_epoch

In [9]:
buffer = 0.125
bin_width = 0.05
baseline_range = [-0.5, -0.25]


visual_time_bins = np.arange(visual_epoch_times[0] + buffer, visual_epoch_times[-1] - buffer, bin_width)
motor_time_bins = np.arange(motor_epoch_times[0] + buffer, motor_epoch_times[-1] - buffer, bin_width)

visual_bin_ranges = list(zip(visual_time_bins[:-1], visual_time_bins[1:]))
motor_bin_ranges = list(zip(motor_time_bins[:-1], motor_time_bins[1:]))
vis_time_plot = visual_time_bins[:-1]
mot_time_plot = motor_time_bins[:-1]



time_bins = {
    "vis": visual_time_bins,
    "mot": motor_time_bins
}

In [5]:
burst_features = pd.read_csv(burst_features_file)

In [6]:
burst_features.columns

Index(['subject', 'epoch', 'peak_time', 'peak_freq', 'peak_amp_base',
       'fwhm_freq', 'fwhm_time', 'trial', 'pp_ix', 'block', 'PC_1', 'PC_2',
       'PC_3', 'PC_4', 'PC_5', 'PC_6', 'PC_7', 'PC_8', 'PC_9', 'PC_10',
       'PC_11', 'PC_12', 'PC_13', 'PC_14', 'PC_15', 'PC_16', 'PC_17', 'PC_18',
       'PC_19', 'PC_20'],
      dtype='object')

In [7]:
PC_to_analyse = ["PC_7", "PC_8", "PC_9", "PC_10"]

In [None]:
# prct = np.linspace(0,100, num=5)
# prct_ranges = list(zip(prct[:-1], prct[1:]))

# wvfrms = {
#     k: [] for k in PC_to_analyse
# }

# for pc_ix, pc_key in enumerate(PC_to_analyse):
#     for low, hi in prct_ranges:
#         low_perc = np.percentile(burst_features[pc_key], low)
#         hi_perc = np.percentile(burst_features[pc_key], hi)
#         wvf_ixs = burst_features.loc[
#             (burst_features[pc_key] >= low_perc) &
#             (burst_features[pc_key] <= hi_perc) 
#         ].index
#         MWF = np.mean(waveform_array[wvf_ixs, :], axis=0)
#         wvfrms[pc_key].append(MWF)

In [None]:
###########################################################################################
#                                  NOT PARALLELIZED                                       #
###########################################################################################

PC_burst_rate_spec_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/PCA_results/PC_burst_rate_spec.pkl"
if op.exists(PC_burst_rate_spec_file):
    PC_burst_rate_spec = pickle.load(open(PC_burst_rate_spec_file, "rb"))

else:
    PC_burst_rate_spec = {i: {"vis": [], "mot": []} for i in PC_to_analyse}

    subjects = burst_features.subject.unique()

    for pc_key in tqdm(PC_to_analyse, colour="green"):
        comp_score = burst_features[pc_key]
        score_range = np.linspace(
            np.percentile(comp_score, 0.5),
            np.percentile(comp_score, 99.5),
            num = 41
        )
        for sub in subjects:
            sub_PC_br = {
                "vis": [],
                "mot": []
            }
            for ep in ["vis", "mot"]:
                tr_uniq_ix = burst_features.loc[
                    (burst_features.subject == sub) &
                    (burst_features.epoch == ep)
                ].pp_ix.unique()

                data_snippet = burst_features.loc[
                    (burst_features.subject == sub) &
                    (burst_features.epoch == ep)
                ][["peak_time", pc_key]]
                PC_br_all_trials = []
                for tr_ix in tqdm(tr_uniq_ix, colour="purple"):
                    peak_times = data_snippet.loc[
                        (data_snippet.pp_ix == tr_ix)
                    ].peak_time.to_numpy()
                    pc_scores = data_snippet.loc[
                        (data_snippet.pp_ix == tr_ix)
                    ][pc_key].to_numpy()

                    PC_br, t_bin, m_bin = np.histogram2d(
                        peak_times,
                        pc_scores,
                        bins = [time_bins[ep], score_range]
                    )
                    PC_br = PC_br / bin_width
                    PC_br = gaussian_filter(PC_br, [1,1])
                    PC_br_all_trials.append(PC_br)
                PC_br_all_trials = np.mean(PC_br_all_trials, axis=0)
                sub_PC_br[ep] = PC_br_all_trials

            # baselining
            bl_ix = np.where(
                (time_bins["vis"] >= baseline_range[0]) &
                (time_bins["vis"] <= baseline_range[-1])
            )[0]
            baseline = np.mean(sub_PC_br["vis"][bl_ix,:], axis=0).reshape(1, -1)
            sub_PC_br["vis"] = (sub_PC_br["vis"] - baseline) / baseline
            sub_PC_br["mot"] = (sub_PC_br["mot"] - baseline) / baseline

            for ep in ["vis", "mot"]:
                PC_burst_rate_spec[pc_key][ep].append(sub_PC_br[ep])
    pickle.dump(PC_burst_rate_spec, open(PC_burst_rate_spec_file, "wb"))

In [10]:
###########################################################################################
#                                      PARALLELIZED                                       #
###########################################################################################

def do_trials(data_snippet, pc_key, score_range, tr_ix):
    peak_times = data_snippet.loc[
        (data_snippet.pp_ix == tr_ix)
    ].peak_time.to_numpy()
    pc_scores = data_snippet.loc[
        (data_snippet.pp_ix == tr_ix)
    ][pc_key].to_numpy()

    PC_br, t_bin, m_bin = np.histogram2d(
        peak_times,
        pc_scores,
        bins = [time_bins[ep], score_range]
    )
    PC_br = PC_br / bin_width
    PC_br = gaussian_filter(PC_br, [1,1])
    PC_br_all_trials[tr_ix] = PC_br


PC_burst_rate_spec_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/PCA_results/PC_burst_rate_spec.pkl"
if op.exists(PC_burst_rate_spec_file):
    PC_burst_rate_spec = pickle.load(open(PC_burst_rate_spec_file, "rb"))

else:
    PC_burst_rate_spec = {i: {"vis": [], "mot": []} for i in PC_to_analyse}

    subjects = burst_features.subject.unique()

    for pc_key in tqdm(PC_to_analyse, colour="green"):
        for sub in tqdm(subjects, colour="purple"):
            sub_PC_br = {
                "vis": [],
                "mot": []
            }
            comp_score_all = burst_features.loc[
                (burst_features.subject == sub)
            ][pc_key]
            lower_bound = np.percentile(comp_score_all, 1)
            upper_bound = np.percentile(comp_score_all, 99)
            score_range = np.linspace(lower_bound, upper_bound, num=41)
            for ep in ["vis", "mot"]:
                data_snippet = burst_features.loc[
                    (burst_features.subject == sub) &
                    (burst_features.epoch == ep) &
                    (burst_features[pc_key] >= lower_bound) &
                    (burst_features[pc_key] <= upper_bound)
                ][["peak_time", pc_key, "pp_ix"]]
                
                tr_uniq_ix = data_snippet.pp_ix.unique()
                
                PC_br_all_trials = {}
                
                Parallel(n_jobs=20, require="sharedmem")(delayed(do_trials)(
                    data_snippet, pc_key, score_range, tr_ix
                ) for tr_ix in tr_uniq_ix);
                
                PC_br_all_trials = [PC_br_all_trials[i] for i in PC_br_all_trials.keys()]
                PC_br_all_trials = np.mean(PC_br_all_trials, axis=0)
                sub_PC_br[ep] = PC_br_all_trials

            # baselining
            bl_ix = np.where(
                (time_bins["vis"] >= baseline_range[0]) &
                (time_bins["vis"] <= baseline_range[-1])
            )[0]
            baseline = np.mean(sub_PC_br["vis"][bl_ix,:], axis=0).reshape(1, -1)
            sub_PC_br["vis"] = (sub_PC_br["vis"] - baseline) / baseline
            sub_PC_br["mot"] = (sub_PC_br["mot"] - baseline) / baseline

            for ep in ["vis", "mot"]:
                PC_burst_rate_spec[pc_key][ep].append(sub_PC_br[ep])
    pickle.dump(PC_burst_rate_spec, open(PC_burst_rate_spec_file, "wb"))

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

In [14]:
all_PCs = ["PC_{}".format(i) for i in np.arange(1, 21)]

def do_trials(data_snippet, pc_key, score_range, tr_ix):
    peak_times = data_snippet.loc[
        (data_snippet.pp_ix == tr_ix)
    ].peak_time.to_numpy()
    pc_scores = data_snippet.loc[
        (data_snippet.pp_ix == tr_ix)
    ][pc_key].to_numpy()

    PC_br, t_bin, m_bin = np.histogram2d(
        peak_times,
        pc_scores,
        bins = [time_bins[ep], score_range]
    )
    PC_br = PC_br / bin_width
    PC_br = gaussian_filter(PC_br, [1,1])
    PC_br_all_trials[tr_ix] = PC_br


all_PC_burst_rate_spec_file = "/home/mszul/datasets/explicit_implicit_beta/derivatives/PCA_results/all_PC_burst_rate_spec.pkl"
if op.exists(all_PC_burst_rate_spec_file):
    all_PC_burst_rate_spec = pickle.load(open(all_PC_burst_rate_spec_file, "rb"))

else:
    all_PC_burst_rate_spec = {i: {"vis": [], "mot": []} for i in all_PCs}

    subjects = burst_features.subject.unique()

    for pc_key in tqdm(all_PCs, colour="green"):
        for sub in tqdm(subjects, colour="purple"):
            sub_PC_br = {
                "vis": [],
                "mot": []
            }
            comp_score_all = burst_features.loc[
                (burst_features.subject == sub)
            ][pc_key]
            lower_bound = np.percentile(comp_score_all, 1)
            upper_bound = np.percentile(comp_score_all, 99)
            score_range = np.linspace(lower_bound, upper_bound, num=41)
            for ep in ["vis", "mot"]:
                data_snippet = burst_features.loc[
                    (burst_features.subject == sub) &
                    (burst_features.epoch == ep) &
                    (burst_features[pc_key] >= lower_bound) &
                    (burst_features[pc_key] <= upper_bound)
                ][["peak_time", pc_key, "pp_ix"]]
                
                tr_uniq_ix = data_snippet.pp_ix.unique()
                
                PC_br_all_trials = {}
                
                Parallel(n_jobs=20, require="sharedmem")(delayed(do_trials)(
                    data_snippet, pc_key, score_range, tr_ix
                ) for tr_ix in tr_uniq_ix);
                
                PC_br_all_trials = [PC_br_all_trials[i] for i in PC_br_all_trials.keys()]
                PC_br_all_trials = np.mean(PC_br_all_trials, axis=0)
                sub_PC_br[ep] = PC_br_all_trials

            # baselining
            bl_ix = np.where(
                (time_bins["vis"] >= baseline_range[0]) &
                (time_bins["vis"] <= baseline_range[-1])
            )[0]
            baseline = np.mean(sub_PC_br["vis"][bl_ix,:], axis=0).reshape(1, -1)
            sub_PC_br["vis"] = (sub_PC_br["vis"] - baseline) / baseline
            sub_PC_br["mot"] = (sub_PC_br["mot"] - baseline) / baseline

            for ep in ["vis", "mot"]:
                all_PC_burst_rate_spec[pc_key][ep].append(sub_PC_br[ep])
    pickle.dump(all_PC_burst_rate_spec, open(all_PC_burst_rate_spec_file, "wb"))

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]