### Configuration

In [1]:
import os
import re
import numpy as np
import pandas as pd
from scipy import io
import hdf5storage
from tqdm import tqdm

import utils__config



In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'G:\\My Drive\\Residency\\Research\\Lab - Damisah\\Project - Sleep\\Revisions'

### Parameters

In [None]:
root_dir = 'Cache/Subject01/Feb02/S01_Feb02_cnato__5___yemi_to_mat'
dict_dir = 'Data/Subject01/S01_dictionary.xlsx'
metric_dir = 'Cache/Subject01/Feb02/S01_spike_metrics.csv'
out_dir = 'Cache/Subject01/Feb02/S01_waveforms_sampled.csv'

# root_dir = 'Cache/Subject05/Jul11/S05_Jul11_cnato__4___yemi_to_mat'
# dict_dir = 'Data/Subject05/S05_dictionary.xlsx'
# metric_dir = 'Cache/Subject05/Jul11/S05_spike_metrics.csv'
# out_dir = 'Cache/Subject05/Jul11/S05_waveforms_sampled.csv'

# root_dir = 'Cache/Subject05/Jul12/S05_Jul12_cnato__4___yemi_to_mat'
# dict_dir = 'Data/Subject05/S05_dictionary.xlsx'
# metric_dir = 'Cache/Subject05/Jul12/S05_spike_metrics.csv'
# out_dir = 'Cache/Subject05/Jul12/S05_waveforms_sampled.csv'

# root_dir = 'Cache/Subject05/Jul13/S05_Jul13_cnato__4___yemi_to_mat'
# dict_dir = 'Data/Subject05/S05_dictionary.xlsx'
# metric_dir = 'Cache/Subject05/Jul13/S05_spike_metrics.csv'
# out_dir = 'Cache/Subject05/Jul13/S05_waveforms_sampled.csv'

In [None]:
samples = 1000
MAT_version = '7.3' # MAT file version (SciPy reads < 7.3, hdf5storage reads >= 7.3)

### Munging

In [None]:
micro_dict = pd.read_excel(dict_dir)
micro_dict = micro_dict[['number', 'laterality', 'region']]

In [None]:
data = pd.DataFrame()

for channel in tqdm(os.listdir(root_dir)):

    if MAT_version == '7.3':
        raw_data = hdf5storage.loadmat(os.path.join(root_dir, channel))
    else:
        raw_data = io.loadmat(os.path.join(root_dir, channel))

    chan_data = pd.DataFrame()

    # Extract unit type and spike times
    for unit in np.arange(0, len(raw_data['sp_types'])):

        # Extract and format spike waveforms
        unit_data = pd.DataFrame(raw_data['sp_waveforms'][unit][0])
        unit_data = unit_data.reset_index()
        unit_data['index'] = unit_data['index'] + 1
        unit_data = pd.melt(unit_data, id_vars = 'index')
        unit_data.columns = ['time_point', 'spike_id', 'amplitude']

        # Skip iteration if empty
        if len(unit_data) == 0:
            continue

        # Extract and merge spike times
        unit_times = pd.DataFrame(raw_data['sp_times'][unit][0])
        unit_times.columns = ['milliseconds']
        unit_times['spike_id'] = unit_times.index
        unit_data = pd.merge(unit_data, unit_times, on='spike_id', how='inner')

        # Randomly sample unique spike ids if their count is greater than 'samples'
        unique_spike_ids = unit_data['spike_id'].unique()
        
        if len(unique_spike_ids) > samples:
            sampled_spike_ids = np.random.choice(unique_spike_ids, size=samples, replace=False)
            unit_data = unit_data[unit_data['spike_id'].isin(sampled_spike_ids)]

        # Set unit meta-data
        unit_data['unit_type'] = raw_data['sp_types'][unit][0]
        unit_data['unit_num'] = unit + 1

        # Merge up a level
        chan_data = pd.concat([chan_data, unit_data])
    
    # Set channel meta-data
    chan_data['subject'] = channel.split('_')[0]

    chan_data['channel'] = channel.split('_')[1]
    chan_data['channel'] = chan_data['channel'].str.split('l', expand = True)[1]
    chan_data['channel'] = chan_data['channel'].astype('int64') # to merge with micro_dict

    chan_data['sign'] = channel.split('_')[2]
    chan_data['sign'] = chan_data['sign'].str.split('.', expand = True)[0]

    # Merge up a level
    data = pd.concat([data, chan_data])

In [None]:
# Merge with dictionary meta-data
data = data.merge(micro_dict, left_on = 'channel', right_on = 'number')

# Account for the offset in unit number between Combinato and MATLAB
# so that you can compare units between Combinato GUI and your analysis (optional)
data['unit_num'] = data['unit_num'] - 1

# Create a unique unit ID
data['unit_id'] = data['subject'] + '_Ch' + data['channel'].astype('str') + '_' + data['sign'] + '_Unit' + data['unit_num'].astype('str')

# Rename laterality/region columns to specify that they apply to the unit
data.rename(columns = {'laterality' : 'unit_laterality', 'region' : 'unit_region'}, inplace = True)

### Quality Control

In [None]:
# Remove artifactual units
# (artifact = -1 | unassigned = 0 | MUA = 1 | SUA = 2)
#data = data[data['unit_type'] == 2] # SUA only
data = data[(data['unit_type'] != -1) & (data['unit_type'] != 0)] # SUA + MUA

# Keep only units from CLA, AMY, ACC, or aINS
data = data[(data['unit_region'] == 'CLA') | (data['unit_region'] == 'AMY') | 
            (data['unit_region'] == 'ACC') | (data['unit_region'] == 'aINS')]

# Keep only units selected by the quality control script
spike_metrics = pd.read_csv(metric_dir)
data = data[data['unit_id'].isin(spike_metrics['unit_id'])]

### Export

In [None]:
# Save to CSV
data = data.drop(['unit_type', 'unit_num', 'subject', 'channel', 'number', 'sign'], axis=1)
data.to_csv(out_dir, index = False)