In [None]:
import sys
sys.path.append('../code')
import utils, spike_train_functions, lfp_functions

import numpy as np
import matplotlib.pyplot as plt
import scipy
import scipy.io as sio
import pandas as pd
from bycycle import BycycleGroup
from bycycle.utils import get_extrema_df

from scipy.stats import wasserstein_distance
from sklearn.manifold import TSNE
from sklearn.cluster import KMeans

import seaborn as sns


In [None]:
data_path = '../data/'
session_dict = {
    'SPK121107_MI_PMd_TT_NOGO_ob001': {
        'event_file': '_ev_explicit.mat',
        'spike_file': '_DSXII-SORTED-01_gNTSNR.mat',
        'lfp_file': '_30x_downsample.mat',
        'map_file': 'SPK_map_MI_PMd.mat'},
    # 'SPK121107_PMv_TT_NOGO_ob001': {
    #     'event_file': '_ev_explicit.mat',
    #     'spike_file': '_DSXII-SORTED-02_gNTSNR.mat',
    #     'lfp_file': '_30x_downsample.mat'},
    'RUSRH120619_MI_PMd_TT_KG_TC_NOGO_Ob001': {
        'event_file': '_ev_explicit.mat',
        'spike_file': '_DSXI_corrected_gNTSNR_split.mat',
        'lfp_file': '_30x_downsample.mat',
        'map_file': 'RUS_map_MI_PMd.mat'},
    
}

In [None]:
def process_session(session_name, session_dict):
    # Load spiking data and electrode map

    fpath = f'../data/{session_name}'
    # unit_fname = f'{fpath}/{session_name}_DSXI_corrected_gNTSNR_split.mat'
    unit_fname = f'{fpath}/{session_name}_DSXII-SORTED-01_gNTSNR.mat'

    event_fname = f'{fpath}/{session_name}_ev_explicit'
    lfp_fname = f'{fpath}/{session_name}_30x_downsample.mat'
    
    # Load spiking data and electrode map
    emap_dict = {'label_idx': np.stack(mapping_dict['map_struct'][0]['Num']).squeeze() - 1,
               'row': np.stack(mapping_dict['map_struct'][0]['Row']).squeeze(),
               'col': np.stack(mapping_dict['map_struct'][0]['Column']).squeeze(),
               'area': np.stack(mapping_dict['map_struct'][0]['SubArrayName']).squeeze()}
    emap_df_full = pd.DataFrame(emap_dict)

    single_unit_dict = sio.loadmat(unit_fname)

    sorted_timestamps = single_unit_dict['sorted_timestamps'][0]
    unit_timestamps = [sorted_timestamps[unit_idx].squeeze() for unit_idx in range(len(sorted_timestamps))]
    num_units = len(unit_timestamps)
    unit_names = np.array(range(num_units))

    unit_electrodes = single_unit_dict['unit_index'][0,:] - 1
    unit_areas = [emap_df_full['area'].values[elec_idx] for elec_idx in unit_electrodes]
    
    samp_freq = 1000
    nsx_dict = sio.loadmat(lfp_fname)
    nsx_duration = nsx_dict['out'][0][0]['MetaTags'][0][0]['DataDurationSec']

    lfp_data_raw = nsx_dict['out'][0][0]['Data']
    lfp_times_raw = np.linspace(0, nsx_duration, lfp_data_raw.shape[1]).squeeze()
