In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import os
import copy
import matplotlib.pyplot as plt
import datajoint as dj
import spyglass as nd
from spyglass.common import (Session, IntervalList)
import spyglass.spikesorting as ss
from spyglass.spikesorting import (SortGroup, 
                                    SortInterval,
                                    SpikeSortingPreprocessingParameters,
                                    SpikeSortingRecording, 
                                    SpikeSorterParameters,
                                    SpikeSortingRecordingSelection,
                                    ArtifactDetectionParameters,
                                    ArtifactRemovedIntervalList,
                                  CuratedSpikeSorting)
os.chdir("/home/jguidera/Src/nwb_custom_analysis/")
from spikesorting_helpers import define_sort_interval_as_interval_list, set_spikesorting_directories
from jguidera_spikesorting import return_spikesorting_params
from jguidera_reference_electrode import ReferenceElectrode, make_refs_dict
from jguidera_brain_region import SortGroupTargetedLocation
from populate_jguidera_reference_electrode import populate_jguidera_reference_electrode
from vector_helpers import unpack_single_element

# Ignore certain warnings 
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

# Set directories
base_dir = "/stelmo/nwb/"
(base_dir, raw_dir, analysis_dir, kachery_storage_dir,
 recording_dir, sorting_dir, waveforms_dir, tmp_dir) = set_spikesorting_directories(base_dir)

# set env vars
os.environ['spyglass_BASE_DIR'] = str(base_dir)
os.environ['spyglass_RECORDING_DIR'] = str(recording_dir)
os.environ['spyglass_SORTING_DIR'] = str(sorting_dir)
os.environ['spyglass_WAVEFORMS_DIR'] = str(waveforms_dir)
os.environ['KACHERY_STORAGE_DIR'] = str(kachery_storage_dir)
os.environ['KACHERY_TEMP_DIR'] = str(tmp_dir)
os.environ['DJ_SUPPORT_FILEPATH_MANAGEMENT'] = 'TRUE'
os.environ['FIGURL_CHANNEL'] = 'franklab2'

dj.config["enable_python_native_blobs"] = True

In [None]:
# *** PARAMETERS ***
# NWB file
subject_id = "J16"
date = "20210531"
# Sort interval
starting_interval_list_names = ["raw data valid times"] # ["04_r2", "06_r3"] # use these interval lists and make changes as indicated by flags
NO_PREMAZE = True  # True to exclude periods when rat being carried to maze
NO_HOME = True  # True to exclude home epochs
widen_exclusion_factor = .001  # widen invalid intervals by this many seconds to account for small differences in start and stop of what should be same interval in different IntervalList entries
# Sort group
sort_group_ids = np.arange(17, 21)
set_sort_group_ids_by_targeted_location = True  # True to set sort groups based on targeted locations
targeted_locations = ["CA1"] # , "OFC", "mPFC"]  # only used if set_sort_group_ids_by_targeted_location is True
override_previous_sort_group = False  # True to remake sort group in datajoint table
# Preprocessing and spike sorter parameters depending on brain region
parameter_set_dict = return_spikesorting_params()
# Spike sorter
sorter = 'mountainsort4'
# Cluster metrics
cluster_metrics_list_name = 'franklab_cluster_metrics_09-19-2021'
# Automatic curation parameters
automatic_curation_parameter_set_name = "none"
# Lab team
team_name = 'JG_DG'
# ******************
if not isinstance(starting_interval_list_names, list):
    raise Exception("starting_interval_list_names must be a list")
    
# Define nwb file and check that exists in table
nwb_file_name = f"{subject_id}{date}_.nwb"
if len((Session() & {'nwb_file_name': nwb_file_name})) == 0: 
    raise Exception("nwb file not in Session table")

In [None]:
# Set sort group by shank
populate_jguidera_reference_electrode()
if override_previous_sort_group or (len(SortGroup & {'nwb_file_name': nwb_file_name}) == 0): 
    print("Setting sort group by shank")
    SortGroup().set_group_by_shank(nwb_file_name=nwb_file_name,
                                   references=make_refs_dict(nwb_file_name),
                                   omit_ref_electrode_group=True,)
                                  # omit_unitrodes=True)
from jguidera_task_event import TaskIdentification
TaskIdentification.populate({"nwb_file_name": nwb_file_name})
from populate_jguidera_brain_region import populate_SortGroupTargetedLocation
SortGroupTargetedLocation.populate({"nwb_file_name": nwb_file_name})
if set_sort_group_ids_by_targeted_location:  # define sort groups by targeted location if desired
    os.chdir("/home/jguidera/Src/nwb_custom_analysis/")
    from jguidera_brain_region import SortGroupTargetedLocation
    sort_group_ids = np.concatenate([(SortGroupTargetedLocation() & {"nwb_file_name": nwb_file_name,
                                  "targeted_location": targeted_location}).fetch("sort_group_id") 
                    for targeted_location in targeted_locations])

In [None]:
# Define valid times for sort
# Populate premaze durations table in case needed
from populate_jguidera_premaze_durations import populate_PremazeDurations
from jguidera_premaze_durations import PremazeDurations
populate_PremazeDurations()
from define_interval_list import define_interval_list_through_exclusion
interval_list_name, interval_list = define_interval_list_through_exclusion(starting_interval_list_names=starting_interval_list_names,
                                                   nwb_file_name=nwb_file_name,
                                                   NO_PREMAZE=NO_PREMAZE,
                                                   NO_HOME=NO_HOME,
                                                   widen_exclusion_factor=widen_exclusion_factor)
IntervalList.insert1({"nwb_file_name": nwb_file_name,
                      "interval_list_name": interval_list_name,
                     "valid_times": interval_list},
                     skip_duplicates=True)

In [None]:
# Define sort interval 
sort_interval_name, sort_interval = define_sort_interval_as_interval_list(interval_list_name,
                                                                          interval_list,
                                                                          nwb_file_name)

In [None]:
only_make_ss_recording = True
no_spike_sorting = False
no_sorting_view = True
no_metrics = False

In [None]:
for sort_group_id in sort_group_ids:
    print(f"On sort group: {sort_group_id}")
    sort_group_brain_region = SortGroupTargetedLocation().return_sort_group_targeted_location_map(nwb_file_name)[sort_group_id]
     
    # Spike sorting recording
    recording_key = {'nwb_file_name': nwb_file_name,
                     'sort_group_id': sort_group_id,
                     'sort_interval_name': sort_interval_name,
                     'preproc_params_name': parameter_set_dict["preproc_params_name"][sort_group_brain_region],
                     'interval_list_name': interval_list_name,
                     'team_name': team_name}
    SpikeSortingRecordingSelection.insert1(recording_key, skip_duplicates=True)
    SpikeSortingRecording.populate([(SpikeSortingRecordingSelection & recording_key).proj()])
    if only_make_ss_recording:
        continue
    
    # Artifact detection 
    targeted_location_sgs = SortGroupTargetedLocation().return_targeted_location_sort_group_map(nwb_file_name)
    # Detect artifacts within sort groups for probes
    from populate_jguidera_artifact import populate_ArtifactDetectionParameters
    populate_ArtifactDetectionParameters()
    if sort_group_brain_region in ["mPFC", "OFC"]:
        artifact_key = (nd.spikesorting.SpikeSortingRecording & recording_key).fetch1('KEY')
        artifact_key['artifact_params_name'] = parameter_set_dict["artifact"][sort_group_brain_region]
        nd.spikesorting.ArtifactDetectionSelection.insert1(artifact_key, skip_duplicates=True)
        nd.spikesorting.ArtifactDetection.populate([(nd.spikesorting.ArtifactDetectionSelection & artifact_key).proj()])
        artifact_removed_interval_list_name = (nd.spikesorting.ArtifactDetection & artifact_key).fetch1('artifact_removed_interval_list_name')
    # Detect artifacts across sort groups for tetrodes
    elif sort_group_brain_region in ["CA1"]:
        from populate_jguidera_artifact import (populate_ArtifactDetectionAcrossSortGroupsParams, 
                                                populate_ArtifactDetectionAcrossSortGroupsSelection)
        from jguidera_artifact import (ArtifactDetectionAcrossSortGroups,
                                       ArtifactDetectionAcrossSortGroupsParams,
                                       ArtifactDetectionAcrossSortGroupsSelection)
        from populate_jguidera_spikesorting import populate_SpikeSortingRecordingCohortParams
        from jguidera_spikesorting import (SpikeSortingRecordingCohort, SpikeSortingRecordingCohortParams)
        targeted_region = "CA1"
        sg_ids = targeted_location_sgs[targeted_region]
        preproc_params_name = parameter_set_dict["preproc_params_name"][targeted_region]
        populate_SpikeSortingRecordingCohortParams(nwb_file_name,
                                                   sort_interval_name, 
                                                   preproc_params_name,
                                                   sg_ids,)
        SpikeSortingRecordingCohort.populate()
        populate_ArtifactDetectionAcrossSortGroupsParams()
        populate_ArtifactDetectionAcrossSortGroupsSelection(nwb_file_name=nwb_file_name)
        spike_sorting_recording_cohort_param_name = (SpikeSortingRecordingCohortParams & {"nwb_file_name": nwb_file_name}).fetch1("spike_sorting_recording_cohort_param_name")
        ArtifactDetectionAcrossSortGroups.populate({"spike_sorting_recording_cohort_param_name": spike_sorting_recording_cohort_param_name})
    else:
        raise Exception(f"Artifact detection not specified for {sort_group_brain_region}")
        
    # Spike sorting
    sorter_params_name = parameter_set_dict["sorter_params_name"][sort_group_brain_region]
    sorting_key = (nd.spikesorting.SpikeSortingRecording & recording_key).fetch1('KEY')
    artifact_removed_interval_list_name = (ArtifactRemovedIntervalList & {"nwb_file_name": nwb_file_name, 
                              "sort_group_id": sort_group_id,
                              "sort_interval_name": sort_interval_name}).fetch1("artifact_removed_interval_list_name")
    sorting_key.update({'sorter': sorter,
                   'sorter_params_name': sorter_params_name,
                   'artifact_removed_interval_list_name': artifact_removed_interval_list_name})
    nd.spikesorting.SpikeSortingSelection.insert1(sorting_key, skip_duplicates=True)
    if not no_spike_sorting:
        nd.spikesorting.SpikeSorting.populate(sorting_key)
    curation_key = nd.spikesorting.Curation.insert_curation(sorting_key)
    
    # waveforms
    ss.WaveformParameters().insert_default()
    waveform_params_names = ['default_not_whitened', 'default_whitened']
    for waveform_params_name in waveform_params_names:
        waveform_key = curation_key.copy()
        waveform_key.update({'waveform_params_name': waveform_params_name})
        ss.WaveformSelection.insert1(waveform_key, skip_duplicates=True)
        ss.Waveforms.populate([(ss.WaveformSelection & waveform_key).proj()])
    wp = ss.WaveformParameters().fetch()
    
    if not no_metrics: 
        # metrics
        metric_params = {'peak_offset' : {'peak_sign' : 'neg'}}
        metric_params_name = 'DPG_just_peak_offset'
        ss.MetricParameters.insert1({'metric_params_name' : metric_params_name,
                                                 'metric_params' : metric_params},
                                                 skip_duplicates=True)
        metrics_params_name_dict = {'default_not_whitened': 'DPG_just_peak_offset',
                                    'default_whitened': "JG_DG_no_peak_offset_min_spikes"}
        metric_key = waveform_key.copy()
        for waveform_params_name in waveform_params_names:
            metric_key['metric_params_name'] =  metrics_params_name_dict[waveform_params_name]
            metric_key['waveform_params_name'] = waveform_params_name
            ss.MetricSelection.insert1(metric_key, skip_duplicates=True)
            ss.QualityMetrics.populate([(ss.MetricSelection & metric_key).proj()])

        # automatic curation
        ss.AutomaticCurationParameters().insert_default()
        auto_curation_params_name = 'JG_DG_AutoCuration_params'
        label_params = {'nn_noise_overlap' : ['>', 0.03, ['noise','reject']],
                        'isi_violation' : ['>', 1/400, ['noise','reject']]}
        ss.AutomaticCurationParameters().insert1({'auto_curation_params_name' : auto_curation_params_name,
                                                  'merge_params' : {},
                                                  'label_params' : label_params}, skip_duplicates=True)
        autocuration_key = metric_key.copy()
        autocuration_key['auto_curation_params_name'] = auto_curation_params_name
        ss.AutomaticCurationSelection.insert1(autocuration_key, skip_duplicates=True)
        ss.AutomaticCuration.populate([(ss.AutomaticCurationSelection & autocuration_key).proj()])

        auto_curation_id = (ss.AutomaticCuration & autocuration_key).fetch1('auto_curation_key')
        auto_curation_out_key = (ss.Curation & auto_curation_id).fetch1("KEY")

        # add to sortingview workspace for manual curation
        if not no_sorting_view: 
            ss.SortingviewWorkspaceSelection.insert1(auto_curation_out_key, skip_duplicates=True)
            ss.SortingviewWorkspace.populate(auto_curation_out_key)
            ss.SortingviewWorkspace().url(auto_curation_out_key)

        # Populate CuratedSpikeSorting
        ss.CuratedSpikeSortingSelection.insert1(auto_curation_out_key, skip_duplicates=True)
        ss.CuratedSpikeSorting.populate(auto_curation_out_key)
        ss.CuratedSpikeSorting.Unit & auto_curation_out_key