In [1]:
# Goal of this notebook: Deconstruct code from schemas, and implement on a smaller scale with imported recording from extractor. 


In [None]:

@schema #define a schema, which is/ will be a group of related tables 
class SpikeSortingArtifactDetectionParameters(dj.Manual): #table creation, manually adding information to it. This table will hold the parameters for artifact detection. 
    #each set of parameters will have a name and a dictionary that contains the parameters, which will be used in the below function.  
    definition = """
    # Table for holding parameters related to artifact detection
    # Note that
    artifact_parameter_name: varchar(200) #name for this set of parameters 
    ---
    parameter_dict: BLOB    # dictionary of parameters for get_no_artifact_times() function
    """

    def insert_default(self): #This will run if there are no arguments to original class?
        #first, insert the default, none, into the table. this contains the parameters that will likely not have any imapct on data. 
        """Insert the default artifact parameters ('none') with a appropriate parameter dict .
        """
        param_dict = {} #empty dictionary called param_dict
        param_dict['skip'] = True
        param_dict['zscore_thresh'] = -1.0
        param_dict['amplitude_thresh'] = -1.0
        param_dict['proportion_above_thresh'] = -1.0
        param_dict['zero_window_len'] = 30 # 1 ms at 30 KHz, but this is of course skipped
        self.insert1({'artifact_parameter_name': 'none', 'parameter_dict' : param_dict},
                        skip_duplicates=True)
#self is a variable that points to the instance of the method you are working with. 

#Looks like this is a method that will actually find times with artifact and remove them. 
#Artifacts are defined as periods where the absolute amplitude of the signal exceeds one or both specified thresholds on the proportion of channels specified, with the period extended by the zero_window/2 samples on each side


#SO this section will take the parameters (why are they hard coded here?) 
#And will return valid times. 
#DOnt have a primary key? 


#want this function to be tied to tables info
#could go around other dj.manual tables, see what kinds of functions you see in them (do they fetch at the end, etc) 
#HOw functions are set up in python (regarding the hard coded things)- if you dont insert, it will default to the hard coded things  
    #find a way to share notebook for us both to check unestanding 
    
    def get_no_artifact_times(self, recording, zscore_thresh=-1.0, amplitude_thresh=-1.0, 
                              proportion_above_thresh=1.0, zero_window_len=1.0, skip: bool=True):
        #Defines the parameters for artifact detection and their types. 
        """returns an interval list of valid times, excluding detected artifacts found in data within recording extractor.
        Artifacts are defined as periods where the absolute amplitude of the signal exceeds one
        or both specified thresholds on the proportion of channels specified, with the period extended
        by the zero_window/2 samples on each side
        Threshold values <0 are ignored.

        :param recording: recording extractor
        :type recording: SpikeInterface recording extractor object
        :param zscore_thresh: Stdev threshold for exclusion, defaults to -1.0
        :type zscore_thresh: float, optional
        :param amplitude_thresh: Amplitude threshold for exclusion, defaults to -1.0
        :type amplitude_thresh: float, optional
        :param proportion_above_thresh:
        :type float, optional
        :param zero_window_len: the width of the window in milliseconds to zero out (window/2 on each side of threshold crossing)
        :type int, optional
        :return: [array of valid times]
        :type: [numpy array]
        """

        # if no thresholds were specified, we return an array with the timestamps of the first and last samples
        if zscore_thresh <= 0 and amplitude_thresh <= 0:
            return np.asarray([[recording._timestamps[0], recording._timestamps[recording.get_num_frames()-1]]])
       
    #This is where the artifact detection and comparison to thresholds occurs. 
    
    #use the specified window length (how long to zero out on either sde of the artifact based on sampling rate
        
        half_window_points = np.round(
            recording.get_sampling_frequency() * 1000 * zero_window_len / 2)
      
    #User defines proportion of electrodes that have to be above threshold for it to be labeled artifact (but two diff versions of this?)
    nelect_above = np.round(proportion_above_thresh * data.shape[0])
    
        # get the data traces
        data = recording.get_traces()

        # compute the number of electrodes that have to be above threshold based on the number of rows of data
        nelect_above = np.round(
            proportion_above_thresh * len(recording.get_channel_ids()))

        # apply the amplitude threshold (find when the data is above amplitude threshold)
        above_a = np.abs(data) > amplitude_thresh

        # zscore the data and get the absolute value for thresholding
        dataz = np.abs(stats.zscore(data, axis=1))
        above_z = dataz > zscore_thresh #find when datas z scare is above specified threshold
       
    #not sure what ravel does. im guessing that it looks acrross electrodes(?) and finds when the sum of them being over thresh is bigger than neglect_above
        above_both = np.ravel(np.argwhere(
            np.sum(np.logical_and(above_z, above_a), axis=0) >= nelect_above))
        valid_timestamps = recording._timestamps #not sure what this is doing yet.. why are these valid timesteps?
        
        # for each above threshold point, set the timestamps on either side of it to -1
        #So, have a list of valid timesteps, and set times that are "above both" plus or minus window points to negative 1...
        for a in above_both:
            valid_timestamps[a - half_window_points:a +
                             half_window_points] = -1

        #anything that is not -1 is now a valid timestamp,     
        # use get_valid_intervals to find all of the resulting valid times.
        #What is get_valid_intervals? Look into these other arguments 
        #Go find get valid intervals 
        return get_valid_intervals(valid_timestamps[valid_timestamps != -1], recording.get_sampling_frequency(), 1.5, 0.001)




In [2]:
#Okay first, make dictionary with parameters

In [5]:
skip = True
zscore_thresh = -1.0
amplitude_thresh=-1.0
proportion_above_thresh = -1.0
zero_window_len=30


param_dict = {} #empty dictionary called param_dict
param_dict['skip'] = skip
param_dict['zscore_thresh'] = zscore_thresh
param_dict['amplitude_thresh'] = amplitude_thresh
param_dict['proportion_above_thresh'] = proportion_above_thresh
param_dict['zero_window_len'] = zero_window_len # 1 ms at 30 KHz, but this is of course skipped


#self.insert1({'artifact_parameter_name': 'none', 'parameter_dict' : param_dict},skip_duplicates=True)

In [6]:
param_dict

{'skip': True,
 'zscore_thresh': -1.0,
 'amplitude_thresh': -1.0,
 'proportion_above_thresh': -1.0,
 'zero_window_len': 30}

In [7]:
recording = 'some path'

In [None]:
def get_no_artifact_times(self, recording, zscore_thresh=-1.0, amplitude_thresh=-1.0, 
                              proportion_above_thresh=1.0, zero_window_len=1.0, skip: bool=True):
    #this just says whe to return the same timestamps. 
    if zscore_thresh <= 0 and amplitude_thresh <= 0:
        return np.asarray([[recording._timestamps[0], recording._timestamps[recording.get_num_frames()-1]]])
    #in "deconstructing" the code, I still have to use the other functions defined in detejoint. 
    #Should I still use the np packages? 

In [None]:
# First, populate spike sortng table with file 
(SpikeSortingRecording() & {'nwb_file_name' : nwb_file_name}).fetch1('recording_extractor_object')



In [None]:
# Then, define keys: 


key = dict()
key['nwb_file_name'] = nwb_file_name
key['sort_group_id'] = sort_group_id
key['sort_interval_name'] = sort_interval_name
key['interval_list_name'] = interval_list_name
key['sorter_name'] ='mountainsort4'
key['parameter_set_name'] = 'franklab_tetrode_hippocampus_30KHz'
key['cluster_metrics_list_name'] ='franklab_cluster_metrics_09-19-2021'

In [None]:
#run the function to get recording extractor 
recording = SpikeSortingRecording().get_filtered_recording_extractor(key)