In [1]:
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from pathlib import Path
import platform
import tkinter as tk
from tkinter import filedialog

from neuropy.io.neuroscopeio import NeuroscopeIO
from neuropy.io.binarysignalio import BinarysignalIO 
from neuropy.io.miniscopeio import MiniscopeIO
from neuropy.core import Epoch
from neuropy.utils import plot_util
from neuropy.plotting.spikes import plot_raster
from neuropy.plotting.signals import plot_signal_w_epochs

sys.path.insert(1, 'C:/BrianKim/Code/Repositories/cnn-ripple/src/cnn/')

In [2]:
# Define a class for a typical recording or set of recordings
class ProcessData:
    def __init__(self, basepath):
        basepath = Path(basepath)
        self.basepath = basepath
        xml_files = sorted(basepath.glob("*.xml"))
        assert len(xml_files) == 1, "Found more/less than one .xml file"
        
        fp = xml_files[0].with_suffix("")
        self.filePrefix = fp
        
        self.recinfo = NeuroscopeIO(xml_files[0])
        eegfiles = sorted(basepath.glob('*.eeg'))
        assert len(eegfiles) == 1, "Fewer/more than one .eeg file detected"
        self.eegfile = BinarysignalIO(eegfiles[0], n_channels=self.recinfo.n_channels,
                                     sampling_rate=self.recinfo.eeg_sampling_rate,
                                     )
        try:
            self.datfile = BinarysignalIO(eegfiles[0].with_suffix('.dat'),
                                         n_channels=self.recinfo.n_channels,
                                         sampling_rate=self.recinfo.dat_sampling_rate,
                                         )
        except FileNotFoundError:
            print('No dat file found, not loading')
                
        
    def __repr__(self) -> str:
        return f"{self.__class__.__name__}({self.recinfo.source_file.name})"
    
def sess_use(basepath=os.getcwd()):
    """Load in data. Uses current directory as default"""

    return ProcessData(basepath)

In [3]:
# Open a directory chooser dialog
dir_use = filedialog.askdirectory(title="Please select a data folder")

# Check if user selected a directory or pressed cancel
if dir_use:
    print(f"Selected Data Directory: {dir_use}")
else:
    print("No directory was selected.")



# dir_use = '/home/kimqi/Documents/Data/Orange/20230905'  # Add directory here!
sess = sess_use(dir_use)

print(sess.recinfo)
print(sess.eegfile)

Selected Data Directory: D:/Data/RippleDetection/20230309Recall1
No dat file found, not loading
filename: D:\Data\RippleDetection\20230309Recall1\Django_recall1_denoised.xml 
# channels: 134
sampling rate: 30000
lfp Srate (downsampled): 1250

duration: 4532.47 seconds 
duration: 1.26 hours 


In [35]:
cur_file = 'D:/Data/RippleDetection/20230309Recall1/Django_recall1_denoised.eeg'
binary_data = BinarysignalIO(cur_file, n_channels=134,sampling_rate=30000)
signal_obj = binary_data.get_signal()
whole_data = signal_obj.traces

# Remove Aux channels
# Currently only using the first 8 channels as a proof of concept
data = np.copy(whole_data[54:62,:])
data = data.T
print(data)
print(data.shape)



[[ 167  190  189 ...   48   21    5]
 [ 320  265  226 ...  -91 -128 -115]
 [ 199  153   91 ... -228 -216 -220]
 ...
 [-692 -694 -683 ... -665 -595 -558]
 [-873 -905 -864 ... -887 -798 -748]
 [-600 -645 -609 ... -612 -551 -522]]
(5665590, 8)


In [36]:
# Downsample and Z-score
from load_data import z_score_normalization, downsample_data


# Normalize it with z-score
print("Normalizing data...", end=" ")
data = z_score_normalization(data)
print("Done!")

print("Shape of loaded data after z-score: ", np.shape(data))

Normalizing data... Done!
Shape of loaded data after z-score:  (5665590, 8)


In [37]:
overlapping = True
window_size = 0.0128
print("Generating Windows...", end=" ")
if overlapping:
    from load_data import generate_overlapping_windows
    
    stride = 0.0064
    
    # Separate the data into 12.8ms windows with 6.4ms overlapping
    X = generate_overlapping_windows(data, window_size, stride, 1250)
else:
    stride = window_size
    X = np.expand_dims(data, 0)
print("Done!")

Generating Windows... Done!


In [38]:
import tensorflow.keras.backend as K
import tensorflow.keras as kr



print("Loading CNN model...", end=" ")
model_path = r'C:\BrianKim\Code\Repositories\cnn-ripple\model'

optimizer = kr.optimizers.Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, amsgrad=False)
model = kr.models.load_model(model_path, compile=False)
model.compile(loss="binary_crossentropy", optimizer=optimizer)

print("Done!")





Done!


In [39]:
print("Detecting ripples...", end=" ")
predictions = model.predict(X, verbose=True)
print("Done!")

Done!


In [28]:
from format_predictions import get_predictions_indexes

# This threshold can be changed
threshold = 0.7

print("Getting detected ripples indexes and times...", end=" ")
pred_indexes = get_predictions_indexes(data, predictions, window_size=window_size, stride=stride, fs=1250, threshold=threshold)

pred_times = pred_indexes / 1250
print("Done!")

Getting detected ripples indexes and times... Done!


In [56]:


#@markdown This is an interactive plot of the loaded data, where detected ripples are shown in blue. Data is displayed in chunks of 1 seconds and you can **move forward, backwards or jump to an specific second** using the control bar at the bottom.\
#@markdown \
#@markdown Run this cell to load the plotting method. Execute the **following** cell to use the method.
%matplotlib notebook
from matplotlib.widgets import Slider
downsampled_fs = 1250


def plot_ripples(k):
    data_size = data.shape[0]
    data_dur = data_size / downsampled_fs
    times = np.arange(data_size) / downsampled_fs

    if k >= times[-1]:
        print("Data is only %ds long!"%(times[-1]))
        return
    elif k < 0:
        print("Please introduce a valid integer.")
        return

    ini_idx = int(k * downsampled_fs)
    end_idx = np.minimum(int((k+1) * downsampled_fs), data_size-1)


    pos_mat = list(range(data.shape[1]-1, -1, -1)) * np.ones((end_idx-ini_idx, data.shape[1]))

    fig = plt.figure(figsize=(9.75,5))
    ax = fig.add_subplot(1,1,1)
    ax.set_ylim(-3,9)
    ax.margins(x=0)
    plt.tight_layout()

    plt.subplots_adjust(bottom=0.25)
    ax_slider = plt.axes([0.25, 0.1, 0.65, 0.03])
    slider = Slider(ax_slider, 'Time (s)', 0, times[-1]-1, valinit=k)

    lines = ax.plot(times[ini_idx:end_idx], data[ini_idx:end_idx, :]*1/np.max(data[ini_idx:end_idx, :], axis=0) + pos_mat, color='k', linewidth=1)
    fills = []
    for pred in pred_indexes:
      if (pred[0] >= ini_idx and pred[0] <= end_idx) or (pred[1] >= ini_idx and pred[1] <= end_idx):
          rip_ini = (pred[0]) / downsampled_fs
          rip_end = (pred[1]) / downsampled_fs
          fill = ax.fill_between([rip_ini, rip_end], [-3, -3], [9, 9], color="tab:blue", alpha=0.3)
          fills.append(fill)



    def update(val):
        k = int(slider.val)
        ax.clear()

        ini_idx = int(k * downsampled_fs)
        end_idx = np.minimum(int((k+1) * downsampled_fs), data_size-1)
        pos_mat = list(range(data.shape[1]-1, -1, -1)) * np.ones((end_idx-ini_idx, data.shape[1]))

        ax.set_ylim(-3, 9)
        ax.margins(x=0)
        ax.set_xlabel("Time (s)")


        lines = ax.plot(times[ini_idx:end_idx], data[ini_idx:end_idx, :]*1/np.max(data[ini_idx:end_idx, :], axis=0) + pos_mat, color='k', linewidth=1)
        fills = []
        for pred in pred_indexes:
          if (pred[0] >= ini_idx and pred[0] <= end_idx) or (pred[1] >= ini_idx and pred[1] <= end_idx):
              rip_ini = (pred[0]) / downsampled_fs
              rip_end = (pred[1]) / downsampled_fs
              fill = ax.fill_between([rip_ini, rip_end], [-3, -3], [9, 9], color="tab:blue", alpha=0.3)
              fills.append(fill)

    slider.on_changed(update)
    update(k)  # Call update once to draw the initial plot

    plt.show()

print("Loaded!")



Loaded!


In [57]:
plot_ripples(k=450)

<IPython.core.display.Javascript object>

In [None]:
# Neuropy Detection
from neuropy.analyses.artifact import detect_artifact_epochs
signal = sess.eegfile.get_signal()
buffer_add = 0.1  # seconds, None = don't add

art_epochs_file = sess.filePrefix.with_suffix(".art_epochs.npy")
if art_epochs_file.exists():
    art_epochs = Epoch(epochs=None, file=art_epochs_file)
    print('Existing artifact epochs file loaded')
else:
    art_epochs = detect_artifact_epochs(signal, thresh=6, 
                                    edge_cutoff=1, merge=6)

    if buffer_add is not None:  # Add in buffer to prevent erroneous detection of start/stop of artifact as SWRs
        art_epochs.add_epoch_buffer(buffer_add)
    sess.recinfo.write_epochs(epochs=art_epochs, ext='art')  # Write to neuroscope
    art_epochs.save(art_epochs_file)
art_epochs