### Configuration

In [1]:
import os
import numpy as np
import pandas as pd
import xarray as xr

import mne
import yasa
from scipy.stats import zscore
import mne_features as mf
from mne_features.feature_extraction import extract_features
from sklearn.neighbors import LocalOutlierFactor

from utils__helpers_macro import robust_zscore
import utils__config

  @nb.jit()


In [2]:
os.chdir(utils__config.working_directory)
os.getcwd()

'Z:\\Layton\\Sleep_083023'

### Parameters

In [3]:
fif_path = 'Cache/Subject05/Jul13/S05_Jul13_256hz.fif'
potato_path = 'Cache/Subject05/Jul13/S05_Jul13_potatogram.csv'
bad_channel_path = 'Cache/Subject05/Jul13/S05_bad_channels.csv'
bad_epoch_path = 'Cache/Subject05/Jul13/S05_bad_epochs.csv'

Please note that the number of samples must be whole-number divisible by (sampling_freq * epoch_length)

In [4]:
sampling_freq = 256 # Hz
epoch_length = 3 # seconds
lof_threshold = -2
chan_threshold = 3

### Load Data

In [5]:
raw = mne.io.read_raw_fif(fif_path, preload = True, verbose = False)

# Select only macroelectrodes
raw.pick_types(seeg = True, ecog = True)

# Remove bad channels
bad_channels = pd.read_csv(bad_channel_path)
bad_channels = bad_channels[bad_channels['channel'].isin(raw.ch_names)]
raw.drop_channels(ch_names = bad_channels['channel'].astype('string'))
print('Channel count after bad channel removal:', len(raw.ch_names))

# Load the upsampled hypnogram
hypnogram = np.loadtxt(potato_path, delimiter = ',')

  raw = mne.io.read_raw_fif(fif_path, preload = True, verbose = False)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Channel count after bad channel removal: 49


### Epoch Data

Add hypnogram as a channel to the Raw object

In [6]:
# Hypnogram dictionary: 
# (-2) = Unassigned
# (-1) = Artifact
# (0) = Awake
# (1) = N1
# (2) = N2
# (3) = N3
# (4) = REM 

# Re-value sleep stages for incorporation into Epochs object
hypnogram[(hypnogram != 2) & (hypnogram != 3)] = 0
hypnogram[(hypnogram == 2) | (hypnogram == 3)] = 1

# Create raw object from the hypnogram
hypnogram = hypnogram[np.newaxis, :]

hypno_info = mne.create_info(ch_names = ['hypno'], 
                             sfreq = raw.info['sfreq'], 
                             ch_types = ['misc'])

hypno = mne.io.RawArray(data = hypnogram,
                        info = hypno_info,
                        first_samp = raw.first_samp)

raw.add_channels([hypno], force_update_info = True)

Creating RawArray with float64 data, n_channels=1, n_times=9584640
    Range : 0 ... 9584639 =      0.000 ... 37439.996 secs
Ready.


0,1
Measurement date,"July 13, 2023 23:00:43 GMT"
Experimenter,Unknown
Digitized points,0 points
Good channels,"49 sEEG, 1 misc"
Bad channels,
EOG channels,Not available
ECG channels,Not available
Sampling frequency,256.00 Hz
Highpass,0.00 Hz
Lowpass,128.00 Hz


Create dummy stim data and an empty stim channel, then fill the channel with the data

In [7]:
# Record the first sample (which is not 0 since the Raw
# file was cropped from the original); you will need this
# to appropriately select the epoch sample number
start = raw.first_samp
step = sampling_freq * epoch_length
stop = raw.last_samp - step

epoch_stim = np.arange(start, stop, step)

# MNE Epochs expects a three column array where the second column
# is a dummy spacer with 0's and the third is an integer indicating
# the ID for the event. So we need to append these to our stim array.
dummy_row = np.zeros(len(epoch_stim))
event_row = np.ones(len(epoch_stim))

epoch_stim = np.vstack((epoch_stim, dummy_row, event_row)).transpose()

# Create a dummy numpy event array and MNE info object
# and use them to create an empty dummy Raw channel
events_info = mne.create_info(ch_names = ['epoch_stim'], 
                              sfreq = raw.info['sfreq'], 
                              ch_types = ['stim'])

empty_events = np.zeros((1, len(raw.times)))

events_channel = mne.io.RawArray(empty_events, events_info)

# Create an event dictionary
event_dictionary = {'epoch_start' : 1}

Creating RawArray with float64 data, n_channels=1, n_times=9584640
    Range : 0 ... 9584639 =      0.000 ... 37439.996 secs
Ready.


Epoch data using dummy stim data in the new channel

In [8]:
# Update the empty stim channel with the formatted epoch start times
raw.add_channels([events_channel], force_update_info = True)
raw.add_events(epoch_stim, 'epoch_stim')

# Find events and create epochs
events = mne.find_events(raw, 
                         stim_channel = 'epoch_stim', 
                         shortest_event = sampling_freq * epoch_length,
                         initial_event = True)

epochs = mne.Epochs(raw, 
                    preload = True, 
                    events = events, 
                    event_id = event_dictionary, 
                    baseline = None,
                    verbose = True,
                    tmin = 0, 
                    tmax = epoch_length)

# Drop the event channel before exporting data
epochs = epochs.drop_channels(['epoch_stim'])

12479 events found
Event IDs: [1]
Not setting metadata
12479 matching events found
No baseline correction applied
0 projection items activated
Using data from preloaded Raw for 12479 events and 769 original time points ...
0 bad epochs dropped


Select epochs with more than 50% N2/3 sleep and save their sample number start times

In [9]:
# Get epoched hypnogram and get SWS% per epoch
hypochs = epochs.get_data(picks = ['hypno']).squeeze()
hypochs = hypochs.mean(axis = 1).transpose()
hypochs = pd.DataFrame(hypochs, columns = ['hypno_score'])

# Keep epochs with more than 50% of SWS
nopochs = pd.Series(hypochs[hypochs['hypno_score'] <= 0.50].index)
hypochs = pd.Series(hypochs[hypochs['hypno_score'] > 0.50].index)

# Remove hypno channel and get data
epochs = epochs.drop_channels(['hypno'])
data = epochs.get_data()

# Delete non-SWS epochs
data = np.delete(data, nopochs, axis = 0)

### Feature Extraction

In [10]:
# Extract scalar features:
# Input is (n_epochs, n_channels, n_times)
# Output is (n_epochs, n_channels * n_features)
mf_scalar = extract_features(data, 
                             sfreq = sampling_freq, 
                             selected_funcs = ['mean', 'variance', 'std', 'ptp_amp', 
                                               'skewness', 'kurtosis', 'rms', 'quantile', 
                                               'zero_crossings', 'hurst_exp'], 
                             ch_names = epochs.ch_names, 
                             return_as_df = True, 
                             n_jobs = -1)

In [11]:
# Extract array features:
# Output array of (n_epochs, n_channels * n_features)
# but note that the second dimension will be a multi-index
freq_bands = np.array([0.5, 4, 8, 13, 30, 50, 70, 100, sampling_freq/2])

mf_params = {'pow_freq_bands__freq_bands' : freq_bands, 
             'spect_slope__fmin' : 0.3,
             'spect_slope__fmax' : sampling_freq/2}

mf_array = extract_features(data, 
                            sfreq = sampling_freq, 
                            selected_funcs = ['pow_freq_bands', 'spect_slope'], 
                            funcs_params = mf_params, 
                            ch_names = epochs.ch_names, 
                            return_as_df = True, 
                            n_jobs = -1)

### Feature Munging

In [12]:
# Scalar features - Reshape multi-index to long
mf_scalar = mf_scalar.unstack().reset_index()
mf_scalar.columns = ['feature', 'channel', 'epoch', 'value']

In [13]:
# Array features - Reshape multi-index to long
mf_array = mf_array.unstack().reset_index()
mf_array.columns = ['long_feature', 'elec_feature', 'epoch', 'value']

# Extract channel and feature names from the aggregate output column
# (Note that channels with underscores in their name will break this)
mf_array[['channel', 'feature']] = mf_array['elec_feature'].str.split('_', expand = True)

# Convert to xarray
mf_array = mf_array[['channel', 'epoch', 'feature', 'value']]

In [14]:
# Append the scalar and array feature datasets:
mf_feats = pd.concat([mf_scalar, mf_array])

# Normalize features for comparability (robust z-score)
#mf_feats['zscore'] = mf_feats.groupby(['feature'])['value'].transform(zscore)
mf_feats['rzscore'] = mf_feats.groupby(['feature'])['value'].transform(robust_zscore)
mf_feats.drop(columns = ['value'], inplace = True)

# Convert to xarray
mf_feats = mf_feats.set_index(['channel', 'epoch', 'feature'])
mf_feats = mf_feats.to_xarray()

# Save meta-data for later
channels = mf_feats.channel.to_pandas()
epochs = mf_feats.epoch.to_pandas()

# Convert to numpy
mf_feats = mf_feats.to_array().to_numpy().squeeze()

### Local Outlier Factor

Channel-wise LOF

In [15]:
for channel in range(mf_feats.shape[0]):

    chan_lof = LocalOutlierFactor()
    _ = chan_lof.fit_predict(mf_feats[channel, :, :])
    chan_lof = chan_lof.negative_outlier_factor_

    if channel == 0:
        lof = chan_lof[np.newaxis, :]

    else:
        lof = np.append(lof, chan_lof[np.newaxis, :], axis = 0)

In [16]:
# Convert from wide to long format
lof = pd.DataFrame(lof, index = channels, columns = epochs)
lof = lof.reset_index()

lof = lof.melt(id_vars = ['index'], var_name = 'epoch', value_name = 'lof')
lof.columns = ['channel', 'epoch', 'lof']

### Rejection Thresholding

Epoch rejection is all-or-none. An epoch must be deleted from all channels if it is rejected at all. Thus, if an epoch is selected as abnormal in more than N channels, the epoch times will be marked and saved to later remove any sleep events detected during those times.

In [17]:
# Create a table of channel counts with LOF < threshold for each epoch
epoch_ct = lof[['epoch', 'lof']].loc[lof['lof'] < lof_threshold].groupby(['epoch']).count().reset_index()

# Get epochs over the channel count threshold, these will be rejected
below_lof = epoch_ct.loc[epoch_ct['lof'] > chan_threshold, 'epoch']

### Save rejected epochs for exclusion

In [18]:
# Get epoch sample numbers
bad_epochs = pd.DataFrame(epoch_stim[:,0], columns = ['start'])
bad_epochs['stop'] = bad_epochs['start'] + (sampling_freq * epoch_length)

# Pad bad epochs by 1 second on either side, then
# convert to int64 for use in indexing
bad_epochs['start'] = bad_epochs['start'] - (sampling_freq * 1)
bad_epochs['stop'] = bad_epochs['stop'] - (sampling_freq * 1)

bad_epochs['start'] = bad_epochs['start'].astype('int64')
bad_epochs['stop'] = bad_epochs['stop'].astype('int64')

# Reset index to get the original hypnogram epoch index
bad_epochs = bad_epochs.reset_index()
bad_epochs.columns = ['hypno_epoch', 'start', 'stop']

# Select only SWS epochs
bad_epochs = bad_epochs[bad_epochs['hypno_epoch'].isin(hypochs)]

# Reset index twice (after keeping only SWS epochs)
# to get the equivalent LOF epoch index, which is
# different from the hypno epoch since its a reset subset
bad_epochs = bad_epochs.reset_index().reset_index()
bad_epochs.columns = ['lof_epoch', 'dummy_epoch', 'hypno_epoch', 'start', 'stop']
bad_epochs.drop(columns = ['dummy_epoch'], inplace = True)

# Keep epochs that were selected by LOF as bad
bad_epochs = bad_epochs[bad_epochs['lof_epoch'].isin(below_lof)]

# # Create a zero-filled 1d array with length of
# # total number of samples, and set value to 1
# # at time points corresponding to the rejected
# # epochs (with their padding added)
# bad_segments = np.zeros(len(raw))

# for index, row in bad_epochs.iterrows():
#     bad_segments[row['start']:row['stop']] = 1

# Convert sample numbers to seconds
bad_epochs['start_time'] = (bad_epochs['start'] / sampling_freq) - (raw.first_samp / sampling_freq)
bad_epochs['stop_time'] = (bad_epochs['stop'] / sampling_freq) - (raw.first_samp / sampling_freq)

# Save to CSV
bad_epochs.to_csv(bad_epoch_path, index = False)

In [19]:
bad_epochs

Unnamed: 0,lof_epoch,hypno_epoch,start,stop,start_time,stop_time
133,133,657,504320,505088,1970.0,1973.0
545,545,2430,1865984,1866752,7289.0,7292.0
570,570,2459,1888256,1889024,7376.0,7379.0
1538,1538,3529,2710016,2710784,10586.0,10589.0
1712,1712,4185,3213824,3214592,12554.0,12557.0
2213,2213,4687,3599360,3600128,14060.0,14063.0
2532,2532,5007,3845120,3845888,15020.0,15023.0
2957,2957,6114,4695296,4696064,18341.0,18344.0
3977,3977,7138,5481728,5482496,21413.0,21416.0
4050,4050,7214,5540096,5540864,21641.0,21644.0
