<a href="https://colab.research.google.com/github/neurologic/Neurophysiology-Lab/blob/main/modules/crayfish-erg/Data-Explorer_crayfish-erg.ipynb" target="_blank" rel="noopener noreferrer"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg"/></a>   

# Data Explorer

<a id="toc"></a>
# Table of Contents

- [Introduction](#intro)
- [Setup](#setup)
- [Part I. Process Data](#one)
- [Part II. Analyze Processed Data](#two)


<a id="setup"></a>
# Setup

[toc](#toc)

Import and define functions

In [None]:
#@title {display-mode: "form"}

#@markdown Run this code cell to import packages and define functions 
from pathlib import Path
import random
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from scipy import ndimage, optimize, signal
from scipy.signal import hilbert,medfilt,resample, find_peaks,butter
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from datetime import datetime,timezone,timedelta
pal = sns.color_palette(n_colors=15)
pal = pal.as_hex()
from numpy import NaN

from matplotlib.ticker import (AutoMinorLocator, MultipleLocator)
from ipywidgets import widgets, interact, interactive,interactive_output
%config InlineBackend.figure_format = 'retina'
plt.style.use("https://raw.githubusercontent.com/NeuromatchAcademy/course-content/master/nma.mplstyle")

def monoExp(x, m, t, b):
    return m * np.exp(-x / t) + b

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Mount Google Drive

In [None]:
#@title {display-mode: "form"}

#@markdown Run this cell to mount your Google Drive.

from google.colab import drive
drive.mount('/content/drive')

print('Task completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

Import data digitized with *Nidaq USB6211* and recorded using *Bonsai-rx* as a *.bin* file

In [None]:
# #@title {display-mode: "form"}

# #@markdown Specify the file path 
# #@markdown to your recorded data on Drive (find the filepath in the colab file manager:

# filepath = "full filepath goes here"  #@param 
# filepath = "/Volumes/NO NAME/BIOL247_FA22/data/crayfish-erg/KP_20221028/condition2_diff-duration_0.bin"

# #@markdown Specify the sampling rate and number of channels recorded.
# sampling_rate = None #@param
# number_channels = None #@param
# retina_channel = 0 #@param
# stimulus_channel = 1 #@param

# sampling_rate = 30000 #@param
# number_channels = 2 #@param

# downsample = False #@param
# newfs = 10000 #@param

# #@markdown After you have filled out all form fields, 
# #@markdown run this code cell to load the data. 

# filepath = Path(filepath)

# # No need to edit below this line
# #################################
# data = np.fromfile(Path(filepath), dtype = np.float64)
# data = data.reshape(-1,number_channels)
# data = data-data[0,:] # only do this offset adjustment for motor nerve recordings

# dur = np.shape(data)[0]/sampling_rate
# print('duration of recording was %0.2f seconds' %dur)

# fs = sampling_rate
# if downsample:
#     # newfs = 2500 #downsample data
#     chunksize = int(sampling_rate/newfs)
#     if number_channels>1:
#         data = data[0::chunksize,:]
#     if number_channels==1:
#         data = data[0::chunksize]
#     fs = int(np.shape(data)[0]/dur)

# time = np.linspace(0,dur,np.shape(data)[0])

# sos = butter(4, 500, 'lp', fs=fs, output='sos')


# if len(np.shape(data))>1:
#     retina_signal = data[:,retina_channel]
#     # retina_signal = medfilt(retina_signal,51)
#     retina_signal = signal.sosfilt(sos, retina_signal)
#     stimulus_signal = data[:,stimulus_channel]
# if len(np.shape(data))==1:
#     signal = data

# print('Now be a bit patient while it plots.')

# f = go.FigureWidget(make_subplots(rows=2, cols=1, row_width=[3, 1], 
#                                   vertical_spacing=0, shared_xaxes= True)) #,layout=go.Layout(height=500, width=800))
# f.add_trace(go.Scatter(x = time[0:fs], y = retina_signal[0:fs],
#                              opacity=1),row=2,col=1)
# f.add_trace(go.Scatter(x = time[0:fs], y = stimulus_signal[0:fs],
#                              opacity=1),row=1,col=1)

# f.update_layout(height=600, width=1000,
#                 showlegend=False,
#                xaxis2_title="time(seconds)", 
#                   yaxis1_title='photoresistor voltage',
#                yaxis2_title='retina voltage')

# slider = widgets.FloatRangeSlider(
#     min=0,
#     max=dur,
#     value=(0,1),
#     step= 1,
#     readout=False,
#     description='Time')
# slider.layout.width = '600px'

# # our function that will modify the xaxis range
# def response(x):
#     with f.batch_update():
#         starti = int(x[0]*fs)
#         stopi = int(x[1]*fs)
#         f.data[0].x = time[starti:stopi]
#         f.data[0].y = retina_signal[starti:stopi]
#         f.data[1].x = time[starti:stopi]
#         f.data[1].y = stimulus_signal[starti:stopi]

# vb = VBox((f, interactive(response, x=slider)))
# vb.layout.align_items = 'center'
# vb

<a id="two_import"></a>
## Import data 

Import data digitized with *Nidaq USB6211* and recorded using *Bonsai-rx* as a *.bin* file

> If you would like to explore the analysis for this lab, but do not have data, you can download examples for the following experiments using the linked shared files:  

In [None]:
#@title {display-mode: "form" }

#@markdown Specify the file path 
#@markdown to your recorded data in the colab runtime (find the filepath in the colab file manager):

filepath = "full filepath goes here"  #@param 
filepath = '/Volumes/NO NAME/BIOL247_FA22/data/crayfish-erg/KP_20221028/condition2_diff-duration_0.bin'

#@markdown Specify the sampling rate and number of channels recorded.

sampling_rate = None #@param
number_channels = None #@param

sampling_rate = 30000 #@param
number_channels = 2 #@param

# downsample = False #@param
# newfs = 10000 #@param

#@markdown After you have filled out all form fields, 
#@markdown run this code cell to load the data. 

filepath = Path(filepath)

# No need to edit below this line
#################################
data = np.fromfile(Path(filepath), dtype = np.float64)
data = data.reshape(-1,number_channels)
data_dur = np.shape(data)[0]/sampling_rate
print('duration of recording was %0.2f seconds' %data_dur)

fs = sampling_rate
# if downsample:
#     # newfs = 10000 #downsample emg data
#     chunksize = int(sampling_rate/newfs)
#     data = data[0::chunksize,:]
#     fs = int(np.shape(data)[0]/data_dur)

time = np.linspace(0,data_dur,np.shape(data)[0])

print('Data upload completed at ' + str(datetime.now(timezone(-timedelta(hours=5)))))

## Plot raw data

In [None]:
#@title {display-mode: "form"}

#@markdown Run this code cell to plot the imported data. <br> 
#@markdown Use the range slider to scroll through the data in time.
#@markdown Use the channel slider to choose which channel to plot
#@markdown Be patient with the range refresh... the more data you are plotting the slower it will be. 

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,1),
    step= 1,
    readout=True,
    continuous_update=False,
    description='Time Range (s)')
slider_xrange.layout.width = '600px'

slider_chan = widgets.IntSlider(
    min=0,
    max=number_channels-1,
    value=0,
    step= 1,
    continuous_update=False,
    description='channel')
slider_chan.layout.width = '300px'

# a function that will modify the xaxis range
def update_plot(x,chan):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    starti = int(x[0]*fs)
    stopi = int(x[1]*fs)
    ax.plot(time[starti:stopi], data[starti:stopi,chan])

w = interact(update_plot, x=slider_xrange, chan=slider_chan);

For a more extensive ***RAW*** Data Explorer than the one provided in the above figure, use the [DataExplorer.py](https://raw.githubusercontent.com/neurologic/Neurophysiology-Lab/main/howto/Data-Explorer.py) application found in the [howto section](https://neurologic.github.io/Neurophysiology-Lab/howto/Dash-Data-Explorer.html) of the course website.

<a id="two_one"></a>
## Define bout and stimulus times

The time between stimulus onset and action potential, and the time between two stimulus pulses are critical parameters of the data on each trial. 

Our first task in processing and analyzing data from this experiment is to figure out the stimulus onset times. You can then segment the data in to separate bouts if the raw recording was not one continuous successful protocol. 

### Define stimulus times

In [None]:
#@title {display-mode: "form"}

#@markdown Run this cell to create an interactive plot with a slider to scroll 
#@markdown through the signal
#@markdown and set an appropriate event detection threshold  
#@markdown (you can do so based on level crossing or peaks). 

slider_xrange = widgets.FloatRangeSlider(
    min=0,
    max=data_dur,
    value=(0,1),
    step= 0.5,
    readout=True,
    continuous_update=False,
    description='Time Range (s)',
    style = {'description_width': '200px'})
slider_xrange.layout.width = '600px'

# slider_yrange = widgets.FloatRangeSlider(
#     min=np.min(stim)-0.5,
#     max=np.max(stim)+0.5,
#     value=[np.min(stim),np.max(stim)],
#     step=0.05,
#     continuous_update=False,
#     readout=True,
#     description='yrange',
#     style = {'description_width': '200px'})
# slider_yrange.layout.width = '600px'

select_channel = widgets.Select(
    options=np.arange(np.shape(data)[1]), # start with a single trial on a single bout... it will update when runs ; old: np.arange(len(trial_times)),
    value=0,
    #rows=10,
    description='Channel used to detect events',
    style = {'description_width': '200px'},
    disabled=False
)

slider_threshold = widgets.FloatSlider(
    min=-2,
    max=2,
    value=0.2,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='event detection threshold',
    style = {'description_width': '200px'})
slider_threshold.layout.width = '600px'

detect_type_radio = widgets.RadioButtons(
    options=['peak', 'level crossing'],
    value='level crossing', # Defaults to 'level crossing'
    layout={'width': 'max-content'}, # If the items' names are long
    description='Type of event detection',
    style = {'description_width': '200px'},
    disabled=False
)

iei_text = widgets.Text(
    value='0.005',
    placeholder='0.005',
    description='min IEI (seconds)',
    style = {'description_width': '200px'},
    disabled=False
)

def update_plot(chan_,xrange,thresh_,detect_type,iei):
    fig, ax = plt.subplots(figsize=(10,5),num=1); #specify figure number so that it does not keep creating new ones
    
    signal = data[:,chan_]
    signal = signal-np.median(signal)
    
    iei = float(iei)
    
    if iei>0.001:
        d = iei*fs #minimum time allowed between distinct events
        
        if detect_type == 'peak':
            if thresh_ >=0:
                r = find_peaks(signal,height=thresh_,distance=d)
            if thresh_ <0:
                r = find_peaks(-1*signal,height=-1*thresh_,distance=d)
            trial_times = r[0]/fs

        if detect_type == 'level crossing':
            # get the changes in bool value for a bool of signal greater than threshold
            threshold_crossings = np.diff(signal > thresh_, prepend=False)
            # get indices where threshold crossings are true
            tcross = np.argwhere(threshold_crossings)[:,0]
            # get a mask for only positive level crossings
            mask_ = [signal[t]-signal[t-1] > 0 for t in tcross]
            # trial times are positive level crossings
            trial_times = tcross[mask_]/fs

        starti = int(xrange[0]*fs)+1
        stopi = int(xrange[1]*fs)-1
        ax.plot(time[starti:stopi], signal[starti:stopi], color='black')
        
        # ax.plot(tmp,color='black')
        ax.hlines(thresh_, time[starti],time[stopi],linestyle='--',color='green')
        ax.scatter(trial_times,[thresh_]*len(trial_times),marker='^',s=300,color='purple',zorder=3)
        # ax.set_ylim(yrange[0],yrange[1])
        ax.set_xlim(xrange[0],xrange[1])
        

        ax.xaxis.set_minor_locator(AutoMinorLocator(5))

              
        return trial_times

w_trials_ = interactive(update_plot, chan_=select_channel, 
                        xrange=slider_xrange, 
                        thresh_=slider_threshold, detect_type = detect_type_radio, iei = iei_text);
display(w_trials_)

In [None]:
#@title {display-mode: "form"}

#@markdown Run this cell to finalize the list of event times 
#@markdown after settling on a channel and threshold in the interactive plot. <br> 
#@markdown This stores the event times in an array called 'event_times'. <br>
#@markdown NOTE: You may have to use "peaks" method for shorter stimulus pulse durations (separately).
trial_times = w_trials_.result



### Define Bouts

In [None]:
#@title {display-mode: "form"}

#@markdown For this experiment, the entire file should be one long bout, 
#@markdown but if there were regions that something got messed up or that you want to exclude, you can specify bouts with good data.
#@markdown Specify the list of bout ranges as follows: [[start of bout 0, end of bout 0],[start 1, end 1],...]] <br>

bouts_list = [[0,125]] #@param
# bouts_list = [[2,10],[10,20],[20,30],[30,45],[45,55],[55,70],[70,85],[85,100],[100,120]]
# bouts_list = [[0,20]]

#@markdown Then run this code cell to programatically define the list of bouts as 'bouts_list'.

<a id="two_two"></a>
## Analyze Data

### Measure the raw data

Obtain necessary information from the raw signal time-locked to each event (which should be the stimulus pulse onsets).

> Just to give you a ballpark, this data processing step took me about 20 minutes for a paired pulse experiment in which I tested 18 different ISI values (with 2 trials at each value). So there were a total of 72 events that I processed data for.

In [None]:
#@title {display-mode:"form"}

#@markdown Run this code cell to create an interactive plot to  
#@markdown examine the raw signal time-locked to individual events (event_times).
#@markdown You can overlay multple channels by selecting more than one.
#@markdown You can overlay multiple event times by selecting more than one. 
#@markdown (To select more than one item from an option menu, press the control/command key 
#@markdown while mouse clicking or shift while using up/down arrows)

slider_xrange = widgets.FloatRangeSlider(
    min=-0.01,
    max=2,
    value=(-0.001,1),
    step=0.001,
    continuous_update=False,
    readout=True,
    readout_format='.4f',
    description='xrange (s)'
)
slider_xrange.layout.width = '600px'

slider_yrange = widgets.FloatRangeSlider(
    min=-5,
    max=5, # normal range for earthworm experiments
    value=(-2,2),
    step=0.01,
    continuous_update=False,
    readout=True,
    description='yrange'
)
slider_yrange.layout.width = '600px'

ui_range = widgets.VBox([slider_xrange, slider_yrange])

# trials in bout 0 to start...
trials_t = trial_times[(trial_times>bouts_list[0][0]) & (trial_times<bouts_list[0][1])]

odd_even_radio = widgets.RadioButtons(
    options=['odd', 'even', 'all'],
    value='all', # Defaults to 'none'
    layout={'width': 'max-content'}, # If the items' names are long
    description='show only events by: ',
    style = {'description_width': '400px'},
    disabled=False
)

select_channels = widgets.SelectMultiple(
    options=np.arange(np.shape(data)[1]), # start with a single trial on a single bout... it will update when runs ,
    value=[0],
    #rows=10,
    description='Channels',
    disabled=False
)

select_bouts = widgets.Select(
    options=np.arange(len(bouts_list)), # start with a single trial on a single bout... it will update when runs ; old: np.arange(len(trial_times)),
    value=0,
    #rows=10,
    description='Bouts',
    disabled=False
)

select_trials = widgets.SelectMultiple(
    options=np.arange(len(trials_t)), # start with a single trial on a single bout... it will update when runs ,
    value=[0],
    #rows=10,
    description='Events',
    disabled=False
)

ui_trials = widgets.HBox([select_channels, select_trials, select_bouts])

slider_threshold = widgets.FloatSlider(
    min=-1,
    max=1,
    value=0.25,
    step=0.001,
    readout_format='.3f',
    continuous_update=False,
    readout=True,
    description='peak detection threshold',
    style = {'description_width': '200px'})
slider_threshold.layout.width = '600px'

detect_chan_radio = widgets.RadioButtons(
    options=['0', '1', 'none'],
    value='none', # Defaults to 'none'
    layout={'width': 'max-content'}, # If the items' names are long
    description='detect delay to peaks on channel: ',
    style = {'description_width': '400px'},
    disabled=False
)

ui_peaks = widgets.VBox([detect_chan_radio, slider_threshold])

trial_abs_readout = widgets.Label(
    value=f'time of this event is (sec): {NaN}'
)
trial_abs_readout.layout.width = '600px'

trial_readout = widgets.Label(
    value=f'time since last event is: {NaN}'
)
trial_readout.layout.width = '600px'

lagging_time_readout = widgets.Label(
    value=f'lagging peak times are: {NaN}'
)
lagging_time_readout.layout.width = '600px'

lagging_amp_readout = widgets.Label(
    value=f'lagging peak amplitudes are: {NaN}'
)
lagging_amp_readout.layout.width = '600px'

def update_plot(trial_sort_,chan_list,trial_list,bout_,yrange,xrange,lagging_chan_,thresh_):
    fig, ax = plt.subplots(figsize=(8,4))# ,ncols=1, nrows=1); #specify figure number so that it does not keep creating new ones
 
    win_0 = int(xrange[0]*fs)
    win_1 = int(xrange[1]*fs)
    xtime = np.linspace(xrange[0],xrange[1],(win_1 - win_0))
    
    trials_t = trial_times[(trial_times>bouts_list[bout_][0]) & (trial_times<bouts_list[bout_][1])]
    trials_init_ = np.arange(len(trials_t))
    
    if trial_sort_=='all':                     
        select_trials.options = trials_init_

        trial_list = [t_try for t_try in trial_list if t_try in trials_init_]
        select_trials.value = trial_list

    if trial_sort_=='even':                     
        select_trials.options = trials_init_[0::2]

        trial_list = [t_try for t_try in trial_list if t_try in trials_init_[0::2]]
        select_trials.value = trial_list                         
    
    if trial_sort_=='odd':                     
        select_trials.options = trials_init_[1::2]

        trial_list = [t_try for t_try in trial_list if t_try in trials_init_[1::2]]
        select_trials.value = trial_list                                 
    
    lagging_time_readout.value=f'lagging peak times are: {NaN}'
    lagging_amp_readout.value=f'lagging peak amplitudes are: {NaN}'
    trial_abs_readout.value=f'time of this event is: {NaN}'
    trial_readout.value=f'time since last event is: {NaN}'
    
    channel_colors = ['purple','green','blue','orange']
    for chan_ in chan_list:
        this_chan = data[:,chan_]
        for trial_ in trial_list:
            if trial_ in trials_init_:
                t_ = trials_t[trial_]

                if ((int(fs*t_)+win_0)>0) & ((int(fs*t_)+win_1))<len(this_chan):
                    data_sweep = this_chan[(int(fs*t_)+win_0):(int(fs*t_)+win_1)]

                    ax.plot(xtime,data_sweep,color=channel_colors[chan_],linewidth=2,alpha=0.5)
    
    d = 0.0005*fs
    if (lagging_chan_ != 'none') & (len(trial_list)==1):
        ax.hlines(thresh_, xrange[0],xrange[1],linestyle='--',color='green')
        lagging_chan_ = int(lagging_chan_)
        lagging_signal = data[(int(fs*t_)+win_0):(int(fs*t_)+win_1),lagging_chan_]
        if thresh_ >=0:
            r = find_peaks(lagging_signal,height=thresh_,distance=d)
            lagging_peak_amp = r[1]['peak_heights']
        if thresh_ <0:
            r = find_peaks(-1*lagging_signal,height=-1*thresh_,distance=d)
            lagging_peak_amp = -1*r[1]['peak_heights']
            # print(r)
            
        lagging_peak_amp = [np.round(a,2) for a in lagging_peak_amp]
        
        lagging_peak_times = [np.round(xtime[lt],5) for lt in r[0]]#r[0]/fs
        lagging_time_readout.value=f'lagging peak times are (ms): {[t*1000 for t in lagging_peak_times]}'
        lagging_amp_readout.value=f'lagging peak amplitudes are (V): {lagging_peak_amp}'
        
        if trial_list[0] == 0:
            trial_readout.value=f'time since last event is: {NaN}'
            trial_abs_readout.value=f'time of this event is: {NaN}'
        if trial_list[0] > 0:
            iti = 1000*(trials_t[trial_list[0]] - trials_t[trial_list[0]-1])
            trial_readout.value=f'time since last event is (msec): {iti:.2f}'
            trial_abs_readout.value=f'time of this event is (sec): {trials_t[trial_list[0]]}'
        
        for lt_ in lagging_peak_times:
            ax.scatter(lagging_peak_times,lagging_peak_amp,color='black',s=50,zorder=3)
    

    ax.set_ylim(yrange[0],yrange[1]);
    ax.set_xlabel('seconds')
    # ax.vlines(0,yrange[0],yrange[1],color='black')

    
#     # Change major ticks to show every 20.
    # ax_pwm.xaxis.set_major_locator(MultipleLocator(5))
    # ax_pwm.yaxis.set_major_locator(MultipleLocator(5))

    # # Change minor ticks to show every 5. (20/4 = 5)
    # ax_mro.yaxis.set_minor_locator(AutoMinorLocator(10))
    ax.xaxis.set_minor_locator(AutoMinorLocator(10))
    # ax_pwm.yaxis.set_minor_locator(AutoMinorLocator(5))

#     # Turn grid on for both major and minor ticks and style minor slightly
# #     # differently.
    ax.grid(which='major', color='gray', linestyle='-')
    ax.grid(which='minor', color='gray', linestyle=':')
#     ax_pwm.grid(which='major', color='gray', linestyle='-')
#     ax_pwm.grid(which='minor', color='gray', linestyle=':')


# display(w)
w = interactive_output(update_plot, {'trial_sort_':odd_even_radio,
                                     'chan_list':select_channels,
                                     'trial_list':select_trials, 
                                     'bout_':select_bouts,
                                     'yrange':slider_yrange, 
                                     'xrange':slider_xrange,
                                     'lagging_chan_':detect_chan_radio,
                                     'thresh_':slider_threshold});
display(trial_abs_readout,trial_readout,lagging_time_readout,lagging_amp_readout,
        odd_even_radio,ui_trials,ui_peaks,w,ui_range)


> Need some way to get the pulse durations from the signal (might need to sample at a higher rate?... although I guess the shortest was 500microseconds?)

<hr> 
Written by Dr. Krista Perks for courses taught at Wesleyan University.

<a id="setup"></a>

<a id="one"></a>

<a id="two"></a>

<a id="three"></a>

<a id="four"></a>