### Configuration

In [1]:
import os
import re
import numpy as np
import pandas as pd
from scipy import io
import hdf5storage
from tqdm import tqdm

import utils__config

In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'Z:\\Layton\\Sleep_083023'

### Parameters

In [3]:
root_dir = 'Cache/Subject01/Feb02/S01_Feb02_cnato__5___yemi_to_mat'
dict_dir = 'Data/Subject01/S01_dictionary.xlsx'
ld_dir = 'Cache/Subject01/Feb02/S01_logdensity_forms.csv'
metric_dir = 'Cache/Subject01/Feb02/S01_spike_metrics.csv'
out_dir = 'Cache/Subject01/Feb02/S01_spikeforms.csv'

In [4]:
MAT_version = '7.3' # MAT file version (SciPy reads < 7.3, hdf5storage reads >= 7.3)

### Munging

In [5]:
micro_dict = pd.read_excel(dict_dir)
micro_dict = micro_dict[['number', 'laterality', 'region']]

In [6]:
data = pd.DataFrame() # average waveform
data_ld = pd.DataFrame() # log density waveform

for channel in tqdm(os.listdir(root_dir)):

    if MAT_version == '7.3':
        raw_data = hdf5storage.loadmat(os.path.join(root_dir, channel))
    else:
        raw_data = io.loadmat(os.path.join(root_dir, channel))

    chan_data = pd.DataFrame()
    chan_data_ld = pd.DataFrame()

    # Extract unit type and spike times
    for unit in np.arange(0, len(raw_data['sp_types'])):

        # Extract and format spike waveforms
        unit_data = pd.DataFrame(raw_data['sp_waveforms'][unit][0])
        unit_data = unit_data.reset_index()
        unit_data['index'] = unit_data['index'] + 1
        unit_data = pd.melt(unit_data, id_vars = 'index')
        unit_data.columns = ['time_point', 'spike_id', 'amplitude']

        # Define the bins for the histogram
        x_bins = np.arange(1, 65)
        y_bins = np.linspace(unit_data['amplitude'].min(), unit_data['amplitude'].max(), 101)  # 101 to get 100 bins

        # Compute the 2D histogram for log density
        histogram, xedges, yedges = np.histogram2d(unit_data['time_point'], unit_data['amplitude'], bins=(x_bins, y_bins))
        log_density = np.log(histogram + 1)  # Add 1 to avoid log(0)

        # Create a meshgrid for xedges and yedges
        x, y = np.meshgrid(xedges[:-1], yedges[:-1], indexing='ij')

        # Flatten the arrays and create a dataframe
        unit_data_ld = pd.DataFrame({
            'time_point': x.flatten(),
            'amplitude': y.flatten(),
            'log_density': log_density.flatten()})

        # Average waveforms (optional step if using too much RAM)
        unit_data = unit_data.groupby(['time_point']).mean().reset_index()

        # Extract some meta-data
        unit_type = raw_data['sp_types'][unit][0]

        # Set meta-data & merge: AVERAGE WAVEFORM
        unit_data['unit_type'] = unit_type
        unit_data['unit_num'] = unit + 1
        chan_data = pd.concat([chan_data, unit_data])

        # Set meta-data & merge: LOG DENSITY WAVEFORM
        unit_data_ld['unit_type'] = unit_type
        unit_data_ld['unit_num'] = unit + 1
        chan_data_ld = pd.concat([chan_data_ld, unit_data_ld])
    
    # Extract channel meta-data for AVERAGE WAVEFORM
    chan_data['subject'] = channel.split('_')[0]
    chan_data['channel'] = channel.split('_')[1]
    chan_data['channel'] = chan_data['channel'].str.split('l', expand = True)[1]
    chan_data['channel'] = chan_data['channel'].astype('int64') # to merge with micro_dict
    chan_data['sign'] = channel.split('_')[2]
    chan_data['sign'] = chan_data['sign'].str.split('.', expand = True)[0]
    data = pd.concat([data, chan_data])

    # Extract channel meta-data for LOG DENSITY WAVEFORM
    chan_data_ld['subject'] = channel.split('_')[0]
    chan_data_ld['channel'] = channel.split('_')[1]
    chan_data_ld['channel'] = chan_data_ld['channel'].str.split('l', expand = True)[1]
    chan_data_ld['channel'] = chan_data_ld['channel'].astype('int64') # to merge with micro_dict
    chan_data_ld['sign'] = channel.split('_')[2]
    chan_data_ld['sign'] = chan_data_ld['sign'].str.split('.', expand = True)[0]
    data_ld = pd.concat([data_ld, chan_data_ld])

100%|██████████| 84/84 [02:10<00:00,  1.56s/it]


In [7]:
# Merge with dictionary meta-data
data = data.merge(micro_dict, left_on = 'channel', right_on = 'number')
data_ld = data_ld.merge(micro_dict, left_on = 'channel', right_on = 'number')

# Account for the offset in unit number between Combinato and MATLAB
# so that you can compare units between Combinato GUI and your analysis (optional)
data['unit_num'] = data['unit_num'] - 1
data_ld['unit_num'] = data_ld['unit_num'] - 1

# Create a unique unit ID
data['unit_id'] = data['subject'] + '_Ch' + data['channel'].astype('str') + '_' + data['sign'] + '_Unit' + data['unit_num'].astype('str')
data_ld['unit_id'] = data_ld['subject'] + '_Ch' + data_ld['channel'].astype('str') + '_' + data_ld['sign'] + '_Unit' + data_ld['unit_num'].astype('str')

# Rename laterality/region columns to specify that they apply to the unit
data.rename(columns = {'laterality' : 'unit_laterality', 'region' : 'unit_region'}, inplace = True)
data_ld.rename(columns = {'laterality' : 'unit_laterality', 'region' : 'unit_region'}, inplace = True)

### Quality Control

In [8]:
# Remove artifactual units
# (artifact = -1 | unassigned = 0 | MUA = 1 | SUA = 2)
#data = data[data['unit_type'] == 2] # SUA only
data = data[(data['unit_type'] != -1) & (data['unit_type'] != 0)] # SUA + MUA
data_ld = data_ld[(data_ld['unit_type'] != -1) & (data_ld['unit_type'] != 0)] # SUA + MUA

# Keep only units from CLA, AMY, ACC, or aINS
data = data[(data['unit_region'] == 'CLA') | (data['unit_region'] == 'AMY') | 
            (data['unit_region'] == 'ACC') | (data['unit_region'] == 'aINS')]

data_ld = data_ld[(data_ld['unit_region'] == 'CLA') | (data_ld['unit_region'] == 'AMY') | 
            (data_ld['unit_region'] == 'ACC') | (data_ld['unit_region'] == 'aINS')]

# Keep only units selected by the quality control script
spike_metrics = pd.read_csv(metric_dir)
data = data[data['unit_id'].isin(spike_metrics['unit_id'])]
data_ld = data_ld[data_ld['unit_id'].isin(spike_metrics['unit_id'])]

### Average Waveforms and Export

In [9]:
grouping_vars = ['subject', 'unit_id', 'unit_region', 
                 'unit_laterality', 'sign', 'unit_type']

# Find mean spike waveform (use if not using the optional group average in the loop cell)
#data = data[(grouping_vars + ['time_point', 'amplitude'])]
#data = data.groupby((grouping_vars + ['time_point'])).mean().reset_index()

# Merge with spike metric meta-data
spike_metrics = pd.read_csv(metric_dir)
spike_metrics = spike_metrics[['unit_id', 'num_count', 'perc_isi_violations']]
spike_metrics.columns = ['unit_id', 'count', 'isi']
data = data.merge(spike_metrics, left_on = 'unit_id', right_on = 'unit_id')

# Find the min/max/abs_max of mean waveforms
min_max = data.groupby(grouping_vars).agg({'amplitude' : ['min', 'max']}).reset_index()
min_max = min_max[['unit_id', 'amplitude']]
min_max.columns = min_max.columns.droplevel()
min_max.columns = ['unit_id', 'min', 'max']
min_max['abs'] = min_max[['min', 'max']].abs().max(axis = 1)
data = data.merge(min_max, on = 'unit_id')

In [10]:
# Save to CSV
data.to_csv(out_dir, index = False)
data_ld.to_csv(ld_dir, index = False)