In [1]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
import platform
import tkinter as tk
from tkinter import filedialog
from scipy.signal import butter, filtfilt

from neuropy.io.neuroscopeio import NeuroscopeIO
from neuropy.io.binarysignalio import BinarysignalIO 
from neuropy.io.miniscopeio import MiniscopeIO
from neuropy.core import Epoch
from neuropy.utils import plot_util
from neuropy.plotting.spikes import plot_raster
from neuropy.plotting.signals import plot_signal_w_epochs

sys.path.insert(1, 'C:/BrianKim/Code/Repositories/cnn-ripple/src/cnn/')

In [2]:
# Define a class for a typical recording or set of recordings
class ProcessData:
    def __init__(self, basepath):
        basepath = Path(basepath)
        self.basepath = basepath
        xml_files = sorted(basepath.glob("*.xml"))
        assert len(xml_files) == 1, "Found more/less than one .xml file"
        
        fp = xml_files[0].with_suffix("")
        self.filePrefix = fp
        
        self.recinfo = NeuroscopeIO(xml_files[0])
        eegfiles = sorted(basepath.glob('*.eeg'))
        assert len(eegfiles) == 1, "Fewer/more than one .eeg file detected"
        self.eegfile = BinarysignalIO(eegfiles[0], n_channels=self.recinfo.n_channels,
                                     sampling_rate=self.recinfo.eeg_sampling_rate,
                                     )
        try:
            self.datfile = BinarysignalIO(eegfiles[0].with_suffix('.dat'),
                                         n_channels=self.recinfo.n_channels,
                                         sampling_rate=self.recinfo.dat_sampling_rate,
                                         )
        except FileNotFoundError:
            print('No dat file found, not loading')
                
        
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.recinfo.source_file.name})"
    
def sess_use(basepath=os.getcwd()):
    """Load in data. Uses current directory as default"""

    return ProcessData(basepath)

In [3]:
# Open a directory chooser dialog
dir_use = filedialog.askdirectory(title="Please select a data folder")

# Check if user selected a directory or pressed cancel
if dir_use:
    print(f"Selected Data Directory: {dir_use}")
else:
    print("No directory was selected.")

sess = sess_use(dir_use)

print(sess.recinfo)
print(sess.eegfile)


Selected Data Directory: D:/Data/RippleDetection/Orange/20230905
No dat file found, not loading
filename: D:\Data\RippleDetection\Orange\20230905\Orange_20230905.xml 
# channels: 35
sampling rate: 30000
lfp Srate (downsampled): 1250

duration: 601.19 seconds 
duration: 0.17 hours 


In [4]:
from neuropy.analyses.artifact import detect_artifact_epochs

cur_file = 'D:\\Data\\RippleDetection\\Orange\\20230905\\Orange_20230905.eeg'

In [5]:
binary_data = BinarysignalIO(cur_file, n_channels=32,sampling_rate=30000)
signal_obj = binary_data.get_signal()
whole_data = signal_obj.traces
print(np.shape(whole_data))
data = whole_data.T
print(np.shape(data))

(32, 821940)
(821940, 32)


In [6]:
signal = sess.eegfile.get_signal()
buffer_add = 0.1  # seconds, None = don't add


art_epochs_file = sess.filePrefix.with_suffix(".art_epochs.npy")
if art_epochs_file.exists():
    art_epochs = Epoch(epochs=None, file=art_epochs_file)
    print('Existing artifact epochs file loaded')
else:
    art_epochs = detect_artifact_epochs(signal, thresh=6, 
                                    edge_cutoff=1, merge=6)

    if buffer_add is not None:  # Add in buffer to prevent erroneous detection of start/stop of artifact as SWRs
        art_epochs.add_epoch_buffer(buffer_add)
    sess.recinfo.write_epochs(epochs=art_epochs, ext='art')  # Write to neuroscope
    art_epochs.save(art_epochs_file)
art_epochs

Buffer of 0.1 added before/after each epoch
D:\Data\RippleDetection\Orange\20230905\Orange_20230905.art_epochs.npy saved


1 epochs
Snippet: 
       start      stop label
0  338.9664  339.2272      

In [12]:
# Helper Functions
def filter0(b, x):
    if x.shape[0] == 1:
        x = np.transpose(x)
    
    if len(b) % 2 != 1:
        raise ValueError("Filter order should be odd")
    
    shift = (len(b) - 1) // 2
    
    from scipy.signal import lfilter
    y0 = lfilter(b, [1], x)
    
    z = np.zeros(shift)
    
    y = np.concatenate((y0[shift:], z[:shift]))
    return y


def unity(A):
    meanA = np.mean(A)
    stdA = np.std(A)
    U = (A - meanA) / stdA
    return U, stdA


def in_intervals(timestamps, restrict):
    # Assuming `restrict` is an Nx2 array, where each row is [start, stop] interval
    keep = np.any([(timestamps >= start) & (timestamps <= stop) for start, stop in restrict], axis=0)
    return keep


def butter_bandpass_filter(data, lowcut, highcut, fs, order=3):
    b, a = butter(order, [lowcut, highcut], btype='band', fs=fs)
    filtered_data = filtfilt(b, a, data)
    return filtered_data

[3.23068691e+02 5.05080773e+06 1.14270444e+07 ... 6.45454946e+07
 6.94142373e+07 6.41030022e+07]
[0. 0. 0. 0. 0.]


In [18]:
def bz_FindRipples(basepath=None, channel=None, lfp_data=None, timestamps=None, **kwargs):
    """
    Find hippocampal ripples (100~200Hz oscillations).
    """
    # Default parameters
    params = {
        'thresholds': [2, 5],
        'durations': [30, 100],
        'restrict': None,
        'frequency': 1250,
        'stdev': None,
        'show': 'off',
        'noise': None,
        'passband': [130, 200],
        'EMGThresh': 0.9,
        'saveMat': False,
        'minDuration': 20,
        'plotType': 2
    }
    
    # Update parameters with provided kwargs
    for key, value in kwargs.items():
        if key in params:
            params[key] = value
    
    # Extract and assign parameters
    frequency = params['frequency']
    restrict = params['restrict']
    sd = params['stdev']
    passband = params['passband']
    
    signal = butter_bandpass_filter(data, passband[0], passband[1], frequency, order=3)
    
    
    windowLength = 11
    # 
    squaredSignal = signal ** 2
    window = np.ones(windowLength) / windowLength
    test = np.sum(squaredSignal, axis=1)
    filter0(window, test)
    normalizedSquaredsignal, sd = unity(filter0(window, np.sum(squaredSignal, axis=1)))
    
    return normalizedSquaredsignal, sd

In [20]:
lfp_data = np.copy(data[:,24])

normalizedSquaredsignal, sd = bz_FindRipples(lfp_data=lfp_data)

print(normalizedSquaredsignal)
print(sd)

[-3.30351022 -2.60833883 -1.9287399  ... -6.81899092 -6.81899092
 -6.81899092]
9330154.266621364
