<a href="https://colab.research.google.com/github/magland/spikeforest_batch_run/blob/master/notebooks/spikeforest_analysis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# SpikeForest analysis

This notebook represents a complete spikeforest analysis of the bionet studies. You should execute the first few cells and then skip down to the section of interest below.

In [0]:
# Only run this cell if you are running this on a hosted runtime that does not have these packages installed
# %%capture is used to suppress the output... this should take up to a minute to complete
%%capture
!pip install spikeforest
!pip install git+https://github.com/magland/spikeforest_batch_run

In [0]:
# Import the python packages -- autoreload is used for development purposes
%load_ext autoreload
%autoreload 2

import spikeforest as sf
from kbucket import client as kb

In [0]:
## Configure readonly access to kbucket -- use this if you only want to browse the results ---
sf.kbucketConfigRemote(name='spikeforest1-readonly')

In [0]:
## Configure read/write access to kbucket -- use this if you are preparing the studies or the processing batches
sf.kbucketConfigRemote(name='spikeforest1-readwrite',ask_password=True)

## Prepare recordings

In [0]:
def read_text_file(path):
  path2=kb.realizeFile(path)
  if path2 is None:
    raise Exception('Unable to realize file: '+path)
  with open(path2,'r') as f:
    return f.read()
  
def prepare_bionet_studies(*,basedir,channels, study_set_name='bionet'):
  studies=[]
  recordings=[]
  names=['bionet_drift','bionet_shuffle','bionet_static']
  for name in names:
    study_name=name
    study_dir=basedir+'/bionet/'+name
    description=read_text_file(study_dir+'/readme.txt')
    study0=dict(
        name=study_name,
        study_set=study_set_name,
        directory=study_dir,
        description=description
    )
    studies.append(study0)
    dd=kb.readDir(study_dir)
    for dsname in dd['dirs']:
        dsdir='{}/{}'.format(study_dir,dsname)
        rec0=dict(
            name=dsname,
            study=study_name,
            description='',
            directory=dsdir,
            channels=channels
        )
        if len(rec0['channels'])>0:
          units=sf.sf_batch.select_units_on_channels(
              recording_dir=dsdir,
              firings=dsdir+'/firings_true.mda',
              channels=rec0['channels']
          )
          rec0['units_true']=units
        recordings.append(rec0)
  return studies, recordings


def prepare_bionet32c_studies(*,basedir):
  channels = list(range(32))
  
  studies, recordings = prepare_bionet_studies(basedir=basedir, channels=channels, study_set_name='bionet32c')
  return studies, recordings


def prepare_magland_synth_studies(*,basedir):
  study_set_name='magland_synth'
  studies=[]
  recordings=[]
  names=[]
  names=names+['datasets_noise10_K10_C4','datasets_noise10_K10_C8']
  names=names+['datasets_noise10_K20_C4','datasets_noise10_K20_C8']
  names=names+['datasets_noise20_K10_C4','datasets_noise20_K10_C8']
  names=names+['datasets_noise20_K20_C4','datasets_noise20_K20_C8']
  description=read_text_file(basedir+'/magland_synth/readme.txt')
  for name in names:
    study_name='magland_synth_'+name[9:]
    study_dir=basedir+'/magland_synth/'+name
    study0=dict(
        name=study_name,
        study_set=study_set_name,
        directory=study_dir,
        description=description
    )
    studies.append(study0)
    dd=kb.readDir(study_dir)
    for dsname in dd['dirs']:
        dsdir='{}/{}'.format(study_dir,dsname)
        recordings.append(dict(
            name=dsname,
            study=study_name,
            directory=dsdir,
            description='One of the recordings in the {} study'.format(study_name)
        ))
  return studies, recordings

def prepare_mearec_tetrode_studies(*,basedir):
  study_set_name='mearec_tetrode'
  studies=[]
  recordings=[]
  names=[]
  names=names+['datasets_noise10_K10_C4','datasets_noise10_K20_C4']
  names=names+['datasets_noise20_K10_C4','datasets_noise20_K20_C4']
  description=read_text_file(basedir+'/mearec_synth/tetrode/readme.txt')
  for name in names:
    study_name='mearec_tetrode_'+name[9:]
    study_dir=basedir+'/mearec_synth/tetrode/'+name
    study0=dict(
        name=study_name,
        study_set=study_set_name,
        directory=study_dir,
        description=description
    )
    studies.append(study0)
    dd=kb.readDir(study_dir)
    for dsname in dd['dirs']:
        dsdir='{}/{}'.format(study_dir,dsname)
        recordings.append(dict(
            name=dsname,
            study=study_name,
            directory=dsdir,
            description='One of the recordings in the {} study'.format(study_name)
        ))
  return studies, recordings

def prepare_mearec_neuronexus_studies(*,basedir):
  study_set_name='mearec_neuronexus'
  studies=[]
  recordings=[]
  names=[]
  names=names+['datasets_noise10_K10_C32','datasets_noise10_K20_C32','datasets_noise10_K40_C32']
  names=names+['datasets_noise20_K10_C32','datasets_noise20_K20_C32','datasets_noise20_K40_C32']
  description=read_text_file(basedir+'/mearec_synth/neuronexus/readme.txt')
  for name in names:
    study_name='mearec_neuronexus_'+name[9:]
    study_dir=basedir+'/mearec_synth/neuronexus/'+name
    study0=dict(
        name=study_name,
        study_set=study_set_name,
        directory=study_dir,
        description=description
    )
    studies.append(study0)
    dd=kb.readDir(study_dir)
    for dsname in dd['dirs']:
        dsdir='{}/{}'.format(study_dir,dsname)
        recordings.append(dict(
            name=dsname,
            study=study_name,
            directory=dsdir,
            description='One of the recordings in the {} study'.format(study_name)
        ))
  return studies, recordings

In [0]:
basedir='kbucket://15734439d8cf/groundtruth'
#basedir='/mnt/ceph/users/jjun/groundtruth'

In [0]:
channels=[0,1,2,3,4,5,6,7]
studies,recordings=prepare_bionet_studies(basedir=basedir,channels=channels)
kb.saveObject(dict(studies=studies,recordings=recordings),key=dict(name='spikeforest_bionet_recordings'))

In [0]:
studies,recordings=prepare_bionet32c_studies(basedir=basedir)
kb.saveObject(dict(studies=studies,recordings=recordings),key=dict(name='spikeforest_bionet32c_recordings'))

In [0]:
studies,recordings=prepare_mearec_neuronexus_studies(basedir=basedir)
kb.saveObject(dict(studies=studies,recordings=recordings),key=dict(name='spikeforest_mearec_neuronexus_recordings'))

In [0]:
studies,recordings=prepare_magland_synth_studies(basedir=basedir)
kb.saveObject(dict(studies=studies,recordings=recordings),key=dict(name='spikeforest_magland_synth_recordings'))

## Create summarize recordings batches

In [0]:
def create_summarize_recordings_batch(*,recordings_name,batch_name):
  print('Creating summarize_recordings batch: '+batch_name)
  SF=sf.SFData()
  SF.loadRecordings(key=dict(name=recordings_name))
  
  jobs=[]
  for name in SF.studyNames():
    study=SF.study(name)
    for recname in study.recordingNames():
      R=study.recording(recname)
      job=dict(
          command='summarize_recording',
          label=R.name(),
          recording=R.getObject()
      )
      jobs.append(job)
  batch=dict(jobs=jobs)
  print('Number of jobs: {}'.format(len(jobs)))
  kb.saveObject(key=dict(batch_name=batch_name),object=batch)

In [0]:
create_summarize_recordings_batch(recordings_name='spikeforest_bionet_recordings',batch_name='summarize_recordings_bionet')
create_summarize_recordings_batch(recordings_name='spikeforest_magland_synth_recordings',batch_name='summarize_recordings_magland_synth')
create_summarize_recordings_batch(recordings_name='spikeforest_mearec_tetrode_recordings',batch_name='summarize_recordings_mearec_tetrode')

create_summarize_recordings_batch(recordings_name='spikeforest_mearec_neuronexus_recordings',batch_name='summarize_recordings_mearec_neuronexus')
create_summarize_recordings_batch(recordings_name='spikeforest_bionet32c_recordings',batch_name='summarize_recordings_bionet32c')


To run these batches, go to a computer with resources somewhere and run something like:

```
bin/sf_run_batch [name_of_batch] --run_prefix "srun -c 2 -n 40"
```

"srun" in Flatiron cluster reqiures you to run the following before
```
module load srun
module load matlab
```
To use GPU, run
```
bin/sf_run_batch [name_of_batch] --run_prefix "srun -c 2 -n 40 --gres=gpu:2 -p gpu"
```
  
where bin/sf_run_batch is found in the spikeforest_batch_run repository.

Alternatively, you can test run it in this notebook using the following commands:


To clear previous batch result, run
```
bin/sf_run_batch_command clear [name_of_batch]
```



In [0]:
## Note: usually you would not run this cell -- see the note above.

import spikeforest_batch_run as sbr
# Execute prepareBatch once (serially)
sbr.prepareBatch(batch_name='summarize_recordings_bionet')

# Execute runBatch many times in parallel
sbr.runBatch(batch_name='summarize_recordings_bionet')

# Execute assembleBatchResults once (serially)
sbr.assembleBatchResults(batch_name='summarize_recordings_bionet')



## Browse recordings

In [0]:
SF=sf.SFData()
SF.loadRecordings(key=dict(name='spikeforest_bionet_recordings'))
SF.loadRecordings(key=dict(name='spikeforest_magland_synth_recordings'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_bionet',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_magland_synth',name='job_results'))

In [0]:
X=sf.SFSelectWidget(sfdata=SF,mode='recording')
display(X)

In [0]:
R=X.recording()
display(R.plot('timeseries'))
display(R.plot('waveforms_true'))
display(R.trueUnitsInfo())

In [0]:
R.plotNames()

## Create spike sorting batches

In [0]:
SF=sf.SFData()
SF.loadRecordings(key=dict(name='spikeforest_bionet_recordings'))
SF.loadRecordings(key=dict(name='spikeforest_magland_synth_recordings'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_bionet',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_magland_synth',name='job_results'))

In [0]:
sorter_ms4_thr3=dict(
    name='MountainSort4-thr3',
    processor_name='MountainSort4',
    params=dict(
        detect_sign=-1,
        adjacency_radius=100,
        detect_threshold=3
    )
)

sorter_irc_tetrode=dict(
    name='IronClust-tetrode',
    processor_name='IronClust',
    params=dict(
        detect_sign=-1,
        adjacency_radius=100,
        detect_threshold=5,
        prm_template_name="tetrode_template.prm"
    )
)

sorter_irc_drift=dict(
    name='IronClust-drift',
    processor_name='IronClust',
    params=dict(
        detect_sign=-1,
        adjacency_radius=100,
        prm_template_name="template_drift.prm"
    )
)

sorter_sc=dict(
    name='SpykingCircus',
    processor_name='SpykingCircus',
    params=dict(
        detect_sign=-1,
        adjacency_radius=100
    )
)

sorter_ks=dict(
    name='KiloSort',
    processor_name='KiloSort',
    params=dict(
        detect_sign=-1,
        adjacency_radius=100
    )
)

In [0]:
def create_sorting_batch(*,recordings_name,batch_name,sorters):
  print('Creating sorting batch: '+batch_name)
  SF=sf.SFData()
  SF.loadRecordings(key=dict(name=recordings_name))
  
  jobs=[]
  for name in SF.studyNames():
    study=SF.study(name)
    for rname in study.recordingNames():
      R=study.recording(rname)
      for sorter in sorters:
        job=dict(
          command='sort_recording',
          label=sorter['name']+': '+R.name(),
          recording=R.getObject(),
          sorter=sorter
        )
        jobs.append(job)

  batch=dict(jobs=jobs)
  print('Number of jobs: {}'.format(len(jobs)))
  kb.saveObject(key=dict(batch_name=batch_name),object=batch)

In [0]:
create_sorting_batch(recordings_name='spikeforest_magland_synth_recordings',batch_name='ms4_magland_synth',sorters=[sorter_ms4_thr3])
create_sorting_batch(recordings_name='spikeforest_magland_synth_recordings',batch_name='irc_magland_synth',sorters=[sorter_irc_tetrode])
create_sorting_batch(recordings_name='spikeforest_magland_synth_recordings',batch_name='sc_magland_synth',sorters=[sorter_sc])
create_sorting_batch(recordings_name='spikeforest_magland_synth_recordings',batch_name='ks_magland_synth',sorters=[sorter_ks])

create_sorting_batch(recordings_name='spikeforest_bionet_recordings',batch_name='ms4_bionet',sorters=[sorter_ms4_thr3])
create_sorting_batch(recordings_name='spikeforest_bionet_recordings',batch_name='irc_bionet',sorters=[sorter_irc_drift])
create_sorting_batch(recordings_name='spikeforest_bionet_recordings',batch_name='sc_bionet',sorters=[sorter_sc])
create_sorting_batch(recordings_name='spikeforest_bionet_recordings',batch_name='ks_bionet',sorters=[sorter_ks])

create_sorting_batch(recordings_name='spikeforest_bionet32c_recordings',batch_name='ms4_bionet32c',sorters=[sorter_ms4_thr3])
create_sorting_batch(recordings_name='spikeforest_bionet32c_recordings',batch_name='irc_bionet32c',sorters=[sorter_irc_drift])
create_sorting_batch(recordings_name='spikeforest_bionet32c_recordings',batch_name='sc_bionet32c',sorters=[sorter_sc])
create_sorting_batch(recordings_name='spikeforest_bionet32c_recordings',batch_name='ks_bionet32c',sorters=[sorter_ks])

create_sorting_batch(recordings_name='spikeforest_mearec_tetrode_recordings',batch_name='ms4_mearec_tetrode',sorters=[sorter_ms4_thr3])
create_sorting_batch(recordings_name='spikeforest_mearec_tetrode_recordings',batch_name='irc_mearec_tetrode',sorters=[sorter_irc_drift])
create_sorting_batch(recordings_name='spikeforest_mearec_tetrode_recordings',batch_name='sc_mearec_tetrode',sorters=[sorter_sc])
create_sorting_batch(recordings_name='spikeforest_mearec_tetrode_recordings',batch_name='ks_mearec_tetrode',sorters=[sorter_ks])

create_sorting_batch(recordings_name='spikeforest_mearec_neuronexus_recordings',batch_name='ms4_mearec_neuronexus',sorters=[sorter_ms4_thr3])
create_sorting_batch(recordings_name='spikeforest_mearec_neuronexus_recordings',batch_name='irc_mearec_neuronexus',sorters=[sorter_irc_drift])
create_sorting_batch(recordings_name='spikeforest_mearec_neuronexus_recordings',batch_name='sc_mearec_neuronexus',sorters=[sorter_sc])
create_sorting_batch(recordings_name='spikeforest_mearec_neuronexus_recordings',batch_name='ks_mearec_neuronexus',sorters=[sorter_ks])

In [0]:
# create batch programatically, same effect as the above cell

vs_sorters = ['ms4', 'irc', 'sc', 'ks']
v_sorters_tetrode = [sorter_ms4_thr3, sorter_irc_tetrode, sorter_sc, sorter_ks]
v_sorters_siprobe = [sorter_ms4_thr3, sorter_irc_drift, sorter_sc, sorter_ks]
vs_recordings_tetrode = ['magland_synth', 'mearec_tetrode']  
vs_recordings_siprobe = ['bionet', 'bionet32c', 'mearec_neuronexus']

for recording in vs_recordings_tetrode:
  for sorter, s_sorter in zip(v_sorters_tetrode, vs_sorters):
    recordings_name = 'spikeforest_{}_recordings'.format(recording)
    batch_name = '{}_{}'.format(s_sorter, recording)
    #print(recordings_name, batch_name)
    create_sorting_batch(recordings_name=recordings_name, batch_name=batch_name, sorters=[sorter])
  print()
    
for recording in vs_recordings_siprobe:
  for sorter, s_sorter in zip(v_sorters_siprobe, vs_sorters):
    recordings_name = 'spikeforest_{}_recordings'.format(recording)
    batch_name = '{}_{}'.format(s_sorter, recording)
    #print(recordings_name, batch_name)    
    create_sorting_batch(recordings_name=recordings_name, batch_name=batch_name, sorters=[sorter])
  print()


To run these sorting batches, follow the instructions above.

## Browse sorting results

In [0]:
SF=sf.SFData()
SF.loadRecordings(key=dict(name='spikeforest_bionet_recordings'))
SF.loadRecordings(key=dict(name='spikeforest_magland_synth_recordings'))
SF.loadRecordings(key=dict(name='spikeforest_mearec_tetrode_recordings'))

SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_bionet',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_magland_synth',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_mearec_tetrode',name='job_results'))

SF.loadProcessingBatch(key=dict(batch_name='ms4_magland_synth',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='sc_magland_synth',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='irc_magland_synth',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='ks_magland_synth',name='job_results'))

SF.loadProcessingBatch(key=dict(batch_name='ms4_mearec_tetrode',name='job_results'))
#SF.loadProcessingBatch(key=dict(batch_name='sc_mearec_tetrode',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='irc_mearec_tetrode',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='ks_mearec_tetrode',name='job_results'))

SF.loadProcessingBatch(key=dict(batch_name='ms4_bionet',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='sc_bionet',name='job_results')) ## Spyking circus not working yet -- need to put into singularity container
SF.loadProcessingBatch(key=dict(batch_name='irc_bionet',name='job_results'))
#SF.loadProcessingBatch(key=dict(batch_name='ks_bionet',name='job_results'))



In [0]:
X=sf.SFSelectWidget(sfdata=SF,mode='sorting_result')
display(X)

In [0]:
R=X.sortingResult()
display(R.plot('unit_waveforms'))
display(R.plot('autocorrelograms'))
display(R.comparisonWithTruth())

## Aggregate sorting results

In [0]:
SF=sf.SFData()
SF.loadRecordings(key=dict(name='spikeforest_bionet_recordings'))
SF.loadProcessingBatch(key=dict(batch_name='summarize_recordings_bionet',name='job_results'))
SF.loadProcessingBatch(key=dict(batch_name='ms4_bionet',name='job_results'))
#SF.loadProcessingBatch(key=dict(batch_name='sc_bionet',name='job_results')) ## Spyking circus not working yet -- need to put into singularity container
SF.loadProcessingBatch(key=dict(batch_name='irc_bionet',name='job_results'))

In [0]:
import pandas as pd
import random
import altair as alt
alt.renderers.enable('colab')

# Accumulate the sorting results
def accumulate_comparison_with_ground_truth(*,SF,studies,sorter_name,fieldnames):
  ret=[]
  for study in studies:
    recordings=[study.recording(name) for name in study.recordingNames()]
    for R in recordings:
      result=R.sortingResult(sorter_name)
      A=result.comparisonWithTruth(format='json')
      B=R.trueUnitsInfo(format='json')
      snr_by_true_unit=dict()
      for b in B:
        snr_by_true_unit[b['unit_id']]=b['snr']
      for i in A:
        a=A[i]
        rec=dict()
        rec['recording_name']=R.name()
        rec['unit_id']=a['Unit ID']
        rec['snr']=snr_by_true_unit[rec['unit_id']]
        for fieldname in fieldnames:
          rec[fieldname]=float(a[fieldname])
        ret.append(rec)
  return ret

def show_accuracy_plot(*,SF,study_name,sorter_name,title):
  
  study=SF.study(study_name)
  X=accumulate_comparison_with_ground_truth(
      SF=SF,
      studies=[study],
      sorter_name=sorter_name,
      fieldnames=['Accuracy']
  )
  
  # Display the accumulated sorting results
  cc=alt.Chart(pd.DataFrame(X),title=title).mark_point().encode(
      x='snr',
      y='Accuracy',
      color='recording_name',
      tooltip='recording_name'
  ).interactive()
  display(cc)

In [0]:
import vdomr as vd

class SelectBox(vd.Component):
    def __init__(self,options=[]):
        vd.Component.__init__(self)
        self._on_change_handlers=[]
        self._value=None
        self.setOptions(options)
        
    def setOptions(self,options):
        self._options=options
        if self._value not in options:
          self._value=options[0] if options else None
        self.refresh()
        
    def value(self):
        return self._value
    
    def setValue(self,value):
        self._value=value
        self.refresh()
        
    def onChange(self,handler):
        self._on_change_handlers.append(handler)
        
    def _on_change(self,value):
        self._value=value
        for handler in self._on_change_handlers:
            handler(value=value)
        
    def render(self):
        opts=[]
        for option in self._options:
            if option==self._value:
              opts.append(vd.option(option,selected='selected'))
            else:
              opts.append(vd.option(option))
        X=vd.select(opts,onchange=self._on_change)
        return X

In [0]:
STUDY=SelectBox(options=SF.studyNames())
SORTER=SelectBox(options=['MountainSort4-thr3','IronClust-drift'])
display(STUDY)
display(SORTER)

In [0]:
print(SORTER.value())

In [0]:
show_accuracy_plot(
    SF=SF,
    study_name=STUDY.value(),
    sorter_name=SORTER.value(),
    title=SORTER.value()+' '+STUDY.value()
)

In [0]:
from google.colab import output
print(output.__file__)