In [4]:
import os
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
import pandas as pd

from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache

In [5]:
import platform 

platstring = platform.platform()
if ('Darwin' in platstring) or ('macOS' in platstring):
    # macOS 
    data_root = "/Volumes/Brain2023/"
elif 'Windows'  in platstring:
    # Windows (replace with the drive letter of USB drive)
    data_root = "E:/"
elif ('amzn' in platstring):
    # then on Code Ocean
    data_root = "/data/"
else:
    # then your own linux platform
    # EDIT location where you mounted hard drive
    data_root = "/media/$USERNAME/Brain2023/"

manifest_path = os.path.join(data_root,'allen-brain-observatory/visual-coding-neuropixels/ecephys-cache/manifest.json') 
cache = EcephysProjectCache.from_warehouse(manifest=manifest_path) #creates a cache object
sessions = cache.get_session_table() #session_table is a pandas dataframe object that contains metadata about all sessions in the cache
sessions.head() #show the first 5 rows of the dataframe object 

Unnamed: 0_level_0,published_at,specimen_id,session_type,age_in_days,sex,full_genotype,unit_count,channel_count,probe_count,ecephys_structure_acronyms
id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
715093703,2019-10-03T00:00:00Z,699733581,brain_observatory_1.1,118.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,884,2219,6,"[CA1, VISrl, nan, PO, LP, LGd, CA3, DG, VISl, ..."
719161530,2019-10-03T00:00:00Z,703279284,brain_observatory_1.1,122.0,M,Sst-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,755,2214,6,"[TH, Eth, APN, POL, LP, DG, CA1, VISpm, nan, N..."
721123822,2019-10-03T00:00:00Z,707296982,brain_observatory_1.1,125.0,M,Pvalb-IRES-Cre/wt;Ai32(RCL-ChR2(H134R)_EYFP)/wt,444,2229,6,"[MB, SCig, PPT, NOT, DG, CA1, VISam, nan, LP, ..."
732592105,2019-10-03T00:00:00Z,717038288,brain_observatory_1.1,100.0,M,wt/wt,824,1847,5,"[grey, VISpm, nan, VISp, VISl, VISal, VISrl]"
737581020,2019-10-03T00:00:00Z,718643567,brain_observatory_1.1,108.0,M,wt/wt,568,2218,6,"[grey, VISmma, nan, VISpm, VISp, VISl, VISrl]"


In [6]:
session_id = 715093703 #import a session id of interest
session = cache.get_session_data(session_id) #creates a session object


In [9]:
#get the unique orientations and temporal frequencies from the stimulus table 
unique_presentation = presentations['orientation'].unique() 
unique_temporal_frequency = presentations['temporal_frequency'].unique() 

#rempve the null values from the unique orientations and temporal frequencies
unique_presentation = unique_presentation[unique_presentation != 'null']
unique_temporal_frequency = unique_temporal_frequency[unique_temporal_frequency != 'null']

#print the unique orientations and temporal frequencies
print(unique_presentation)
print(unique_temporal_frequency)


[315.0 90.0 225.0 135.0 0.0 270.0 180.0 45.0]
[4.0 8.0 2.0 1.0 15.0]


In [11]:
def get_spike_counts(session_id, region, window_duration, time_step=0.001, pre_stim_window=-0.1):
    
    #to-do: 
        #update drifting_gratings can be changed to other stimuli for scaling purposes
    
    #input: session id, region of interest, window duration
        #session id: int, specific session id of interest
        #region of interest: string, specific region of interest
        #window duration: float, duration of the window of interest in seconds
        
    #output: spike counts for the region of interest in xarray format where the dimensions are unit, time, and stimulus presentation id in that order
    
    #example: spikes = get_spike_counts(715093703, "VISp", 0.5) 
    
    
    session = cache.get_session_data(session_id) #creates a session object  
    presentations = session.get_stimulus_table('drifting_gratings') #creates a stimulus table object for drifting gratings,  
    
    presentations = presentations[presentations.orientation != 'null'] #remove the null 
    #time_step = 0.001 #in seconds 
    #pre_stim_window = -0.1 #in seconds
    
    time_bins = np.arange(pre_stim_window, window_duration + time_step, time_step) #create an array of time bins for the window of interest 
    
    #apply the .presentationwise_spike_counts method to the session object to get the spike counts for the region of interest 
    
    spikes = session.presentationwise_spike_counts(
        stimulus_presentation_ids = presentations.index.values, 
        bin_edges = time_bins, 
        unit_ids = session.units.query('snr > 2.5 & isi_violations < 0.5 & amplitude_cutoff < 0.1 & presence_ratio > 0.9 & ecephys_structure_acronym == @region').index.values
        )
    return spikes


#test the function
spikes = get_spike_counts(715093703, "VISp", 2)

In [41]:
def get_metrics(session_id, spikes, pre=-0.05, post= 0.1): 
    #input: session id, spikes variable, pre and post stimulus window
        #session id: int, specific session id of interest
        #spikes: xarray, spike counts for the region of interest in xarray format where the dimensions are unit, time, and stimulus presentation id in that order
        #pre: float, pre stimulus window in seconds
        #post: float, post stimulus window in seconds
        
    #output: dataframe with the following columns: unit_id, orientation, temporal_frequency, pre_stimulus_mean, post_stimulus_mean
    
    #example: df_metrics = get_metrics(715093703, spikes, pre=-0.05, post= 0.1)
    
    #user input parameters
    pre = -0.05
    post = 0.1
    
    # spikes is the input xarray variable for the user to provide
    session = cache.get_session_data(session_id) #creates a session object
    presentations = session.get_stimulus_table('drifting_gratings') #creates a stimulus table object for drifting gratings
    
    #get the unique orientations and temporal frequencies from the stimulus table
    unique_orientations = presentations['orientation'].unique()
    unique_temporal_frequency = presentations['temporal_frequency'].unique()
    
    #rempve the null values from the unique orientations and temporal frequencies
    unique_orientations = unique_presentation[unique_presentation != 'null']
    unique_temporal_frequency = unique_temporal_frequency[unique_temporal_frequency != 'null']
    
    #creating a mask for using np.where
    pre_indices = np.where((spikes.time_relative_to_stimulus_onset.values >= pre) & (spikes.time_relative_to_stimulus_onset.values <0))[0] #np.where returns a list with two elements and only need first
    post_indices = np.where((spikes.time_relative_to_stimulus_onset.values >= 0 ) & (spikes.time_relative_to_stimulus_onset.values < post))[0] #np.where returns a list with two elements and only need first
    
    #select the pre stimulus onset spikes
    pre_spikes = spikes.loc[:, spikes.time_relative_to_stimulus_onset.values[pre_indices], :]
    
    #select the post stimulus onset spikes
    post_spikes = spikes.loc[:, spikes.time_relative_to_stimulus_onset.values[post_indices], :]
    
    #initialize an empty dataframe with all the columns
    df_metrics = pd.DataFrame(columns=['unit_id', 'orientation', 'temporal_frequency', 'pre_stimulus_mean', 'post_stimulus_mean'], dtype=float)
    
    #for each unit_id, for each unique orientaiton, calculate the mean pre and post and normalize based on the number of trials for each orientation
    #for the unit, iterate throuh each unit but ensure the values are integers and not floats
    for unit_id in spikes.unit_id.values:
        #for orientation in unique_orientations:
        for orientation in unique_orientations:
            #for temporal_freq in unique_temporal_frequency:
            for temporal_freq in unique_temporal_frequency:
                #grab the trial ids for the specific orientation and temporal frequency
                trial_ids = presentations[(presentations.orientation == orientation) & (presentations.temporal_frequency == temporal_freq)].index.values.tolist()
                
                #create an empty list
                trial_spikes_pre = []
                trial_spikes_post = []
                
                #for each trial id, grab the spike count for the unique pair here per trial and store into a list to collect the length of the list
                #repeat this process for both pre and post spikes variables
                
                for trial_id in trial_ids:
                    #grab the pre and post spikes for the specific trial id
                    trial_spikes_pre.append(pre_spikes.loc[trial_id, :, unit_id])
                    trial_spikes_post.append(post_spikes.loc[trial_id, :, unit_id])
                    
                    #now once the pre and post spikes are stored in a list, calculate the mean and proceed to store the data into the dataframe
                    pre_stimulus_mean = np.mean(trial_spikes_pre)
                    post_stimulus_mean = np.mean(trial_spikes_post)
                    
                    #convert to Hz as the data is in ms
                    pre_stimulus_mean = pre_stimulus_mean * 1000
                    
                    post_stimulus_mean = post_stimulus_mean * 1000
                    
                    #store the data in the df_metrics dataframe
                    df_metrics = df_metrics.append({'unit_id': unit_id, 'orientation': orientation, 'temporal_frequency': temporal_freq, 'pre_stimulus_mean': pre_stimulus_mean, 'post_stimulus_mean': post_stimulus_mean}, ignore_index=True)

    return df_metrics

#test the function
df_metrics = get_metrics(715093703, spikes, pre=-0.05, post= 0.1)
    

  df_metrics = df_metrics.append({'unit_id': unit_id, 'orientation': orientation, 'temporal_frequency': temporal_freq, 'pre_stimulus_mean': pre_stimulus_mean, 'post_stimulus_mean': post_stimulus_mean}, ignore_index=True)
  df_metrics = df_metrics.append({'unit_id': unit_id, 'orientation': orientation, 'temporal_frequency': temporal_freq, 'pre_stimulus_mean': pre_stimulus_mean, 'post_stimulus_mean': post_stimulus_mean}, ignore_index=True)
  df_metrics = df_metrics.append({'unit_id': unit_id, 'orientation': orientation, 'temporal_frequency': temporal_freq, 'pre_stimulus_mean': pre_stimulus_mean, 'post_stimulus_mean': post_stimulus_mean}, ignore_index=True)
  df_metrics = df_metrics.append({'unit_id': unit_id, 'orientation': orientation, 'temporal_frequency': temporal_freq, 'pre_stimulus_mean': pre_stimulus_mean, 'post_stimulus_mean': post_stimulus_mean}, ignore_index=True)
  df_metrics = df_metrics.append({'unit_id': unit_id, 'orientation': orientation, 'temporal_frequency': temporal_fre

In [43]:
df_metrics.head()

Unnamed: 0,unit_id,orientation,temporal_frequency,pre_stimulus_mean,post_stimulus_mean
0,950930145.0,315.0,4.0,0.0,0.0
1,950930145.0,315.0,4.0,0.0,0.0
2,950930145.0,315.0,4.0,0.0,0.0
3,950930145.0,315.0,4.0,0.0,0.0
4,950930145.0,315.0,4.0,0.0,0.0


In [22]:
#what are the unique orientations and temporal frequencies from df_metrics
unique_orientations = df_metrics['orientation'].unique()
#print the unique orientations
print(unique_orientations)

#print the unique temporal frequencies
unique_temporal_frequency = df_metrics['temporal_frequency'].unique()
print(unique_temporal_frequency)

#find the max mean firing of a unqiue unit 
max_mean_firing = df_metrics.groupby(['unit_id']).max()
max_mean_firing

[315.  90. 225. 135.   0. 270. 180.  45.]
[ 4.  8.  2.  1. 15.]


Unnamed: 0_level_0,orientation,temporal_frequency,pre_stimulus_mean,post_stimulus_mean
unit_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
950930145.0,315.0,15.0,0.02,0.014667
950930237.0,315.0,15.0,0.02,0.03
950930340.0,315.0,15.0,0.02,0.02
950930407.0,315.0,15.0,0.08,0.12
950930423.0,315.0,15.0,0.03,0.02
950930437.0,315.0,15.0,0.01,0.005
950930522.0,315.0,15.0,0.04,0.05
950930795.0,315.0,15.0,0.02,0.04
950930866.0,315.0,15.0,0.008,0.020769
950930964.0,315.0,15.0,0.08,0.16


In [None]:
def get_spikes_append_waveformlabels_2_unit_table(session_id, region):
     
     #input: session id, region of interest
        #session id: int, specific session id of interest
        #region of interest: string, specific region of interest
        
    #output: good unit table with the cell type column
    #example: good_unit_table = get_spikes_append_waveformlabels(715093703, "VISp")
    
    session = cache.get_session_data(session_id) #creates a session object
    good_unit_table = session.units.query('snr > 2.5 & isi_violations < 0.5 & amplitude_cutoff < 0.1 & presence_ratio > 0.9 & ecephys_structure_acronym == @region') #creates a list of good units for the region of interest
    good_unit_table['cell_type'] = np.where(good_unit_table['waveform_duration'] > 0.4, 'RS', 'FS') #creates a new column in the good unit table that labels the units as RS or FS based on the waveform duration,
    #if the waveform duration is greater than 0.4 it is labeled as RS and if it is less than 0.4 it is labeled as FS
    
    return good_unit_table

#test the function
good_unit_table = get_spikes_append_waveformlabels_2_unit_table(session_id, region)

In [None]:
#print the the number of FS and RS units in the good unit table
print(good_unit_table['cell_type'].value_counts()) #print the number of FS and RS units in the good unit table

In [None]:
def get_ecephyssession_units_object(session_id, region): 
    
    #input: session id, region of interest
        #session id: int, specific session id of interest
        #region of interest: string, specific region of interest
        
    #output: session object
    
    #example: test_session = get_ecephyssession_object(715093703, "VISp")
    
    session_units = cache.get_units.query(session_id) #creates a session object  
    
    passed_qc_unit_ids = session_units.query('snr > 2.5 & isi_violations < 0.5 & amplitude_cutoff < 0.1 & presence_ratio > 0.9 & ecephys_structure_acronym == @region').index.values.tolist() #creates a list of unit ids that passed qc
    
    
    return passed_qc_unit_ids


#test the function
test = get_ecephyssession_units_object(715093703, "VISp")
    
    

In [None]:
session.structurewise_unit_counts #gives the number of units in each brain area for a given session

In [None]:
# take in session id and allow the user to access the meta data for that session 
def get_session_data(session_id):
    session_metadata = cache.get_session_data(session_id).metadata
    return session_metadata 

In [None]:
specific_id = session_1.get('ecephys_session_id') #use the get method to access the value associated with the key 'ecephys_session_id'

#print the value associated with the key 'ecephys_session_id'
print('ecephys_session_id for the first index of session_1: ' + str(session_1.get('ecephys_session_id')))

In [None]:
#using the specific session id, access the session data for that session 
session_1_data = cache.get_session_data(specific_id)
print('Number of units in the session table: ' + str(len(session_1_data.units))) #how many units are there in the session table?

In [None]:
dg = session_1_data.get_stimulus_table('drifting_gratings')
dg[dg.orientation == 90]

In [None]:
#define regions of interest 
region_of_interest = ['VISp']; #visual cortex, primary, will be used for now and will be a list of strings 

#prepare the time bins for computing the PSTH 
time_step = 0.05 # in seconds, will be a user input later 
pre_time = -0.5 # in seconds, will be a user input later
time_bins = np.arange(pre_time, 1.5, time_step) #will be a user input later, pre_time is the start of the time bins, 1.5 is the end of the time bins, time_step is the step size between each time bin 

session_1_data.structurewise_unit_counts #this is a pandas dataframe object that contains the number of units in each brain area for the session

temporal_freqs = session_1_data.get_stimulus_table('drifting_gratings')['temporal_frequency'].unique()

In [None]:
session_1_data.unit?

In [None]:
presentations = session_1_data.get_stimulus_table(["drifting_gratings"]) #this is a pandas dataframe object that contains the stimulus table for drifting gratings
print(presentations['orientation'].unique()) #this prints the unique orientation values in the stimulus table for drifting gratings

presentations


def sort_orientation_and_temporal_freq(session_data): 
    #'sort_orientation_and_temporal_freq' will sort the stimulus table for drifting gratings by orientation and temporal frequency values
    
    #input: session_data, a session data object from the AllenSDK of a single session
    #output: a sorted stimulus table for drifting gratings
    
    #create a list of unique orientation values in the stimulus table
    presentations = session_data.get_stimulus_table('drifting_gratings')['orientation'].unique() 

    #create a list of unique temporal frequency values in the stimulus table
    temporal_freqs = session_data.get_stimulus_table('drifting_gratings')['temporal_frequency'].unique()
    
    #sort the dataframe for each unique combination of orientation and temporal frequency values
    for orientation in presentations: #iterate over each unique orientation value in the stimulus table
        for temporal_freq in temporal_freqs: #iterate over each unique temporal frequency value in the stimulus table 
    

In [None]:
print(f'{session_1_data.units.shape[0]} units total') #print the number of units in the session table
units_with_very_high_snr = session_1_data.units[session_1_data.units['snr'] > 4] #create a dataframe of units with snr > 4
print(f'{units_with_very_high_snr.shape[0]} units have snr > 4') #print the number of units with snr > 4

#plot the distribution of snr values for all units in the session table 
plt.hist(session_1_data.units['snr'], bins=100)
plt.xlabel('snr')
plt.ylabel('number of units')
plt.title('Distribution of snr values for all units in the session table')
plt.show()

#test for bimodality of snr values in the session table
from scipy.stats import kstest
kstest(session_1_data.units['snr'], 'norm') #if the p-value is less than 0.05, then the distribution is not normal

#plot the distribution of snr values for units with snr > 4
plt.hist(units_with_very_high_snr['snr'], bins=100)
plt.xlabel('snr')
plt.ylabel('number of units')
plt.title('Distribution of snr values for units with snr > 4')
plt.show()

#test for bimodality of snr values for units with snr > 4 
kstest(units_with_very_high_snr['snr'], 'norm') #if the p-value is less than 0.05, then the distribution is not normal



In [None]:
#create a function that will iterate over the session_ids list, access the session data for each session id, and store the snr values for each session 
# pandas dataframe 

def get_snr_values(session_ids):
    snr_values = pd.DataFrame(columns=['session_id', 'unit_id', 'snr']) #create an empty pandas dataframe with columns 'session_id', 'unit_id', and 'snr'
    
    for session_id in session_ids: #iterate over the session_ids list
        session_data = cache.get_session_data(session_id) #access the session data for each session id
        session_units = session_data.units[['snr']] #create a dataframe with the snr values for each unit in the session
        session_units['unit_id'] = session_units.index #add a column 'unit_id' with the index values of the session_units dataframe
        session_units['session_id'] = session_id #add a column 'session_id' with the current session_id value
        session_units = session_units[['session_id', 'unit_id', 'snr']] #reorder the columns of the session_units dataframe
        snr_values = snr_values.append(session_units) #append the session_units dataframe to the snr_values dataframe
    return snr_values

#use the function to create a dataframe with the snr values for each unit in the session table
snr_values = get_snr_values(session_ids)

#plot the distribution of snr values for all units in the session table
plt.hist(snr_values['snr'], bins=100) #took 39m to run 

In [None]:
#save the snr_values dataframe as a csv file to the current working directory
snr_values.to_csv('snr_values_allunits.csv', index=False) 

In [None]:
# We're going to build an array of spike counts surrounding stimulus presentation onset
# To do that, we will need to specify some bins (in seconds, relative to stimulus onset)
time_bin_edges = np.linspace(-0.01, 0.4, 200)

# do the the above but for the sessin_1 data
flash_250_ms_stimulus_presentation_ids = session_1_data.stimulus_presentations[
    session_1_data.stimulus_presentations['stimulus_name'] == 'flashes'
].index.values

# and get a set of units with only decent snr
decent_snr_unit_ids = session_1_data.units[
    session_1_data.units['snr'] >= 1.5
].index.values

spike_counts_da = session_1_data.presentationwise_spike_counts(
    bin_edges=time_bin_edges,
    stimulus_presentation_ids=flash_250_ms_stimulus_presentation_ids,
    unit_ids=decent_snr_unit_ids
)

spike_counts_da 

#determine the number of unique units in the spike_counts_da dataframe
print('Number of unique units in the spike_counts_da dataframe: ' + str(len(spike_counts_da.unit_id.unique())))


In [None]:
mean_spike_counts = spike_counts_da.mean(dim='stimulus_presentation_id')
mean_spike_counts

In [None]:
from allensdk.brain_observatory.ecephys.visualization import plot_mean_waveforms, plot_spike_counts, raster_plot

plot_spike_counts(
    mean_spike_counts,
    mean_spike_counts['time_relative_to_stimulus_onset'],
    'mean spike count',
    'mean spike counts on flash_250_ms presentations'
)
plt.show() 

In [None]:
# use the session_1 data
units_of_interest = decent_snr_unit_ids[:35] #take the first 35 units with decent snr values

waveforms = {uid: session_1_data.mean_waveforms[uid] for uid in units_of_interest} #create a dictionary with the mean waveform for each unit of interest 
peak_channels = {uid: session_1_data.units.loc[uid, 'peak_channel_id'] for uid in units_of_interest} #create a dictionary with the peak channel for each unit of interest

# plot the mean waveform on each unit's peak channel for each unit of interest
plot_mean_waveforms(waveforms, units_of_interest, peak_channels) 
plt.show()


In [None]:
#plot the first waveform in the waveforms dictionary to visualize how the trough to peak amplitude is calculated 
plt.plot(waveforms[units_of_interest[0]].values[0])
plt.xlabel('sample number')
plt.ylabel('amplitude (microvolts)')
plt.title('First waveform in the waveforms dictionary')
plt.show()

#now overlay what the trough and peak values are on the first waveform in the waveforms dictionary
plt.plot(waveforms[units_of_interest[0]].values[0])
plt.axhline(y=waveforms[units_of_interest[0]].values[0].min(), color='r', linestyle='-')
plt.axhline(y=waveforms[units_of_interest[0]].values[0].max(), color='g', linestyle='-')
plt.xlabel('sample number')
plt.ylabel('amplitude (microvolts)')
plt.title('First waveform in the waveforms dictionary with trough and peak values')
plt.show()



In [None]:
#grab the first waveform in the waveforms dictionary to test the function
waveform = waveforms[units_of_interest[0]].values[0]

#find the max absolute in the waveform and then divide the waveform by that value
normalized_waveform = waveform / np.max(np.abs(waveform)) #normalize the waveform based on absolute max value 

#plot the normalized waveform and the original waveform

plt.plot(normalized_waveform)
plt.xlabel('sample number')
plt.ylabel('amplitude (microvolts)')
plt.title('Normalized waveform and original waveform')
plt.show()




In [None]:
#grab the first waveform in the waveforms dictionary to test the function
waveform = waveforms[units_of_interest[0]].values[0]

#find the max absolute in the waveform and then divide the waveform by that value
normalized_waveform = waveform / np.max(np.abs(waveform)) #normalize the waveform based on absolute max value 

# Normalizing
trough_location = np.where(normalized_waveform == np.min(normalized_waveform))[0][0]
max_after_trough = np.max(normalized_waveform[trough_location:])
max_after_trough_location = np.where(normalized_waveform == max_after_trough)[0][0]
max_before_trough = np.max(normalized_waveform[:trough_location])
max_before_trough_location = np.where(normalized_waveform == max_before_trough)[0][0]
difference = (max_after_trough_location - trough_location)/30 # ms trough to peak

#based on the normalize waveform, plot each of the values that will be calculated in the function on the waveform
plt.plot(normalized_waveform)
#add the name of the variable and the value of the variable to the plot
plt.text(0, 0.8, 'trough_location = ' + str(trough_location))

plt.axvline(x=trough_location, color='r', linestyle='-')
plt.axvline(x=max_after_trough_location, color='g', linestyle='-')

plt.xlabel('sample number')
plt.ylabel('amplitude (microvolts)')
plt.title('Normalized waveform with trough and peak values')
plt.show()

    

In [None]:
#what are the unique number of session ids in the waveform_values dictionary?
print('Number of unique session ids in the waveform_values dictionary: ' + str(len(waveform_values.keys())))


In [None]:
#grab the first waveform in the waveforms dictionary to test the function
waveform = waveforms[units_of_interest[0]].values[0]

#find the max absolute in the waveform and then divide the waveform by that value
normalized_waveform = waveform / np.max(np.abs(waveform)) #normalize the waveform based on absolute max value 

# Normalizing
trough_location = np.where(normalized_waveform == np.min(normalized_waveform))[0][0]
max_after_trough = np.max(normalized_waveform[trough_location:])
max_after_trough_location = np.where(normalized_waveform == max_after_trough)[0][0]
max_before_trough = np.max(normalized_waveform[:trough_location])
max_before_trough_location = np.where(normalized_waveform == max_before_trough)[0][0]
difference = (max_after_trough_location - trough_location)/30 # ms trough to peak

#create a function to iterative enter the dictionary waveforms_values with the values for each unit of interest and use this key to access session_1_data.units 
def get_waveform_from_all_sessions(input_dict):
    #iterate over the keys in the input_dict
    for session_id in input_dict.keys():
        #iterate over the the cell ids in the input_dict
        for unit_id in input_dict[session_id].keys(): 
            
            #store the mean waveform within in 'raw_waveform' key 
            input_dict[session_id][unit_id]['raw_waveform'] = session_1_data.mean_waveforms[unit_id].values[0]
            
            #find the max absolute in the waveform and then divide the waveform by that value
            normalized_waveform = input_dict[session_id][unit_id]['raw_waveform'] / np.max(np.abs(input_dict[session_id][unit_id]['raw_waveform'])) #normalize the waveform based on absolute max value
            
            #store the normalized waveform within the 'normalized_waveform' key
            input_dict[session_id][unit_id]['normalized_waveform'] = normalized_waveform
            
            #perform calculations on the normalized waveform to find the trough_to_peak_amplitude, trough_to_peak_time, trough_time, peak_time, trough_amplitude, and peak_amplitude
            trough_location = np.where(normalized_waveform == np.min(normalized_waveform))[0][0]
            max_after_trough = np.max(normalized_waveform[trough_location:])
            max_after_trough_location = np.where(normalized_waveform == max_after_trough)[0][0]
            max_before_trough = np.max(normalized_waveform[:trough_location])
            max_before_trough_location = np.where(normalized_waveform == max_before_trough)[0][0]
            difference = (max_after_trough_location - trough_location)/30 # ms trough to peak
            
            #store the trough_to_peak_amplitude, trough_to_peak_time, trough_time, peak_time, trough_amplitude, and peak_amplitude values within the input_dict
            input_dict[session_id][unit_id]['trough_to_peak_amplitude'] = difference
            input_dict[session_id][unit_id]['trough_to_peak_time'] = difference
            input_dict[session_id][unit_id]['trough_time'] = trough_location
            input_dict[session_id][unit_id]['peak_time'] = max_after_trough_location
            input_dict[session_id][unit_id]['trough_amplitude'] = max_before_trough
            input_dict[session_id][unit_id]['peak_amplitude'] = max_after_trough
            
            #classify the cell type based on difference value 
            if difference < 0.4:
                input_dict[session_id][unit_id]['cell_type'] = 'FS'
            elif difference > 0.4:
                input_dict[session_id][unit_id]['cell_type'] = 'RS'
            
    return input_dict

#pull out the first session id in the waveform_values dictionary to test the function
session_id = list(waveform_values.keys())[0]

#use this session id to test the function 
waveform_values_test = get_waveform_from_all_sessions({session_id: waveform_values[session_id]})

#access the content of the session_id key in the waveform_values_test dictionary
waveform_values_test[session_id]

#plot the normalized waveform for the first unit in the session_id key in the waveform_values_test dictionary
plt.plot(waveform_values_test[session_id][units_of_interest[0]]['normalized_waveform'])
plt.xlabel('sample number')
plt.ylabel('amplitude (microvolts)')
plt.title('Normalized waveform for the first unit in the session_id key in the waveform_values_test dictionary')
plt.show()

#what are the number of keys in the first entry of the session_id key in the waveform_values_test dictionary?
print('Number of keys in the first entry of the session_id key in the waveform_values_test dictionary: ' + str(len(waveform_values_test[session_id][units_of_interest[0]].keys())))            
            
            
            
            