# Signal Tools

### Pick and choose functions as needed
- Reads in an h5 file from Lead Labs Willow System
- Scaling function to turn raw data into microvolts
- Filtering routine between user specified cutoffs (butterworth bandpass)
- Downsampling routine to 1kHz for oscillatory activity
- Displays raw traces from each channel as they are arranged on the probe shank

## Import Packages 

In [None]:
import h5py
import numpy as np
import os
import matplotlib.pyplot as plt
from IPython.display import HTML
from scipy.signal import butter, sosfiltfilt, sosfreqz, filtfilt
import time
import pandas as pd
import scipy 
import csv

## Input File (Required)

#### Function Definition: Import a single h5 file

In [None]:
def get_single_h5 (datafile, chan_start, chan_end, time_start, time_end, fs):
    channels = list(np.arange(chan_start, chan_end)) # list of channels for the current shank
    f_data = h5py.File(datafile, 'r') # Read in the h5 file
    x_values = f_data.get('sample_index') # time in datapoints (x)
    y_values = f_data.get('channel_data')  # actual data in raw format, must be scaled to microvolts (y)
    
    # keep only the time window and channels that we specified in user input
    x_values = x_values[time_start*fs:time_end*fs] 
    y_values = y_values[time_start*fs:time_end*fs, channels]
    f_data.close()
    display (HTML("Data consists of " + str(y_values.shape[1]) + 
              " columns of data (channels, x) and " + str(y_values.shape[0]) + 
              " rows of data (measurements, y)." + "The recording is " +
              str(round(ys.shape[0]/(fs * 60),3)) + " mins long."))
    display (HTML("x-axis data (in datapoints) is stored in the <strong>x_values</strong> variable. <br>"))
    display (HTML("y-axis data is stored in the <strong>y_values</strong> variable. <br>"))
    return x_values, y_values

#### User Input: Import a single file

In [None]:
datafile = 'experiment_C20200330-174152_concat.h5' 
fs = 30000 # sample rate
window = [0, 30] # time window of analysis (for filtering, 4 mins max for most computers)
chan_num = [0,128] # channels on the shank

# Call the get_single_h5() function
x_values, y_values = get_single_h5 (datafile, chan_num[0], chan_num[1], window[0], window[1], fs)

## Scale the data

#### Function Definition: Scale Data

In [None]:
# Define the reusable scaling function
def scale_data (data, scale_factor):
    scaled_data = [d*scale_factor for d in data]
    return scaled_data

#### User Input: Scale X values
- Example: scale datapoints to seconds: 
    - `unscaled = x_values` 
    - `scale_factor = 1/30000`

In [None]:
x_unscaled = x_values 
x_scale_factor = 1/fs

# Call the scale_data function
x_scaled = scale_data (x_unscaled, x_scale_factor)
display(HTML("""Scaled X-axis data saved to the <strong>x_scaled</strong> 
             variable as <strong> datapoints </strong>."""))

#### User Input: Scale Y Values
- Example: Scale raw y values to microvolts: 
   - `unscaled = y_values`
   - `scale_factor = 0.195`

In [None]:
y_unscaled = y_values 
y_scale_factor = 0.195

# Call the scale_data function
y_scaled = []
print("Scaling Channel:", end = ' ')
for i,y in enumerate(y_unscaled.T):  # Cycles through the transposed columns (channels) of y-axis data
    print(i, end = ' ')
    y_scaled.append (scale_data (y, y_scale_factor)) # Add the column to the new list of scaled data
y_scaled = np.array(y_scaled).T  # make the list of scaled data a numpy array for speed/convenience
display(HTML("Scaled y-axis data saved to the <strong>y_scaled</strong> variable."))

## Bandpass Filter the Data

#### Function Definition: Butterworth Bandpass Filtering

In [None]:
def butter_bandpass(lowcut, highcut, fs, order):
        nyq = 0.5 * fs
        low = lowcut / nyq
        high = highcut / nyq
        sos = butter(order, [low, high], analog=False, btype='band', output='sos')
        return sos

def butter_bandpass_filter(data, lowcut, highcut, fs, order):
        sos = butter_bandpass(lowcut, highcut, fs, order)
        y = sosfiltfilt(sos, data)
        return y

def get_filtered (ys, lowcut, highcut, order, fs, df):
    display (HTML("<h4>Analyzing channel: "))
    for chan in np.arange(ys.shape[1]):
        print(chan, end = ' ')
        order = 6
        x = np.arange(len(ys[:,chan]))
        sos = butter_bandpass(lowcut, highcut, fs, order)
        w, h = sosfreqz(sos, worN=2000)
        filtered[chan] = butter_bandpass_filter(ys[:,chan], lowcut, highcut, fs, order=order)
    return df

#### User Input: Filtering (scaled data)

In [None]:
y_unfiltered = y_scaled # or y_values if you didn't need the scaling step
fs = 30000              # Sample Rate, if using downsampled data remember to change this to 1000
lowcut = 450            # Hz for bandpass filter
highcut = 5000          # Hz for bandpass filter
order = 6               # For bandpass filter 

filtered = pd.DataFrame()
filtered = get_filtered(y_unfiltered, lowcut, highcut, order, fs, filtered)
y_filtered = np.array(filtered)
display(HTML("Filtered y-axis data saved to the <strong>y_filtered</strong> variable."))

## Downsample the Data

#### Define the downsampling() function

In [None]:
def downsampling(xs,ys,fs, new_fs):
    by = fs/new_fs
    x_downsampled = [xd/fs for xd in np.arange(0,max(xs), by)]
    y_downsampled = []
    for i,y in enumerate(ys):
        y_downsampled.append(y)
        i = i + 30
    y_downsampled = np.array(y_downsampled)
    return x_downsampled, y_downsampled                         

#### User input: Downsampling

In [None]:
x_fs = x_values     
y_fs = y_filtered   # or y_values or y_scaled, etc..
fs = 30000          # current sample rat
new_fs = 1000       # the desired sample rate

# Call the downsampling() function
x_downsampled, y_downsampled = downsampling(x_fs,y_fs,fs,new_fs)
x_downsampled = np.array(x_downsampled)
y_downsampled = np.array(y_downsampled)
display(HTML("""Downsampled x-axis data saved to the 
            <strong>x_downsampled</strong> variable as <strong>seconds</strong>"""))
display(HTML("Downsampled y-axis data saved to the <strong>y_downsampled</strong> variable."))

## Lowpass Filter the Data

#### Function Definition: Lowpass filter()

In [None]:
def butter_lowpass_filter(data, cutoff, fs, order,nyq):
    normal_cutoff = cutoff / nyq
    # Get the filter coefficients 
    b, a = butter(order, normal_cutoff, btype='low', analog=False)
    y = filtfilt(b, a, data)
    return y

def get_lowpass (ys, highcut, order, fs):
    filtered = []
    display (HTML("<h4>Analyzing channel: "))
    for chan in np.arange(ys.shape[1]):
        print(chan, end = ' ')
        x = np.arange(len(ys[:,chan])) # sample x_values
        n = len(x)  # number of samples
        t = len(x)/fs # sample period (seconds)
        signal_frequency = 15 # in Hz, highest desired frequency + buffer 
        nyq = fs/2
        filtered.append(butter_lowpass_filter(ys[:,chan], highcut, fs, order, nyq))
    filtered = np.array(filtered)
    return filtered

#### User Input: Filtering (downsampled data)

In [None]:
y_lowpass = y_downsampled  # or y_values if you didn't need the scaling step
fs = 1000                     # Sample Rate, if using downsampled data remember to change this to 1000
highcut = 15                  # Hz for bandpass filter
order = 2                     # For sine wave, polynomial of 2 is appropriate 

ds_filtered = pd.DataFrame()
ds_filtered = get_lowpass(y_lowpass, highcut, order, fs)
y_downsampled_filtered = np.array(ds_filtered)
display(HTML("""Downsampled and filtered y-axis data 
            saved to the <strong>y_downsampled_filtered</strong> variable."""))

## Make a Line Plot

#### Define line_plot() function

In [None]:
def line_plot (x,y,x_label, y_label, start, end):
    start = int(start * fs)
    end = int(end * fs)
    fig, ax = plt.subplots(figsize = (20,5))
    ax.plot(x[start:end],y[start:end])
    ax.set_xlabel (x_label, fontsize = 16)
    ax.set_ylabel (y_label, fontsize = 16)
    ax.tick_params(labelsize = 14)

#### User Input: Plot Line Function

In [None]:
x_line = x_downsampled # or x_values if you didn't scale 
y_line = y_downsampled # or y_scale or y_values as you prefer
channel = 0  
start_plot = 0 # in seconds
end_plot = 10   # in seconds
x_label = 'Time (s)' # units of your x data
y_label = 'Microvolts'

# Call the line_plot function
line_plot(x_line, y_line[:,channel],x_label, y_label, start_plot, end_plot )

## Plot Data as Grid

#### Function Definition: grid_plot()

In [None]:
def grid_plot(y_grid,columns, start, end, fs):
    columns = np.array(columns)
    start = int(start*30000)
    end = int(end*30000)
    xs = list(np.arange(len(y_grid[start:end,0])))
    x = [x/fs for x in xs]
    display(HTML('<hr><h4>Plotting channel: '))
    fig, ax = plt.subplots (len(columns[0]),int(columns.shape[0]), figsize =(15,40), sharex = True, sharey = True)
    j = 0
    for j in np.arange (0,columns.shape[0]):
        for i,col in enumerate(columns[j]): 
            if j == 1: # Offset for middle row (grid is 66 panels, but we only have 64 channels. Middle row is shorter) 
                i = i + 1
            print(col, end = ' ')
            ax[i][j].plot(x, y_grid[start:end,col], color = 'dimgray', 
                            label=str(col)) # Unfiltered Signal
            handles, labels = ax[i][j].get_legend_handles_labels()
            ax[i][j].legend(handles, labels, loc = 'upper right', fontsize = 8, shadow = False)
            y_lims = ax[i][j].get_ylim()
            ax[i][j].tick_params (labelsize = 12)
    plt.tight_layout()
    fig.text(0.0, 0.5, r'Amplitude ($\mu$V)' + '\n', ha='center', rotation='vertical', fontsize = 18)
    fig.text(0.5, 0.0, '\n Time (s)', va='center',  fontsize = 18)
    plt.savefig (datafile.replace('.h5','_unfiltered' + str(chan_num[1]) + '.png'))
    display(HTML('<hr>'))



#### User Input: Grid plot

In [None]:
y_grid = y_downsampled_filtered  # can be y_values, y_scaled, y_filtered, etc... as needed
start_grid = 0 # in seconds
end_grid = 5   # in seconds
fs = 1000     # change if not downsampled
# columns is a list of list representing the arrangement of channels on the shank
# Within the column list are 3 sublists, each corresponding to a column on the shank
# Top row is column 1, bottom is column 3
# Channels to the left of the list are closer to the tip of the shank
columns = [[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21], 
            [22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40,41], 
            [42,43,44,45,46,47,48,49,50,51,52,53,54,55,56,57,58,59,60,61,62,63]]  

# call the grid_plot() function
grid_plot(y_grid,columns, start_grid, end_grid, fs)

## Save Traces to a Nex CSV

In [None]:
def nex_traces (ys, fs, save_name):
    xs = list(np.arange(len(ys[:,0])))
    x = [x/fs for x in xs]
    df = pd.DataFrame()
    df['time_seconds'] = x 
    for col in np.arange(ys.shape[1]):
        df['Ch'+str(col)] =  ys[:,col]
    df.to_csv(save_name, index = False)
    print(display(HTML(df.head().to_html())))

In [None]:
ys = y_filtered
fs = 30000
save_name = datafile.replace('.h5','_traces_nex.csv') # Can change this if needed
nex_traces (ys,fs, save_name)

## Concatenate CSVs 

In [None]:
def csv_concatenate(csv_files, common_column,save_file): 
    col_count = 0
    out = pd.DataFrame()
    for i,csv_file in enumerate(csv_files,0):
        df = pd.read_csv(csv_files[i])
        colnames = [int(col.split('Ch')[1])+(col_count) for 
                    col in df.columns[df.columns != common_column]]
        for col in colnames:
            out['Ch'+str(col)] = df['Ch'+ str(col-col_count)] 
        col_count = col_count + len(df.columns[df.columns != common_column]) 
    if common_column in df.columns.values:
        out[common_column]=df[common_column]
        out.to_csv(save_file)
    out.to_csv(save_file, index = False)
    return csv

In [None]:
csv_files = ['experiment_C20200330-174152_filtered_192_events_nex.csv', 'experiment_C20200330-174152_filtered_192_events_nex.csv'] # list the filenames inside the brackets, separated by commas
common_column = 'time_seconds' # (ie time_seconds) otherwise leave as empy quotation marks
save_file = 'concatenated.csv'

# Call the function
csv_data = csv_concatenate(csv_files, common_column, save_file)
csv_data = np.array(csv_data) # convert to numpy array for speed/compatibility

# Check the conversion
df = pd.read_csv(save_file)
display(HTML(df.head().to_html()))


## H5 Concatenate

#### Define Function: h5_concatenate_columns() and h5_concatenate_rows()

In [None]:
def h5_concatenate_columns(datafiles, chan_range, save_name):
    df = pd.DataFrame()
    count = 0
    for i,f in enumerate(datafiles):
        channels = list(np.arange(chan_range[0],chan_range[1]))
        f_data = h5py.File(f, 'r') # Read in the h5 file
        ys = f_data.get('channel_data')  # actual data in raw format, must be scaled to microvolts (y)
        xs = list(np.arange(len(ys[:,0])))
        df_temp = pd.DataFrame(ys)
        df_temp.columns = ['Ch' + str(chan + count) for chan in np.arange(len(df_temp.columns.values))]
        df = df.join(df_temp)
        count = count + len(df_temp.columns.values)
        print("Columns after file " + str(i) + ": " + str(len(df.columns.values)))
    hf = h5py.File(save_name.replace('.h5', '_cols.h5'), 'w')
    hf.create_dataset('sample_index', data = list(np.arange(len(df.columns.values))))
    hf.create_dataset('channel_data', data = df)
    hf.close()
    
def h5_concatenate_rows(datafiles, chan_range, save_name):
    df = pd.DataFrame()
    count = 0
    for i,f in enumerate(datafiles):
        channels = list(np.arange(chan_range[0],chan_range[1]))
        f_data = h5py.File(f, 'r') # Read in the h5 file
        ys = f_data.get('channel_data')  # actual data in raw format, must be scaled to microvolts (y)
        xs = list(np.arange(len(ys[:,0])))
        df_temp = pd.DataFrame(ys)
        df = df.append(df_temp)
        print('Rows after file '+ str(i) +': ' + str(len(df.iloc[:][0])))
        count = count + len(df_temp.columns.values)
    hf = h5py.File(save_name.replace('.h5', '_rows.h5'), 'w')
    hf.create_dataset('sample_index', data = list(np.arange(len(df.columns.values))))
    hf.create_dataset('channel_data', data = df)
    hf.close()
    
def h5_concatenate (datafiles, chan_range, save_name, rows_or_columns):
    if rows_or_columns == 'rows':
        h5_concatenate_rows (datafiles, chan_range, save_name)
    elif rows_or_columns == 'columns':
        h5_concatenate_columns(datafiles, chan_range, save_name)

#### User Input: h5_concatenate
- Be sure to list the files in order that you want them concatenated

In [None]:
datafiles = ['experiment_C20200330-174152_chunk0.h5',
            'experiment_C20200330-174152_chunk1.h5']
chan_range = [0,64]
save_name = 'experiment_C20200330-174152_concat.h5'
rows_or_columns = 'columns'    # Choose whether to concatenate by rows or columns

# Call the h5_concatenate function
h5_concatenate(datafiles, chan_range, save_name, rows_or_columns)

## H5 Splitter

#### Function Definition: h5_splitter

In [None]:
def h5_saver (save_name, x_values, y_values):
    hf = h5py.File(save_name, 'w')
    hf.create_dataset('sample_index', data = x_values)
    hf.create_dataset('channel_data', data = y_values)
    hf.close()

def h5_splitter (datafile, chunk_length, chunk_num, chan_range, fs):
    channels = list(np.arange(chan_range[0], chan_range[1])) # list of channels for the current shank
    f_data = h5py.File(datafile, 'r') # Read in the h5 file
    x_values = f_data.get('sample_index') # time in datapoints (x)
    y_values = f_data.get('channel_data')  # actual data in raw format, must be scaled to microvolts (y)

    # cycle through the data and parcel out chunks
    chunk_length = int(chunk_length * fs)
    time_start = 0
    time_end = int(time_start + chunk_length)
    for chunk in np.arange(chunk_num):
        x_tmp = x_values[time_start:time_end] 
        y_tmp = y_values[time_start:time_end, channels]
        h5_saver(datafile.replace('.h5','_chunk'+str(chunk)+'.h5'), x_tmp, y_tmp)
        time_start = time_start + chunk_length
        time_end = time_end + chunk_length
    f_data.close()

#### User Input: h5_splitter

In [None]:
datafile = 'experiment_C20200330-174152.h5'
chunk_length = 0.2 # Length of chunks in seconds
chunk_num = 10
chan_range = [0,64] # start and end of channel range (last one not included)
fs = 30000

h5_splitter (datafile, chunk_length, chunk_num, chan_range, fs)