## NWB-Datajoint tutorial 1

**Note: make a copy of this notebook and run the copy to avoid git conflicts in the future**

This is the first in a multi-part tutorial on the NWB-Datajoint pipeline used in Loren Frank's lab, UCSF. It demonstrates how to run spike sorting within the pipeline.

If you have not done [tutorial 0](0_intro.ipynb) yet, make sure to do so before proceeding.

Let's start by importing the `nwb_datajoint` package, along with a few others. 

In [None]:
from pathlib import Path
import os
import numpy as np

import nwb_datajoint as nd

import warnings
warnings.simplefilter('ignore', category=DeprecationWarning)  # ignore datajoint+jupyter async warning

# Comment these if you have already set these environment variables
data_dir = Path('/stelmo/nwb') # CHANGE ME TO THE BASE DIRECTORY FOR DATA STORAGE ON YOUR SYSTEM
os.environ['DJ_SUPPORT_FILEPATH_MANAGEMENT'] = 'TRUE'
os.environ['NWB_DATAJOINT_BASE_DIR'] = str(data_dir)
os.environ['KACHERY_STORAGE_DIR'] = str(data_dir / 'kachery-storage')
os.environ['SPIKE_SORTING_STORAGE_DIR'] = str(data_dir / 'spikesorting')

In [None]:
# We also import a bunch of tables so that we can call them easily
from nwb_datajoint.common import (RawPosition, HeadDir, Speed, LinPos, StateScriptFile, VideoFile,
                                  DataAcquisitionDevice, CameraDevice, Probe,
                                  DIOEvents,
                                  ElectrodeGroup, Electrode, Raw, SampleCount,
                                  LFPSelection, LFP, LFPBandSelection, LFPBand,
                                  SortGroup, SpikeSorting, SpikeSorter, SpikeSorterParameters, SpikeSortingWaveformParameters, SpikeSortingParameters, SpikeSortingMetrics, CuratedSpikeSorting,
                                  FirFilter,
                                  IntervalList, SortInterval,
                                  Lab, LabMember, Institution,
                                  BrainRegion,
                                  SensorData,
                                  Session, ExperimenterList,
                                  Subject,
                                  Task, TaskEpoch,
                                  Nwbfile, AnalysisNwbfile, NwbfileKachery, AnalysisNwbfileKachery)

In this tutorial, we will continue to work with the copy of `beans20190718.nwb` that you created in tutorial 0. If you deleted it from `Session`, make sure to re-insert before proceeding.

In [None]:
# Define the name of the file that you copied and renamed; make sure it's something unique. 
nwb_file_name = 'beans20190718.nwb'
filename, file_extension = os.path.splitext(nwb_file_name)
# This is a copy of the original nwb file, except it doesn't contain the raw data (for storage reasons)
nwb_file_name2 = filename + '_' + file_extension

In [None]:
# Run if you need to reinsert the data
nd.insert_sessions(nwb_file_name)

### Spike sorting

In general, running spike sorting means making decisions about the following:
1. which eletrodes to sort together (e.g. electrodes that form a tetrode should be sorted together, but tetrodes that are far apart need not be);
2. which time interval to sort (e.g. there may a long period in the recording where nothing happens, and we might want to exclude that);
3. which spike sorter to use (e.g. Mountainsort? Kilosort? IronClust?);
4. given choice of the spike sorter in 3, which parameter set to use.

In our Datajoint framework, everything that we do is an interaction with a table. This is true for spike sorting as well - i.e. we think of spike sorting as a process where we enter parameters of spike sorting (i.e. our decisions about the four questions above) into tables, and use that information to populate another table that will hold the result of spike sorting. Under the hood, we use a number of packages, notably `spikeinterface`. But the user need not know this - they just have to interact with the table. This makes spike sorting straightforward. In addition, the entries in these tables serve as a record of exactly which decisions you made.

#### Define sort group
We start with the first question: which electrodes do we want to sort together? We first inspect the `Electrode` table.

In [None]:
Electrode & {'nwb_file_name': nwb_file_name2}

This recording was done with polymer probes. Here `electrode_group_name` refers to a probe. We can see that there were two probes, `0` and `1`.

In [None]:
# get unique probe id
np.unique((Electrode & {'nwb_file_name': nwb_file_name2}).fetch('electrode_group_name'))

Each probe has four shanks, as you can see:

In [None]:
# get unique shank id for the first probe
np.unique((Electrode & {'nwb_file_name': nwb_file_name2, 'electrode_group_name': 0}).fetch('probe_shank'))

Our job is to identify the electrodes that we want to sort together, and add them as a sort group in the `SortGroup` table. One natural way to do this is to set each shank as a sort group (for tetrode recordings, each tetrode can be thought of as a "shank" with four electrodes). Use `set_group_by_shank` method for this:

In [None]:
SortGroup().set_group_by_shank(nwb_file_name2)

This generates 8 sort groups, one for each of the four shanks in the two probes.

In [None]:
SortGroup & {'nwb_file_name': nwb_file_name2}

`SortGroup` has a *parts table* called `SortGroupElectrode` - think of this as child table that contains information auxiliary to the parent table. As you can see, it contains two extra attributes: `electrode_group_name` and `electrode_id`.

In [None]:
SortGroup.SortGroupElectrode & {'nwb_file_name': nwb_file_name2}

What if you don't want to sort by shank? Maybe you want to select specific electrodes across shanks and sort them. To do so, you just have to manually `insert` a new entry into the `SortGroup` and `SortGroupElectrode` tables. 

In [None]:
sort_group_id = 8

In [None]:
# First we make a new entry in the SortGroup table, and give it sort_group_id of 8
SortGroup.insert1({'nwb_file_name': nwb_file_name2, 'sort_group_id': sort_group_id, 'sort_reference_electrode_id': -1}, 
                  skip_duplicates = True)
# Next, we will associate with the sort group that we just created every fourth electrode of the first shank
SortGroup.SortGroupElectrode.insert([[nwb_file_name2, 8, 0, elec] for elec in range(0,32,4)], skip_duplicates=True)

Note that `insert` is a method, just like `fetch`. You can insert an entry in the form of a dictionary or a list in the order of the attributes. We can look at the new entries we just made.

In [None]:
SortGroup & {'nwb_file_name' : nwb_file_name2, 'sort_group_id' : sort_group_id}

In [None]:
SortGroup.SortGroupElectrode & {'nwb_file_name': nwb_file_name2, 'sort_group_id': sort_group_id}

#### Define sort interval
Next, we make a decision about the time interval for our spike sorting. Let's re-examine `IntervalList`.

In [None]:
IntervalList & {'nwb_file_name' : nwb_file_name2}

For our example, let's just decide the first 10 seconds of the first run interval (`02_r1`) as our sort interval. To do so, we first fetch `valid_times` of this interval, define our new sort interval, and add this to the `SortInterval` table.

In [None]:
interval_list_name = '02_r1'

In [None]:
interval = (IntervalList & {'nwb_file_name' : nwb_file_name2,
                            'interval_list_name' : interval_list_name}).fetch1('valid_times')
print(interval)

In [None]:
sort_interval = np.asarray([interval[0][0]+10, interval[0][0]+20])
print(sort_interval)

In [None]:
# Check out SortInterval
SortInterval & {'nwb_file_name' : nwb_file_name2}

In [None]:
sort_interval_name = 'beans_02_r1_10s'

In [None]:
# Specify the required attributes
SortInterval.insert1({'nwb_file_name' : nwb_file_name2,
                      'sort_interval_name' : sort_interval_name,
                      'sort_interval' : sort_interval}, skip_duplicates=True)

In [None]:
# See results
SortInterval & {'nwb_file_name' : nwb_file_name2}

#### Define sorter
Next we decide which spike sorter to use. This boils down to looking at the `SpikeSorter` table and choosing the one we like. Initially, `SpikeSorter` may not be populated; in that case, we insert some sorters to it by checking which ones are available via `spikeinterface`, the package that we will be using implicitly for spike sorting.

In [None]:
SpikeSorter().insert_from_spikeinterface()

In [None]:
SpikeSorter()

For our example, we will be using `mountainsort4`.

In [None]:
sorter_name='mountainsort4'

#### Define sorter parameters
Once we have decided on a spike sorter, we have to set parameters. Some of these parameters are common to all sorters (e.g. frequency band to filter the raw data before sorting begins) but most are specific to the sorter that we chose. Again, we populate `SpikeSorterParameters` table with some default parameters for each sorter, and then we add our version as a new entry.

In [None]:
SpikeSorterParameters().insert_from_spikeinterface()

In [None]:
SpikeSorterParameters()

Define a new set of spike sorter parameters from default and add to table.

In [None]:
# Let's look at the default params
ms4_default_params = (SpikeSorterParameters & {'sorter_name' : sorter_name,
                                               'parameter_set_name' : 'default'}).fetch1()
print(ms4_default_params)

In [None]:
# Change the default params
param_dict = ms4_default_params['parameter_dict']
# We will just sort electrodes one by one
param_dict['adjacency_radius'] = 0
param_dict['curation'] = False
# Turn filter off since we will filter it prior to starting sort
param_dict['filter'] = False
# set num_workers to be the same number as the number of electrodes
param_dict['num_workers'] = len((SortGroup.SortGroupElectrode & {'sort_group_id':sort_group_id}).fetch('electrode_id'))
param_dict['verbose'] = True
# set clip size as number of samples for 2 milliseconds
param_dict['clip_size'] = np.int(2e-3 * (Raw & {'nwb_file_name' : nwb_file_name2}).fetch1('sampling_rate'))
param_dict['noise_overlap_threshold'] = 0

In [None]:
# Give a unique name here
parameter_set_name = 'test'

In [None]:
# Insert
SpikeSorterParameters.insert1({'sorter_name' : sorter_name,
                               'parameter_set_name' : parameter_set_name,
                               'parameter_dict' : param_dict}, skip_duplicates = True)

In [None]:
# Check that insert was successful
SpikeSorterParameters & {'sorter_name' : sorter_name, 'parameter_set_name' : parameter_set_name}

#### Define qualtiy metric parameters

We're almost done. There are more parameters related to how to compute the quality metrics for curation. We just use the default options here. 

In [None]:
# we'll use `test`
SpikeSortingMetrics()

In [None]:
cluster_metrics_list_name = 'test'

#### Bringing everything together

We now collect all the decisions we made up to here and put it into `SpikeSortingParameters` table (note: this is different from spike sor*ter* parameters defined above).

In [None]:
# collect the params
key = dict()
key['nwb_file_name'] = nwb_file_name2
key['sort_group_id'] = sort_group_id
key['sort_interval_name'] = sort_interval_name
key['interval_list_name'] = interval_list_name
key['sorter_name'] = sorter_name
key['parameter_set_name'] = parameter_set_name
key['cluster_metrics_list_name'] = cluster_metrics_list_name

In [None]:
# insert
SpikeSortingParameters.insert1(key, skip_duplicates = True)

In [None]:
# inspect
SpikeSortingParameters & {'nwb_file_name' : nwb_file_name2}

#### Running spike sorting
Now we can run spike sorting. As we said it's nothing more than populating another table (`SpikeSorting`) from the entries of `SpikeSortingParameters`.

In [None]:
# Specify entry (otherwise runs everything in SpikeSortingParameters); `proj` gives you primary key
SpikeSorting.populate([(SpikeSortingParameters & {'nwb_file_name' : nwb_file_name2}).proj()])

In [None]:
SpikeSorting & {'nwb_file_name' : nwb_file_name2}