# Crimson Stimulation Responses

Looking through the BrainPatch stimulation responses, in particular those LFP/dendritic spikes. Need to figure out what they are...

From the "artifact_exploration" stuff I was doing, it looks like I'll mostly need to look at clips around the stim. Can probably HPF at about 70 hz and keep the interesting stuff

In [None]:
%load_ext autoreload
%autoreload

from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
import pandas as pd
from open_ephys.analysis import Session
from scipy import signal
# from scipy.fft import fft, fftfreq
# from scipy.signal.windows import gaussian
from sklearn.decomposition import PCA
import os, glob, re
import openephys_utils 


# %matplotlib ipympl
%matplotlib qt

# pdfPages for saving images to a multi-page PDF
from matplotlib.backends.backend_pdf import PdfPages

## Define a few functions for later usage

First one just opens an open ephys directory and returns the signals, timestamps, and events

In [None]:
def open_sig_events(directory:str):
    # open up a session, then pull out the signal and events
    session = Session(directory)

    recording = session.recordnodes[0].recordings[0].continuous[0]

    # get out the signal
    sig = recording.samples[:,:64] * recording.metadata['bit_volts'][0]

    # pull out the events -- both giving the time and the indices
    events = np.argwhere(np.diff(recording.samples[:,64]>5000) == 1)
    events = events.reshape([int(events.shape[0]/2),2])
    event_ts = events/recording.metadata['sample_rate']

    # timestamps
    timestamps = recording.sample_numbers - recording.sample_numbers[0]
    timestamps = timestamps/recording.metadata['sample_rate']

    return sig, timestamps, events, event_ts


Function to find the minimum of a clipped period after the stimulation

In [None]:
def find_responses(sig, events, len_ms:int = 25, n_chans:int = 64, sample_rate:int = 30000):

    t_len = len_ms*30 # 52 ms * 30 khz

    # set up the events to plot patches
    n_events = events.shape[0] # number of stimulation events

    responses = np.zeros((n_events, t_len, n_chans))
    # maxs = np.zeros((n_events, n_chans)) # not getting much info from these
    # rel_maxs = np.zeros((n_events, n_chans))
    # abs_maxs = np.zeros((n_events, n_chans))
    mins = np.zeros((n_events, n_chans))
    rel_mins = np.zeros((n_events, n_chans))
    abs_mins = np.zeros((n_events, n_chans))

    for i_event, event in enumerate(events):
        response = sig[event[0]:event[0]+len_ms*int(sample_rate/1000),:]
        means = np.mean(sig[event[0]+4:event[1]-4,:], axis = 0)
        responses[i_event,:,:] = response - means # response for each channel
    
        mins[i_event,:] = np.min(response - means, axis=0)
        rel_mins[i_event,:] = np.argmin(response - means, axis=0)/30000
        abs_mins[i_event,:] = rel_mins[i_event,:] + event_ts[i_event,0]

    return mins, rel_mins, abs_mins

Plot the average post-stim responses for a particular channel. Will plot it into an existing axis if provided one

In [None]:
def plot_avg_response(sig, events, len_ms:int = 25, channel = 0, ax:plt.axes=None, label:str=None):
    # Plot the average response for a particular channel
    
    if ax is None:
        fig,ax = plt.subplots()
    
    if label == None:
        label = f'Channel {channel}'


    # set up the events to plot patches
    n_events = events.shape[0] # number of stimulation events

    # put together a NxT array
    t_len = len_ms * 30
    responses = np.zeros((n_events, t_len))
    
    # go through each event
    for i_event, event in enumerate(events):
        response = sig[event[0]:event[0]+len_ms*30,channel]
        means = np.mean(sig[event[0]+4:event[1]-4]) # center the during-stimulation to 0
        responses[i_event,:] = response - means
    
    # put together the means, STDs, and timestamps
    ts = np.arange(t_len)/30
    means = np.mean(responses, axis=0)
    line = ax.plot(ts, means, label=label)
    
    ts_std = np.ravel(np.array([ts, ts[::-1]]))
    std = np.ravel(np.array([means + np.std(responses, axis=0), means[::-1] - np.std(responses, axis=0)[::-1]]))
    patch_array = np.array([ts_std, std]).T
    std_patch = Polygon(patch_array, alpha=0.2, color=line[-1].get_color())
    ax.add_patch(std_patch)

    # print(dir(line))

Find spikes using the old-school default:

1. Filter and de-mean
1. Calculate a threshold: -4.5x the STD
1. Flag issues:
    1. Too-short ISIs (< 3 ms?)
    1. Simulataneous-ish (on more than N channels)
    1. Something about the wave shape -- deviations? Depth of field?
1. 


In [None]:
def find_spikes(sig, filter_high:float = 6000, filter_low:float = 150, sample_rate:int = 30000, CAR:bool = True):
    # number of channels
    n_chan = np.min(sig.shape)
    axis = np.argmax(sig.shape)
    sig = sig.T if axis == 1 else sig # make sure that Time is along axis 0

    # # CAR
    # if CAR:
    #     pca = PCA()
    #     xform = PCA.fit_transform(sig)
    #     # sig = np.matmul(xform, PCA.)


    # filter the thing
    sos = signal.butter(N=8, Wn=[filter_low, filter_high], fs=sample_rate, output='sos', btype='bandpass')
    filt_sig = signal.sosfiltfilt(sos=sos, x=sig, axis=0)

    # find a threshold for each channel
    thresholds = np.expand_dims(-4.5 * np.std(filt_sig, axis=0), axis=0)


    # find the crossings
    thresholds_rep = np.tile(thresholds, (sig.shape[0], 1)) # create a TxN array of N threshold values
    crossings = np.argwhere(np.diff((filt_sig<thresholds_rep).astype(int), axis=0) < 0)

    # Create a dataframe for the spikes, and also store chunks of 50 ms of data
    spike_df = pd.DataFrame({'sample_no':crossings[:,0].astype(int), 'electrode':crossings[:,1].astype(int)})
    sample_columns = [f'sample {i - 10}' for i in range(50)]
    spike_df.loc[:,sample_columns] = np.nan

    for i_row,row in spike_df.iterrows():
        spike_df.loc[i_row,sample_columns] = filt_sig[row['sample_no']-10:row['sample_no']+40,row['electrode']]

    # return spike_df, filt_sig
    return spike_df






## Single file analysis

Mostly to check the functioning of the code when I'm batch processing files

pull in the data -- we'll start with one file at a time

In [None]:
directory = 'Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um'

# load signals if we haven't already loaded it
if 'sig' not in locals():
    sig, timestamps, events, event_ts = open_sig_events(directory)

# pull out the spikes
if 'spike_df' not in locals():
    spike_df,filt_sig = find_spikes(sig)

# choose the channels to show
channels = np.arange(40,50)

# plot the continuous and show the times
fig_cont, ax_cont = plt.subplots(nrows = len(channels), sharex=True)

for i_channel, channel in enumerate(channels):
    ax_cont[i_channel].plot(timestamps, filt_sig[:,channel])
    for i_spike, spike in spike_df.loc[spike_df['electrode'] == channel].iterrows():
        ax_cont[i_channel].axvspan((int(spike['sample_no'])-10)/30000, (int(spike['sample_no'])+40)/30000, color = 'cyan')
    


In [None]:
def calc_spike_counts(spike_dict:dict, max_ts:int = None, min_ts:int = None, fs:int = 30000, bin_ms:float = .005):
    # calculate the binned firing rates and return, along with bin timestamps

    if (max_ts is None) or (min_ts is None):
        print('Calculating bin range start and finish is a bummer! . Next time give me some info!')
        max_ts,min_ts = 0,0
        for chan_xings in spike_dict.values():
            max_ts = int(max(max_ts, chan_xings['spike_ts'].max()))
            min_ts = int(min(min_ts, chan_xings['spike_ts'].min()))
        

    bins = np.arange(start=min_ts, step = int(fs*bin_ms), stop=max_ts+1) # put together the bins
    spike_counts = np.empty((len(bins),len(spike_dict.keys()))) # put together a pre-allocated array
    for channel, data in spike_dict.items(): # loop through the dict
        spike_counts[:,channel] = np.histogram(data['sample_no'], bins) # bin it


    return spike_counts, bins    

         

In [None]:
print('h')

max_ts= int(sig.shape[0]) + 1
min_ts= 0
fs = 30000
bin_ms = .005


bins = np.arange(start=min_ts, step = int(fs*bin_ms), stop=max_ts+1) # put together the bins
spike_counts = np.empty((len(bins)-1,len(spikes.keys()))) # put together a pre-allocated array
for channel, data in spikes.items(): # loop through the dict
    spike_counts[:,channel],_ = np.histogram(data['sample_no'], bins) # bin it




In [None]:
# import openephys_utils
# fig_pca, ax_pca = plt.subplots(nrows=4, sharex=True)
directory = 'Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um'

if not all([var in locals() for var in ['sig', 'timestamps', 'events', 'events_ts']]):
    sig, timestamps, events, event_ts = openephys_utils.open_sig_stims(directory)

sig_eraasr = openephys_utils.ERAASR(sig)
sig_mine = openephys_utils.ERAASR(sig, mode='mine')

# spikes = openephys_utils.threshold_crossings(sig_eraasr)
spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None)
fr,fr_ts = openephys_utils.calc_FR(spike_dict, max_samp = sig.shape[0], min_samp=0, bin_ms = 0.001)


In [None]:
def avg_stim_FR(firing_rate:np.array, stims:np.array, bin_sec:float = .001, normalize:bool = True, resp_length_samp:int = 12000, fs:int=30000):
    '''
    the average firing rate for a recording (per channel) for each recording setup
    
    inputs:
        firing_rate: np.array               - array of the firing rates
        stims:np.array                      - stimulation start and stop times (Nx2) 
        normalize:bool                      - Should we normalize so the average (pre-stim) firing rate is 1?
        fs:int                              - sample rate (Hz) [30000]
        bin_sec:float                         - bin window length (s) [.001]
        bin_ms:int
        resp_length_samp:int                - length of interest for response (samples) [15000]

        
    outputs:
        avg_fr:np.array            - 
        std_fr:np.array
    
    '''
    # convert bin_win (seconds) into a sample value
    bin_samp = int(fs*bin_win)

    # normalize off the average firing rates before the first stim:
    if normalize:
        pre_stim = firing_rate[:stims[0,0], :] # pull out the firing rates before the first stimulation
        firing_rate = np.matmul(firing_rate, np.linalg.pinv(pre_stim.mean(axis=0)*np.eye(firing_rate.shape[1]))) 

    # pull out the stimulation segments, then reshape
    stim_resp = firing_rate[(np.concatenate([np.arange(start=int(row[0]/bin_samp), stop=int((row[0]+resp_length_samp)/bin_samp)) for row in stims])).astype(int), :] 
    stim_resp = stim_resp.reshape((stims.shape[0],int(resp_length_samp/bin_samp),64)).transpose((1,0,2))
    
    # then pull out the mean for each channel and the std
    avg_fr = stim_resp.mean(axis=1) # mean across trials
    std_fr = stim_resp.std(axis=1) # standard deviation across trials

    return avg_fr, std_fr

In [None]:
def plot_mean_FR(mean_FR:np.array, stims:np.array, std_FR:np.array = None, channel:int = None, ax:plt.axes = None):
    '''
    Plot the mean firing rates for a specific channel. 

    inputs:
        mean_FR : np.array
        stims : np.array
        std_FR : np.array
        channel : int
        ax : matplotlib axes

    outputs:
    '''
    if ax is None:
        fig, ax = plt.subplots()

    
    FR_line = ax.plot(mean_FR[:,channel])
    if std_FR is not None:
        ts = np.arange(mean_FR.shape[0])
        ts_std = np.ravel(np.array([ts, ts[::-1]]))
        std_patch = np.concatenate([mean_FR[:,channel] + std_FR[:,channel], np.maximum(mean_FR[:,channel][::-1] - std_FR[:,channel][::-1],0)])
        std_patch = Polygon(np.array(list(zip(ts_std,std_patch))), color=FR_line[0].get_color(), alpha=0.2)
        ax.add_patch(std_patch)


In [None]:
# mean_fr,std_fr = avg_stim_FR(fr, events, normalize=False)


fig,ax = plt.subplots()
plot_mean_FR(mean_fr, events, std_fr, 51, ax)

In [None]:
resp_length = 12000
channel=51
fs = 30000
bin_win = .001
bin_samp = int(fs*bin_win)
resp_samp = int(resp_length/bin_samp)

stim_resp = fr[(np.concatenate([np.arange(start=int(row[0]/bin_samp), stop=int((row[0]+resp_length)/bin_samp)) for row in events])).astype(int),:]
stim_resp = stim_resp.reshape((events.shape[0],resp_samp,64)).transpose((1,0,2))

fig,ax = plt.subplots(nrows=3, sharex=True, sharey=True)
for i_event in np.arange(events.shape[0]):
    ax[0].plot(np.arange(resp_samp)*bin_samp, stim_resp[:,i_event,channel])
ax[1].plot(np.arange(resp_samp)*bin_samp, stim_resp[:,:,channel].mean(axis=1))
ax[1].plot(np.arange(resp_samp)*bin_samp, stim_resp[:,:,channel].mean(axis=1) + stim_resp[:,:,channel].std(axis=1))
ax[1].plot(np.arange(resp_samp)*bin_samp, stim_resp[:,:,channel].mean(axis=1) - stim_resp[:,:,channel].std(axis=1))
openephys_utils.plot_PSTH(spike_dict, events, channel, ax=ax[2])


In [None]:
pre_stim = fr[:int(events[0,0]/bin_samp),:]
avg_prestim = np.linalg.pinv(pre_stim.mean(axis=0)*np.eye(64))
fr_demean = np.matmul(fr, avg_prestim)

In [None]:
fig,ax = plt.subplots()

line1 = ax.plot(fr_demean[:,51])
line2 = ax.plot(fr[:,51])

In [None]:
zip(np.arange(fr.shape[0]),fr[:,51])

In [None]:
line1[0].get_color()

In [None]:
# fr.shape
[np.arange(start=int(row[0]/bin_samp), stop=int((row[0]+resp_length)/bin_samp)) for row in events]

In [None]:
bin_width = 30000 * .01
bins = np.arange(start = 0, step = bin_width, stop = sig.shape[0] + 150)
FR = np.histogram(spike_dict[51]['sample_no'], bins)

In [None]:
spike_dict[51]['sample_no'].shape

In [None]:
fig,ax = plt.subplots()
fr,fr_ts = openephys_utils.calc_FR(spike_dict, max_samp = sig.shape[0], min_samp=0, bin_ms = 0.001)

# ax.plot(FR[1][:-1],FR[0])
ax.plot(fr_ts,fr[:,51]*50)
ax.scatter(spike_dict[51]['sample_no'], np.ones((spike_dict[51]['sample_no'].shape[0], 1)), color='k')

In [None]:
np.diff(events, axis=0)

In [None]:
# import openephys_utils
spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None)
fr2,fr2_ts = openephys_utils.calc_FR(spike_dict, max_samp = sig.shape[0], min_samp=0)

In [None]:
def plot_spike_binary(spike_dict:dict, ax = None, stims:np.array = None, fs = 30000):
    '''
    Plot an on/off of channels over time. show the stimulations if given
    '''

    # create an axis if not given
    if ax is None:
        fig,ax = plt.subplots()

    # for each channel, plot the spike times
    for channel, data in spike_dict.items():
        ax.vlines(data['sample_no']/fs, channel, channel+1)
    
    # plot the stimulation times if given
    if stims is not None:
        for i_stim in range(stims.shape[0]):
            patch_array = np.array([[stims[i_stim,0]/fs,-1],
                                        [stims[i_stim,1]/fs,-1],
                                        [stims[i_stim,1]/fs,len(spike_dict.keys())],
                                        [stims[i_stim,0]/fs,len(spike_dict.keys())]])
            stim_patch = Polygon(patch_array, alpha=0.2, color='k')
            ax.add_patch(stim_patch)

In [None]:
def plot_PSTH(spike_dict:dict, stims:np.array, channel:int = 0, ax=None, fs=30000):
    '''
    Plot the PSTH for a single channel
    
    inputs:
        spike_dict:dict     - firing rates
        events:np.array     - event times
        channel:int         - which channel are we working with?
    '''

    # create an axis if not given
    if ax == None:
        fig,ax = plt.subplots()

    # split the spikes for the channel into a series of new channels
    spikes = spike_dict[channel]['sample_no']
    for i_stim in range(stims.shape[0]-1):
        spike_subset = spikes[np.logical_and(spikes>=stims[i_stim,0], spikes<stims[i_stim+1,0])] - stims[i_stim,0]
        ax.vlines(spike_subset/fs,i_stim,i_stim+1)
        # create a patch at the stimulus point
        patch_array = np.array([[0, i_stim], [stims[i_stim,1]/fs-stims[i_stim,0]/fs,i_stim],
                                [stims[i_stim,1]/fs-stims[i_stim,0]/fs,i_stim+1], [0, i_stim+1]])
        stim_patch = Polygon(patch_array, alpha=0.4, color='k')
        ax.add_patch(stim_patch)

In [None]:

spike_subset = spikes[51]['sample_no']
for i_stim in range(events.shape[0]-1):
        resh = spike_subset[np.logical_and(spike_subset>=events[i_stim,0], spike_subset<events[i_stim+1,0])] - events[i_stim,0]
        ax.vlines(spike_subset,i_stim,i_stim+1)

In [None]:
def plot_mean_waveforms(spike_dict:dict, channel:int=0, fs:int=30000, ax=None):
    '''
    plot the mean threshold crossing for a channel, and a patch around the standard deviation.

    could theoretically look at splitting into different units

    inputs:
        spike_dict
        channels
        std_flag
        map
    '''



    # create an axis if it doesn't exist
    if ax is None:
        fig,ax = plt.subplots()

    # go into the waveforms
    mean_wf = spike_dict[channel]['waveform'].mean(axis=0)
    std_= spike_dict[channel]['waveform'].std(axis=0)

    # elapsed ts
    ts = [t/(fs/1000) for t in range(mean_wf.shape[0])]

    # create std patch
    std_array = np.ndarray((2*mean_wf.shape[0],2))
    std_array[:mean_wf.shape[0],:] = np.column_stack((ts, mean_wf+std_))
    std_array[mean_wf.shape[0]:,:] = np.column_stack((ts, mean_wf-std_))[::-1,:]
    std_patch = Polygon(std_array, alpha=0.2, color = 'orange')

    # plot em
    ax.plot(ts, mean_wf, color='orange') 
    ax.add_patch(std_patch)
    ax.set_xlabel('time (ms)')
    ax.set_ylabel('magnitude (uV)')

In [None]:
plot_spike_binary(spikes, stims = events)

In [None]:
def multisave_PDF(fig, pdf):
    pdf.savefig(fig)

In [None]:
fig,ax = plt.subplots(nrows=2, sharex=True)

openephys_utils.plot_PSTH(spike_dict=spike_dict, stims=events, ax=ax[0], channel=51)

stim_reps_T = stim_reps.reshape((121,450,64)).transpose([1,0,2])
for i_stim in range(121):
    ax[1].plot(stim_reps_T[:,i_stim,51])

In [None]:
with PdfPages(os.path.join(directory, 'channel_plots_minimum_nomultichannel.pdf')) as pdf:
    fig, ax = plt.subplots(nrows = 2)
    for channel in range(64):
        # plot_PSTH(spikes, events, ax=ax[0], channel=channel)
        # plot_mean_waveforms(spikes, ax=ax[1], channel=channel)
        pdf.attach_note(f'Channel {channel}')
        plot_PSTH(spike_dict, events, ax=ax[0], channel=channel)
        plot_mean_waveforms(spike_dict, ax=ax[1], channel=channel)
        pdf.savefig(fig)
        for sub_ax in ax:
            sub_ax.cla()

In [None]:
high_pass, low_pass = 300, 6000
fs = 30000
thresh_mult = -3.5

sos_bpf = signal.butter(N = 8, Wn = [high_pass, low_pass], btype='bandpass', fs = fs, output='sos')
sig_filt = signal.sosfiltfilt(sos=sos_bpf, x = sig_eraasr, axis=0)

# find the threshold values for each channel
thresholds = np.std(sig_filt, axis=0) * thresh_mult
xings = np.nonzero(np.diff(np.where(sig_filt < thresholds, 1, 0), axis=0) == 1)

# fig,ax = plt.subplots(nrows = len(channels))

# for i_channel,channel in enumerate(channels):
#     xings_channel = xings[0][xings[1] == channel]
#     ax[i_channel].plot(timestamps, sig_filt[:,channel])
#     ax[i_channel].plot(timestamps, sig_eraasr[:,channel])
#     ax[i_channel].hlines(thresholds[channel], timestamps[0], timestamps[-1])

#     ax[i_channel].vlines(timestamps[xings_channel], 1.2*np.min(sig_filt[:,channel]), 1.2*np.max(sig_filt[:,channel]))
    

# need to introduce some basic cross-channel artifact rejection

# split into per-channel dictionary
# bt = int(.0003*fs)
# at = int(.0012 * fs)
# spike_dict = {}
# for i_channel in np.arange(sig.shape[1]):
#     spike_ts = xings[0][xings[1] == i_channel] # sample #
#     spike_wf = [sig_filt[ts-bt:ts+at,i_channel] for ts in spike_ts] # waveform

#     spike_dict[i_channel] = {'sample_no':spike_ts, 'waveform':spike_wf}


In [None]:
channels = [32, 36, 39, 48, 51] # the mapping from Sara just seems to be 1:1, but I'm not sure that's right...
fig_filt, ax_filt = plt.subplots(nrows = len(channels), sharex=True, sharey=True)

for i_channel, channel in enumerate(channels):
    ax_filt[i_channel].plot(timestamps,sig_clean[:,channel])
    ax_filt[i_channel].plot(timestamps,sig_filt[:,channel])
    ax_filt[i_channel].hlines(np.std(sig_filt[:,channel])*-3.5, timestamps[0], timestamps[-1])
    ax_filt[i_channel].set_title(channel)

In [None]:
# directory = 'Z:\\BrainPatch\\20240821\\Crimson__2024-08-21_13-29-59__20mA_MinOil_2ms'
# directory = 'Z:\\BrainPatch\\20240821\\Crimson__2024-08-21_13-46-01__20mA_MinOil_2ms'
# directory = 'Z:\\BrainPatch\\20240821\\Crimson__2024-08-21_15-10-03__20mA_MinOil_2ms'


# directory = 'Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um'
directory = 'Z:\\BrainPatch\\20241002\\Crimson__2024-10-02_12-00-49__spontaneous_waking'

# # load signals if we haven't already loaded it
# if 'sig' not in locals():
#     sig, timestamps, events, event_ts = open_sig_events(directory)
session = Session(directory)
print(session)


for i_rec in range(len(session.recordnodes)):
    print(f'{len(session.recordnodes[i_rec].recordings)} recording(s) in session "{session.recordnodes[i_rec].directory}"\n')
    recordings = session.recordnodes[i_rec].recordings
    
    for i_rec,recording in enumerate(recordings):
        recording.load_continuous()
        recording.load_spikes()
        recording.load_events()
        recording.load_messages()

        print(f'Recording {i_rec} has:')
        print(f'\t{len(recording.continuous)} continuous streams')
        print(f'\t{len(recording.spikes)} spike streams')
        print(f'\t{len(recording.events)} event streams')
    
    print('\n')


Compare the offline filtering to the online filtering, and take a look at the specific channels that I think might have some good stuff

In [None]:
# channel list -- 
channels = [32, 36, 39, 48, 51] # the mapping from Sara just seems to be 1:1, but I'm not sure that's right...

# put together some filters
# sos_h = signal.butter(N = 8, Wn = [150], btype = 'high', output = 'sos', fs=30000)
# sos_l = signal.butter(N = 8, Wn = [6000], btype = 'low', output = 'sos', fs=30000)
sos_bp = signal.butter(N=4, Wn = [150, 8000], btype='bandpass', output='sos', fs=30000)

# timestamps -- raw
# ts_raw = np.arange(len(session.recordnodes[0].recordings[0].continuous[0].sample_numbers))/session.recordnodes[0].recordings[0].continuous[0].metadata['sample_rate']
ts_raw = session.recordnodes[0].recordings[0].continuous[0].sample_numbers/session.recordnodes[0].recordings[0].continuous[0].metadata['sample_rate']
# ts_filt = np.arange(len(session.recordnodes[0].recordings[1].continuous[0].sample_numbers))/session.recordnodes[0].recordings[1].continuous[0].metadata['sample_rate']
# ts_filt = session.recordnodes[0].recordings[1].continuous[0].sample_numbers/session.recordnodes[0].recordings[1].continuous[0].metadata['sample_rate']


fig,ax = plt.subplots(nrows=len(channels), sharex=True)
# raw recording -- filter it and plot it
for i_channel, channel in enumerate(channels):
    # sig_temp = signal.sosfilt(sos_l, signal.sosfilt(sos_h, session.recordnodes[0].recordings[0].continuous[0].samples[:,channel])/4)
    sig_temp = signal.sosfilt(sos_bp, session.recordnodes[0].recordings[0].continuous[0].samples[:,channel])
    ax[i_channel].plot(ts_raw, sig_temp, label='filtered offline')
    ax[i_channel].plot(ts_raw, session.recordnodes[0].recordings[0].continuous[0].samples[:,channel], label='raw')
    # ax[i_channel].plot(ts_filt, session.recordnodes[0].recordings[1].continuous[0].samples[:,channel], label='filtered online')

    ax[i_channel].set_ylabel('uV')
    ax[i_channel].legend()
    


In [None]:
plt_freq, ax_freq = plt.subplots(nrows=2)

w, h = signal.sosfreqz(sos=sos_bp, fs = 30000)

ax_freq[0].semilogx(w, 20*np.log10(np.abs(h)))
ax_freq[1].semilogx(w, np.angle(h))



In [None]:
from scipy.io import loadmat

In [None]:
probe_map = loadmat("Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat")

Grab 50 ms after each stimulation. Set the mean of the stimulation period to 0.

Find the minimum, maximum, depth of modulation, and time of each after the stimulation starts

In [None]:
n_chans = 64 # 64 recording channels
len_ms = 150
t_len = len_ms*30 # 52 ms * 30 khz
n_events = events.shape[0] # number of stimulation events

# set up the events to plot patches
events = np.argwhere(np.diff(recording.continuous[0].samples[:,64]>5000) == 1)
events = events.reshape([int(events.shape[0]/2),2])
event_ts = events/recording.continuous[0].metadata['sample_rate']

responses = np.zeros((n_events, t_len, n_chans))
maxs = np.zeros((n_events, n_chans))
rel_maxs = np.zeros((n_events, n_chans))
abs_maxs = np.zeros((n_events, n_chans))
mins = np.zeros((n_events, n_chans))
rel_mins = np.zeros((n_events, n_chans))
abs_mins = np.zeros((n_events, n_chans))

for i_event, event in enumerate(events):
    response = sig[event[0]:event[0]+len_ms*30,:]
    means = np.mean(sig[event[0]+4:event[1]-4,:], axis = 0)
    responses[i_event,:,:] = response - means # response for each channel
    
    mins[i_event,:] = np.min(response - means, axis=0)
    rel_mins[i_event,:] = np.argmin(response - means, axis=0)/30000
    abs_mins[i_event,:] = rel_mins[i_event,:] + event_ts[i_event,0]

    # maxs[i_event,:] = np.max(response[int(rel_mins*30000),:] - means, axis=0) # only interested in stuff after the negative deviation
    # rel_maxs[i_event,:] = np.argmax(response[int(rel_mins*30000),:] - means, axis=0)/30000
    # abs_maxs[i_event,:] = rel_maxs[i_event,:] + event_ts[i_event,0]


Now let's do the same thing, but look at the same channel for a couple of different stimulation amplitudes

### Average stimulation responses

let's take a look at the average stimulation response for a couple different electrodes

In [None]:
# directory = 'Z:\\BrainPatch\\20240821\\Crimson__2024-08-21_13-46-01__20mA_MinOil_2ms'
directory = 'Z:\\BrainPatch\\20240821\\Crimson__2024-08-21_13-29-59__20mA_MinOil_2ms'

# get the signal etc
signal, timestamps, events, event_ts = open_sig_events(directory)

fig_avg, ax_avg = plt.subplots()

for channel in [0,5,10,15,20]:
    plot_avg_response(signal, events, len_ms= 40, channel=channel, ax=ax_avg)


ax_avg.axvspan(0, 2, color='k', alpha=.1, label='Stimulation Period')

# clean up the plot, add a legend etc
ax_avg.legend()
for spine in ['top','bottom','right','left']:
    ax_avg.spines[spine].set_visible(False)

ax_avg.set_xlabel('Time after stimulation onset (ms)')
ax_avg.set_ylabel('Magnitude (uV)')
ax_avg.set_title('Mean stimulation responses with standard deviations\n20 mA, 400 um')

## Multi-file analysis

Looking at the responses over different distances and currents

First we need to put together a list of the different recordings and the parameters

### August 21

In [None]:
## August 21 data
# lets go through recordings in groups of locations
base_dir = 'Z:\\BrainPatch\\20240821'

dir_400 = ['Crimson__2024-08-21_13-44-07__10mA_MinOil_2ms','Crimson__2024-08-21_13-46-01__20mA_MinOil_2ms','Crimson__2024-08-21_13-47-40__15mA_MinOil_2ms','Crimson__2024-08-21_13-49-43__10mA_MinOil_2ms','Crimson__2024-08-21_13-51-50__5mA_MinOil_2ms']
dir_700 = ['Crimson__2024-08-21_13-56-49__5mA_MinOil_2ms','Crimson__2024-08-21_13-58-50__10mA_MinOil_2ms','Crimson__2024-08-21_14-00-53__15mA_MinOil_2ms','Crimson__2024-08-21_14-02-54__20mA_MinOil_2ms']
dir_1000 = ['Crimson__2024-08-21_14-05-52__5mA_MinOil_2ms','Crimson__2024-08-21_14-07-41__10mA_MinOil_2ms','Crimson__2024-08-21_14-09-46__15mA_MinOil_2ms','Crimson__2024-08-21_14-11-45__20mA_MinOil_2ms']
dir_1300 = ['Crimson__2024-08-21_14-14-26__5mA_MinOil_2ms','Crimson__2024-08-21_14-16-02__10mA_MinOil_2ms','Crimson__2024-08-21_14-17-58__15mA_MinOil_2ms','Crimson__2024-08-21_14-20-21__20mA_MinOil_2ms']
dir_1600 = ['Crimson__2024-08-21_14-23-13__5mA_MinOil_2ms','Crimson__2024-08-21_14-25-16__10mA_MinOil_2ms','Crimson__2024-08-21_14-27-12__15mA_MinOil_2ms','Crimson__2024-08-21_14-29-03__20mA_MinOil_2ms']

# dictionary of direct groups
dir_dict = {400: dir_400, 700:dir_700, 1000:dir_1000, 1300:dir_1300, 1600:dir_1600}


### September 25

In [None]:
# base_dir = 'Z:\\BrainPatch\\20240925\\No_Mineral_Oil'
base_dir = 'Z:\\BrainPatch\\20240925'

dir_400 = glob.glob('*0mm_2ms*', root_dir=base_dir) + glob.glob('*2ms_0mm*', root_dir=base_dir)
dir_600 = glob.glob('*2ms_.6mm', root_dir=base_dir) 
dir_1200 = glob.glob('*2ms_1.2mm', root_dir=base_dir) 
dir_1500 = glob.glob('*2ms_1.5mm', root_dir=base_dir) 

dir_dict = {400:dir_400, 600:dir_600, 1200:dir_1200, 1500:dir_1500}

channel = 42

### October 2

In [None]:
base_dir = 'Z:\\BrainPatch\\20241002\\lateral'

dir_300 = glob.glob('*2ms_400um', root_dir=base_dir)
dir_600 = glob.glob('*2ms_600um', root_dir=base_dir)
dir_900 = glob.glob('*2ms_900um', root_dir=base_dir)
dir_1200 = glob.glob('*2ms_1200us', root_dir=base_dir)
dir_1500 = glob.glob('*2ms_1500um', root_dir=base_dir)

dir_dict = {300:dir_300, 600:dir_600, 900:dir_900, 1200:dir_1200, 1500:dir_1500}

channel = 36

Next let's take a look at a single channel for a few different current levels

In [None]:
fig_avg_dist, ax_avg_dist = plt.subplots()

channel = 48
distance = 1500

for sub_dir in dir_1500:
    directory = os.path.join(base_dir,sub_dir)

    # get the signal etc
    sig, timestamps, events, event_ts = open_sig_events(directory)

    amp = re.search('(\d)+mA',sub_dir)[0]
    plot_avg_response(sig, events, len_ms= 40, channel=channel, ax=ax_avg_dist, label=amp)

ax_avg_dist.axvspan(0, 2, color='k', alpha=.1, label='Stimulation Period')
    
# clean up the plot, add a legend etc
ax_avg_dist.legend()
for spine in ['top','bottom','right','left']:
    ax_avg_dist.spines[spine].set_visible(False)

ax_avg_dist.set_xlabel('Time after stimulation onset (ms)')
ax_avg_dist.set_ylabel('Magnitude (uV)')
ax_avg_dist.set_title(f'Mean stimulation at different stimulation amplitudes\nChannel {channel}, {distance} um')

Load all of the different directories, then put the mean and median negative deviation for each channel into a dataframe for easy analysis and plotting

In [None]:
resp_df = pd.DataFrame(columns=['Channel_no','Current','Distance','uMin','uMin_ts','medMin','medMin_ts'])

for dist,dir_list in dir_dict.items():
    for sub_dir in dir_list:
        directory = os.path.join(base_dir, sub_dir) # go through the subdir, check to make sure it exists and that there's data inside
        if not os.path.exists(directory):
            continue
        if not len([file for file in os.listdir(directory) if not file.startswith('.')]): # if the directory is empty, skip it
            continue


        # open the directory
        sig, timestamps, events, event_ts = open_sig_events(directory)

        # pull out the stim responses
        mins, rel_mins, abs_mins = find_responses(sig, events)

        # means and medians for each channel
        uMins = np.mean(mins, axis=0)
        uMins_ts = np.mean(rel_mins, axis=0)
        medMins = np.median(mins, axis=0)
        medMins_ts = np.median(rel_mins, axis=0)

        # a nested dictionary of all of the channels responses
        tdict = {ii:{'Channel_no':ii, 
                'Current':re.search('([0-9]+)mA', sub_dir)[1],
                'Distance': dist,
                'uMin':uMins[ii],
                'uMin_ts':uMins_ts[ii],
                'medMin':medMins[ii],
                'medMin_ts':medMins_ts[ii],
                } for ii in range(64)}

        t_df = pd.DataFrame.from_dict(tdict, orient='index') # create a dataframe

        resp_df = pd.concat([resp_df, t_df], ignore_index=True)

resp_df.Current = resp_df.Current.astype(int)

In [None]:
os.listdir(directory)

Plot the effects of distance on the magnitude of the response for the different current levels. Different channels on different axes

In [None]:
currents = resp_df.Current.unique()
currents.sort()
# channels = [10, 15, 20, 25, 30]
channels = np.arange(64, step = 5)

fig_dist,ax_dist = plt.subplots(nrows=len(channels), sharex=True, sharey=True)
# fig_time, ax_time = plt.subplots(nrows=len(channels), sharex=True, sharey=True)
for i_chan,chan in enumerate(channels):
    for i_curr,curr in enumerate(currents):
        dist_cmp = resp_df.loc[(resp_df.Current==curr) * (resp_df.Channel_no==chan)]
        ax_dist[i_chan].plot(dist_cmp.Distance, dist_cmp.uMin)
        # ax_time[i_chan].plot(dist_cmp.Distance, dist_cmp.uMin_ts)

    ax_dist[i_chan].legend([f'{current} mA' for current in currents], loc=4)
    ax_dist[i_chan].set_title(f'Channel {chan}')
    ax_dist[i_chan].set_ylabel('Magnitude (uV)')


    # ax_time[i_chan].legend([f'{current} mA' for current in currents], loc=4)
    # ax_time[i_chan].set_title(f'Channel {chan}')
    # ax_time[i_chan].set_ylabel('Time (ms)')

    # remove the outer boxes
    for spine in ['top','bottom','right','left']:
        ax_dist[i_chan].spines[spine].set_visible(False)
        # ax_time[i_chan].spines[spine].set_visible(False)
    

ax_dist[-1].set_xlabel('Distance (um)')
fig_dist.suptitle('Mean response minimum as a function of distance (per current level)')


# ax_time[-1].set_xlabel('Distance (um)')
# fig_time.suptitle('Mean minimum time as a function of distance (per current level)')



mean negative deviation for all channels as a function of distance. Different axis per current level

In [None]:
currents = resp_df.Current.unique()
distances = resp_df.Distance.unique()
currents.sort()
distances.sort()

fig_min_scatter,ax_min_scatter = plt.subplots(ncols=len(currents), sharex=True, sharey=True)
for i_curr,curr in enumerate(currents):
    dist_cmp = resp_df.loc[resp_df.Current==curr ]
    ax_min_scatter[i_curr].scatter(dist_cmp.Distance, dist_cmp.uMin, s = 2, color='grey')
    current_means = dist_cmp.groupby('Distance').mean('uMin')
    ax_min_scatter[i_curr].plot(current_means.index, current_means.uMin, color='k')

    # ax_min_scatter[i_curr].legend([f'{current} mA' for current in currents], loc=4)
    ax_min_scatter[i_curr].set_title(f'LED current: {curr} mA')
    ax_min_scatter[i_curr].set_xlabel('Distance (um)')


    # remove the outer boxes
    for spine in ['top','bottom','right','left']:
        ax_min_scatter[i_curr].spines[spine].set_visible(False)
        # ax_time[i_chan].spines[spine].set_visible(False)
    

ax_min_scatter[0].set_ylabel('Magnitude (uV)')
fig_dist.suptitle('Mean response minimum as a function of distance (per current level)')


# ax_time[-1].set_xlabel('Distance (um)')
# fig_time.suptitle('Mean minimum time as a function of distance (per current level)')



Time of the minimum value as a function of current. Each distance on a different plot

In [None]:
currents = resp_df.Current.unique()
distances = resp_df.Distance.unique()
currents.sort()
distances.sort()

fig_time_scatter,ax_time_scatter = plt.subplots(ncols=len(distances), sharex=True, sharey=True)
for i_dist,dist in enumerate(distances):
    curr_cmp = resp_df.loc[resp_df.Distance == dist]
    ax_time_scatter[i_dist].scatter(curr_cmp.Current, curr_cmp.uMin_ts, s = 2, color='blue')
    curr_means = curr_cmp.groupby('Current').mean('uMin_ts')
    ax_time_scatter[i_dist].plot(curr_means.index, curr_means.uMin_ts, color='k')

    # ax_time_scatter[i_dist].legend([f'{distent} mA' for distent in distents], loc=4)
    ax_time_scatter[i_dist].set_title(f'Distance: {dist}um')
    ax_time_scatter[i_dist].set_xlabel('Current (mA)')


    # remove the outer boxes
    for spine in ['top','bottom','right','left']:
        ax_time_scatter[i_dist].spines[spine].set_visible(False)
        # ax_time[i_chan].spines[spine].set_visible(False)
    

ax_time_scatter[0].set_ylabel('Time (ms)')
fig_dist.suptitle('Mean response minimum as a function of current (per distance)')


# Multi-Channel spike processing


October 2

In [None]:
base_dir = 'Z:\\BrainPatch\\20241002\\lateral'

dir_300 = glob.glob(os.path.join(base_dir,'*2ms_400um'))
dir_600 = glob.glob(os.path.join(base_dir,'*2ms_600um'))
dir_900 = glob.glob(os.path.join(base_dir,'*2ms_900um'))
dir_1200 = glob.glob(os.path.join(base_dir,'*2ms_1200us'))
dir_1500 = glob.glob(os.path.join(base_dir,'*2ms_1500um'))

dir_dict = {300:dir_300, 600:dir_600, 900:dir_900, 1200:dir_1200, 1500:dir_1500}

channel = 51

In [None]:
channels = [51]
with PdfPages(os.path.join(base_dir, f'channel_{channel}.pdf')) as pdf:
    fig, ax = plt.subplots(nrows = 2)

    for directory in dir_300:
        amplitude = re.search('\d{1,2}mA', directory)[0]
        if not all([var in locals() for var in ['sig', 'timestamps', 'events', 'events_ts']]):
            sig, timestamps, stims, stim_ts = openephys_utils.open_sig_stims(directory)

        sig_eraasr = openephys_utils.ERAASR(sig)
        sig_mine = openephys_utils.ERAASR(sig, mode='mine')

        spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None, low_cutoff=-20)

        for channel in channels:
            openephys_utils.plot_mean_waveforms(spike_dict, ax=ax[0], channel=channel)
            ax[0].set_title('Threshold Crossing Waveform')
            openephys_utils.plot_PSTH(spike_dict, stims, ax=ax[1], channel=channel)
            ax[1].set_title('')
            ax[1].set_xlabel('Time (s)')
            ax[1].set_ylabel('Stimulation Number')
            # plt.title(f'Channel {channel}, {directory}')
            fig.text(0.05, 0.95, f'Channel {channel}, {amplitude}', transform=fig.transFigure, size=24)
            pdf.savefig(fig)
            for sub_ax in ax:
                sub_ax.cla()

Put together mean firing rates, PSTHs, and units for highest firing channels that seem like reasonable waveforms. Iterate through all conditions for a particular unit

In [2]:
# %load_ext autoreload
# %autoreload

from matplotlib import pyplot as plt
from matplotlib.patches import Polygon
import numpy as np
import pandas as pd
from open_ephys.analysis import Session
from scipy import signal
from scipy.io import loadmat
# from scipy.fft import fft, fftfreq
# from scipy.signal.windows import gaussian
from sklearn.decomposition import PCA
import os, glob, re
import openephys_utils 


# %matplotlib ipympl
%matplotlib qt

# pdfPages for saving images to a multi-page PDF
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.gridspec import GridSpec

# status bar
from tqdm.notebook import tqdm

# kilosort
import kilosort

Show the mean post-stim firing rates and waveform for all conditions for a couple of channels.

This turned out pretty messy, so the next cell down is the same thing for just a few current/distances

In [None]:
channels = [36, 40, 42, 46, 51]
bin_sec = .001
base_dir = 'Z:\\BrainPatch\\20241002\\lateral'

figs = {}
for channel in channels:
    fig = plt.figure()
    gs = GridSpec(3,2, fig)
    ax_wf = fig.add_subplot(gs[0,0])
    ax_means = fig.add_subplot(gs[1,:])
    # ax_meanmax = fig.add_subplot(gs[2,:])
    figs[channel] = {'figure':fig, 'thresh':ax_wf, 'fr_means':ax_means}

# fig_max,ax_max = plt.subplots()

directories = [directory for directory in os.listdir(base_dir) if '2ms' in directory]
for directory in tqdm(directories):
    current = int(re.search('(\d{1,2})mA', directory)[1])
    distance = int(re.search('(\d{3,4})um', directory)[1])

    print(f'{current}mA {distance}mm')

    # read everything
    sig, timestamps, stims, stim_ts = openephys_utils.open_sig_stims(os.path.join(base_dir,directory))

    # clean, pull out threshold crossings, get firing rates
    sig_eraasr = openephys_utils.ERAASR(sig)
    spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None, low_cutoff=-20)
    fr, bins = openephys_utils.calc_FR(spike_dict, max_samp = sig_eraasr.shape[0], min_samp = 0, bin_sec = bin_sec)

    mean_fr,std_fr = openephys_utils.mean_stim_FR(fr, stims, bin_sec = bin_sec, normalize=False, resp_length_samp = 4000)

    for channel in channels: # for each of the channels we want
        # make the plots and label them, etc
        openephys_utils.plot_mean_waveforms(spike_dict = spike_dict, channel=channel, ax=figs[channel]['thresh'], label=f'{current}mA, {distance/1000}mm')
        openephys_utils.plot_mean_FR(mean_FR=mean_fr, channel = channel, ax = figs[channel]['fr_means'], label=f'{current}mA, {distance/1000}mm')



# add a legend to everything
for fig in figs:
    fig['thresh'].legend()
    fig['fr_means'].legend()



Mean post-stim firing rates and waveforms for just a couple recordings at 400 um

In [7]:

channels = [36, 40, 42, 46, 51]
bin_sec = .001
base_dir = 'Z:\\BrainPatch\\20241002\\lateral'


figs = {}
for channel in channels:
    fig = plt.figure()
    gs = GridSpec(3,2, fig)
    ax_wf = fig.add_subplot(gs[0,0])
    ax_means = fig.add_subplot(gs[1,:])
    # ax_meanmax = fig.add_subplot(gs[2,:])
    figs[channel] = {'figure':fig, 'thresh':ax_wf, 'fr_means':ax_means}

# directories = [directory for directory in os.listdir(base_dir) if '2ms' in directory]
# directories = ['Crimson__2024-10-02_12-26-46__20mA_2ms_400um',
            #    'Crimson__2024-10-02_12-35-37__15mA_2ms_400um',
            #    'Crimson__2024-10-02_12-39-41__10mA_2ms_400um',
            #    'Crimson__2024-10-02_12-49-56__5mA_2ms_400um']
directories = ['Crimson__2024-10-02_12-26-46__20mA_2ms_400um']
for directory in tqdm(directories):
    current = int(re.search('(\d{1,2})mA', directory)[1])
    distance = int(re.search('(\d{3,4})um', directory)[1])

    print(f'{current}mA {distance}mm')

    # read everything
    sig, timestamps, stims, stim_ts = openephys_utils.open_sig_stims(os.path.join(base_dir,directory))

    # clean, pull out threshold crossings, get firing rates
    sig_eraasr = openephys_utils.ERAASR(sig)
    spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None, low_cutoff=-20, thresh_mult=-4.5)
    fr, bins = openephys_utils.calc_FR(spike_dict, max_samp = sig_eraasr.shape[0], min_samp = 0, bin_sec = bin_sec)

    mean_fr,std_fr = openephys_utils.mean_stim_FR(fr, stims, bin_sec = bin_sec, normalize=False, resp_length_samp = 4000)

    for i_channel, channel in enumerate(channels): # for each of the channels we want
        # make the plots and label them, etc
        openephys_utils.plot_mean_waveforms(spike_dict = spike_dict, channel=channel, ax=figs[channel]['thresh'], label=f'{current}mA, {distance/1000}mm')
        openephys_utils.plot_mean_FR(mean_FR=mean_fr/10, std_FR=std_fr/10, channel = channel, ax = figs[channel]['fr_means'], label=f'{current}mA, {distance/1000}mm', bin_ms = bin_sec * 1000)



  0%|          | 0/1 [00:00<?, ?it/s]

20mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-26-46__20mA_2ms_400um\raw_signal.pkl


Same process as above, but storing a "max firing rate" dictionary similar to the one for the LFPs

In [None]:

channels = np.arange(64)
bin_sec = .005
base_dir = 'Z:\\BrainPatch\\20241002\\lateral'

fr_list = []
directories = [directory for directory in os.listdir(base_dir) if '2ms' in directory]
for directory in tqdm(directories):
    current = int(re.search('(\d{1,2})mA', directory)[1])
    distance = int(re.search('(\d{3,4})um', directory)[1])

    print(f'{current}mA {distance}mm')

    # read everything
    sig, timestamps, stims, stim_ts = openephys_utils.open_sig_stims(os.path.join(base_dir,directory))

    # clean, pull out threshold crossings, get firing rates
    sig_eraasr = openephys_utils.ERAASR(sig, stims)
    spike_dict = openephys_utils.threshold_crossings(sig_eraasr, multi_rejection=None, low_cutoff=-20, thresh_mult=-4.5)
    fr, bins = openephys_utils.calc_FR(spike_dict, max_samp = sig_eraasr.shape[0], min_samp = 0, bin_sec = bin_sec)

    mean_fr,std_fr = openephys_utils.mean_stim_FR(fr, stims, bin_sec = bin_sec, normalize=False, resp_length_samp = 4000)

    temp_list = [{'Channel_no':ii,
                  'current':current,
                  'distance':distance,
                  'max_fr':mean_fr[:,ii].max(),
                  'spike':spike_dict[ii]['waveform'].mean(axis=0)} for ii in channels]
    
    fr_list.extend(temp_list)


fr_df = pd.DataFrame(fr_list)

with open(os.path.join(base_dir,'firing_rates.pkl'), 'wb') as fid:
    pd.to_pickle(fr_df, fid)

  0%|          | 0/24 [00:00<?, ?it/s]

20mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\raw_signal.pkl
20mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-24-31__20mA_2ms_400um\raw_signal.pkl
20mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-26-46__20mA_2ms_400um\raw_signal.pkl
15mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-35-37__15mA_2ms_400um\raw_signal.pkl
15mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-37-44__15mA_2ms_400um\raw_signal.pkl
10mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-39-41__10mA_2ms_400um\raw_signal.pkl
10mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-41-41__10mA_2ms_400um\raw_signal.pkl
5mA 400mm
loading previously converted file Z:\B

  'spike':spike_dict[ii]['waveform'].mean(axis=0)} for ii in channels]
  ret = ret.dtype.type(ret / rcount)


15mA 400mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-52-58__15mA_2ms_400um\raw_signal.pkl
5mA 600mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-10-56__5mA_2ms_600um\raw_signal.pkl
10mA 600mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-17-09__10mA_2ms_600um\raw_signal.pkl
15mA 600mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-20-08__15mA_2ms_600um\raw_signal.pkl
20mA 600mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-28-37__20mA_2ms_600um\raw_signal.pkl
20mA 900mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-31-23__20mA_2ms_900um\raw_signal.pkl
15mA 900mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-39-44__15mA_2ms_900um\raw_signal.pkl


  'spike':spike_dict[ii]['waveform'].mean(axis=0)} for ii in channels]
  ret = ret.dtype.type(ret / rcount)


10mA 900mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-42-12__10mA_2ms_900um\raw_signal.pkl
5mA 900mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-51-03__5mA_2ms_900um\raw_signal.pkl


  'spike':spike_dict[ii]['waveform'].mean(axis=0)} for ii in channels]
  ret = ret.dtype.type(ret / rcount)


5mA 1200mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_13-54-13__5mA_2ms_1200um\raw_signal.pkl
10mA 1200mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_14-02-07__10mA_2ms_1200um\raw_signal.pkl


  'spike':spike_dict[ii]['waveform'].mean(axis=0)} for ii in channels]
  ret = ret.dtype.type(ret / rcount)


15mA 1200mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_14-05-20__15mA_2ms_1200um\raw_signal.pkl
20mA 1200mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_14-12-40__20mA_2ms_1200um\raw_signal.pkl
20mA 1500mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_14-15-26__20mA_2ms_1500um\raw_signal.pkl
15mA 1500mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_14-22-56__15mA_2ms_1500um\raw_signal.pkl
10mA 1500mm
loading previously converted file Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_14-26-40__10mA_2ms_1500um\raw_signal.pkl


  'spike':spike_dict[ii]['waveform'].mean(axis=0)} for ii in channels]
  ret = ret.dtype.type(ret / rcount)


Plots for the max firing rates as a function of distance or current. Also looking at the least squares fit

In [141]:
fig_fr, ax_fr = plt.subplots()

electrode_map = loadmat('Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat')
cmap = np.array(plt.colormaps.get_cmap('tab10').colors)
curr = 15

intrinsic_dist = electrode_map['xcoords'][fr_df.Channel_no].astype(float) - 375
fr_df['Chan_dist'] = np.sqrt(fr_df.distance.astype(float)**2 + intrinsic_dist.ravel()**2)
fr_temp = fr_df.loc[fr_df.current.eq(curr)]

ax_fr.scatter(fr_temp['Chan_dist'], fr_temp['max_fr'], c = cmap[electrode_map['kcoords'][fr_temp.Channel_no],:])
# ax_fr.

# # lstsq regression
# B = np.linalg.lstsq(np.column_stack([fr_temp['Chan_dist'], np.ones_like(fr_temp['Chan_dist'])]), fr_temp['distance'])
# Xi = np.array([[fr_temp['distance'].min(),1],[fr_temp['distance'].max(),1]])
# ax_fr.plot(Xi[:,0], np.matmul(Xi,B[0]), color='black')

ax_fr.set_xlabel('distance (um)')
ax_fr.set_ylabel('firing rate (Hz)')

Text(0, 0.5, 'firing rate (Hz)')

Now that we have cleaned stuff, let's try running things through kilosort again to compare with the plain threshold crossings

In [204]:
kilosort.io.load_probe('Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat')

{'xc': array([500., 500., 500., 500., 500., 500., 500., 500., 500., 500., 500.,
        750., 500., 750., 500., 750., 500., 750., 500., 750., 750., 750.,
        500., 750., 750., 750., 750., 750., 750., 750., 750., 750.,   0.,
          0.,   0.,   0.,   0.,   0.,   0.,   0., 250.,   0.,   0.,   0.,
        250.,   0., 250.,   0., 250.,   0., 250.,   0., 250.,   0., 250.,
        250., 250., 250., 250., 250., 250., 250., 250., 250.],
       dtype=float32),
 'yc': array([  0., 250., 200., 350., 150., 450.,  50., 400., 100., 600., 300.,
        750., 500., 550., 700., 350., 650., 700., 550., 100., 500., 600.,
        750.,  50., 300., 400., 250., 200.,   0., 450., 150., 650., 100.,
        600.,  50., 400., 200., 250., 350., 450., 700.,   0., 550., 650.,
        500., 150., 600., 750., 750., 300., 550., 500., 350., 700., 150.,
        650.,   0., 450., 100., 400., 250., 300.,  50., 200.],
       dtype=float32),
 'kcoords': array([3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 3., 4., 3., 4., 3

In [38]:
filename = 'Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\sig_eraasr.npy'
np.save(filename, sig_eraasr.astype(np.float32))
settings = {'filename':filename,
            'probe_name':'Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat',
            'n_chan_bin':64,
            'nearest_chans':1,
            'data_dtype':'float32'}

kilosort.run_kilosort(settings,  file_object=sig_eraasr.astype(np.float32), data_dtype='float32')

kilosort.run_kilosort: Kilosort version 4.0.17
kilosort.run_kilosort: Kilosort version 4.0.17
kilosort.run_kilosort: Kilosort version 4.0.17
kilosort.run_kilosort: Kilosort version 4.0.17
kilosort.run_kilosort: Kilosort version 4.0.17
kilosort.run_kilosort: Sorting Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\sig_eraasr.npy
kilosort.run_kilosort: Sorting Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\sig_eraasr.npy
kilosort.run_kilosort: Sorting Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\sig_eraasr.npy
kilosort.run_kilosort: Sorting Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\sig_eraasr.npy
kilosort.run_kilosort: Sorting Z:\BrainPatch\20241002\lateral\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\sig_eraasr.npy
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: ----------------------------------------
kilosort.run_kilosort: ----

({'n_chan_bin': 64,
  'fs': 30000,
  'batch_size': 60000,
  'nblocks': 1,
  'Th_universal': 9,
  'Th_learned': 8,
  'tmin': 0,
  'tmax': inf,
  'nt': 61,
  'shift': None,
  'scale': None,
  'artifact_threshold': inf,
  'nskip': 25,
  'whitening_range': 32,
  'highpass_cutoff': 300,
  'binning_depth': 5,
  'sig_interp': 20,
  'drift_smoothing': [0.5, 0.5, 0.5],
  'nt0min': 20,
  'dmin': 50.0,
  'dminx': 32,
  'min_template_size': 10,
  'template_sizes': 5,
  'nearest_chans': 1,
  'nearest_templates': 100,
  'max_channel_distance': 50.0,
  'templates_from_data': True,
  'n_templates': 6,
  'n_pcs': 6,
  'Th_single_ch': 6,
  'acg_threshold': 0.2,
  'ccg_threshold': 0.25,
  'cluster_downsampling': 20,
  'x_centers': None,
  'duplicate_spike_ms': 0.25,
  'filename': WindowsPath('Z:/BrainPatch/20241002/lateral/Crimson__2024-10-02_12-21-01__20mA_2ms_400um/sig_eraasr.npy'),
  'probe_name': 'Z:\\BrainPatch\\20241002\\64-4shank-poly-brainpatch-chanMap.mat',
  'data_dtype': 'float32',
  'data_dir

In [17]:
with open(filename, 'rb') as fid: 
    templates = np.load('Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\kilosort4\\templates.npy')

In [20]:
# load spike times, template shape, and channel number
spike_times = np.load('Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\kilosort4\\spike_times.npy')
chan_map = np.load('Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\kilosort4\\channel_map.npy')
templates = np.load('Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\kilosort4\\templates.npy')
st = np.load('Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\kilosort4\\spike_templates.npy')

# best channel for each template. Might change this to channel with greatest minimum value
chan_best = (templates**2).sum(axis=1).argmax(axis=-1)
chan_best = chan_map[chan_best]

# channels of interest
channels = [5, 17, 35, 37, 50, 51]

fig, ax = plt.subplots(nrows=len(channels), sharex=True, sharey=True)

cmap = np.array(plt.colormaps.get_cmap('tab10').colors)

for i_channel, channel in enumerate(channels):
    ax[i_channel].plot(np.arange(sig_eraasr.shape[0])/30000, sig_eraasr[:,channel], label=channel)

    chan_templates = np.nonzero(chan_best == channel)

    for i_ct, ct in enumerate(chan_templates[0]):
        ax[i_channel].vlines(spike_times[st == ct]/30000, sig_eraasr[:,channel].min(), sig_eraasr[:,channel].max(), color = cmap[i_ct+1, :], alpha=0.3)




    ax[i_channel].set_ylabel('magnitude (uV)')
    ax[i_channel].set_title(channel)



In [17]:
np.nonzero(chan_best == 5)

(array([44], dtype=int64),)

In [3]:
st = np.load('Z:\\BrainPatch\\20241002\\lateral\\Crimson__2024-10-02_12-21-01__20mA_2ms_400um\\kilosort4\\spike_templates.npy')

In [None]:
spike_times

array([60, 69, 25, 73, 50, 13, 49, 52, 32, 52])