In [1]:
# This example runs multiple recording sorting using a local computer
# Created by James Jun and Jeremy Magland on Feb 28, 2019

# prerequisits
# $ pip install ml_ms4alg
# $ conda install -c conda-forge ipywidgets
# $ jupyter labextension install @jupyter-widgets/jupyterlab-manager

# please ignore the warning when running MountainSort4
#   RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import spikeforest_analysis as sa
from spikeforest import spikeextractors as se
import os
import shutil
import spikeforest as sf
import numpy as np
from spikesorters import IronClust, MountainSort4
from spikeforest import spiketoolkit as st
from spikeforest import spikewidgets as sw
import ipywidgets as widgets
import matplotlib.pyplot as plt


In [4]:
# select a data source and sorter
# you may change the data source index and sorter index
v_datasource = ['generate', 'download']
v_sorter = ['MountainSort4', 'IronClust']

widget1 = widgets.Dropdown(
    options=v_sorter, 
    index=0, description='Spike sorters')
display(widget1)

widget2 = widgets.Dropdown(
    options=v_datasource, 
    index=0, description='Data source')
display(widget2)

Dropdown(description='Spike sorters', options=('MountainSort4', 'IronClust'), value='MountainSort4')

Dropdown(description='Data source', options=('generate', 'download'), value='generate')

In [5]:
# create a parameters dictionary

params = dict(
    sorter = v_sorter[widget1.index],
    datasource = v_datasource[widget2.index],
    in_path = 'recordings/example1',
    out_path = 'sortings/example1',
    num_jobs = 4,
    num_workers = 4,
    )

In [6]:
# Define sorters

def irc(recpath):
    return IronClust.createJob(
            recording_dir=recpath,
            firings_out={'ext':'.mda'},
            detect_sign=-1,
            adjacency_radius=100,
            prm_template_name='static',
            _force_run=True,
            )

def ms4(recpath):
    return MountainSort4.createJob(
            recording_dir=recpath,
            firings_out={'ext':'.mda'},
            detect_sign=-1,
            adjacency_radius=100,
            _force_run=True
            )

v_sorters = dict(IronClust=irc, MountainSort4=ms4)

In [7]:
# get recording

recpath = params['in_path']
savepath = params['out_path']

# delete previous recording
if os.path.exists(recpath): shutil.rmtree(recpath)
if not os.path.exists(recpath): os.makedirs(recpath)
if not os.path.exists(savepath): os.makedirs(savepath)
        
if params['datasource'] is 'generate':
    # generate recording
    rx, sx_true = se.example_datasets.toy_example1(
        duration=600, num_channels=4, samplerate=30000, K=10)
else:
    # download recording
    kpath = 'kbucket://15734439d8cf/groundtruth/magland_synth/datasets_noise10_K10_C4/001_synth'
    rx = se.MdaRecordingExtractor(kpath, download=True)
    sx_true = se.MdaSortingExtractor(kpath + '/firings_true.mda')   
    
se.MdaRecordingExtractor.writeRecording(
    recording=rx, save_path=recpath)
se.MdaSortingExtractor.writeSorting(
    sorting=sx_true, save_path=os.path.join(savepath, 'firings_true.mda'))

In [8]:
# create a batch (`jobs`) and execute batch

import mlprocessors as mlpr

jobs=[]
for iJob in range(0, params['num_jobs']):
    job=v_sorters[params['sorter']](recpath)
    jobs.append(job)
    

mlpr.executeBatch(jobs=jobs, num_workers=params['num_workers'])

RUNNING: python3 /tmp/tmp0uj_jpd5/execute.py
RUNNING: python3 /tmp/tmpi0_4u928/execute.py
RUNNING: python3 /tmp/tmpjbf3ovm5/execute.py
RUNNING: python3 /tmp/tmp0erc7zcf/execute.py
::::::::::::::::::::::::::::: MountainSort4
Traceback (most recent call last):
  File "/tmp/tmp0uj_jpd5/execute.py", line 8, in <module>
    main()
  File "/tmp/tmp0uj_jpd5/execute.py", line 5, in main
    MountainSort4.execute(_cache=True, _force_run=True, _keep_temp_files=None, _container=None, recording_dir='/home/jamesjun/src/spikeforest/docs/example_notebooks/recordings/example1', firings_out='/tmp/tmp0uj_jpd5/output_firings_out.mda', detect_sign=-1, adjacency_radius=100, freq_min=300, freq_max=6000, whiten=True, clip_size=50, detect_threshold=3, detect_interval=10, noise_overlap_threshold=0.15)
  File "/home/jamesjun/src/spikeforest/mountaintools/mlprocessors/core.py", line 505, in execute
    return execute(proc, **kwargs)
  File "/home/jamesjun/src/spikeforest/mountaintools/mlprocessors/execute.py", l

Exception: Non-zero return code when running processor job

In [10]:
jobs[0]

{'command': 'execute_mlprocessor',
 'label': 'MountainSort4 (version: 4.2.0) (container: None)',
 'processor_name': 'MountainSort4',
 'processor_version': '4.2.0',
 'processor_class_name': 'MountainSort4',
 'processor_code': 'sha1://bd3eca8278d957ce8f8b4583fef7d70da4748783/code.json',
 'container': None,
 'inputs': {'recording_dir': '/home/jamesjun/src/spikeforest/docs/example_notebooks/recordings/example1'},
 'outputs': {'firings_out': {'ext': '.mda'}},
 'parameters': {'detect_sign': -1,
  'adjacency_radius': 100,
  'freq_min': 300,
  'freq_max': 6000,
  'whiten': True,
  'clip_size': 50,
  'detect_threshold': 3,
  'detect_interval': 10,
  'noise_overlap_threshold': 0.15},
 '_force_run': True,
 '_cache': True}

In [9]:
# assemble sorting outputs

v_sx = []
for iJob in range(0, len(jobs)):
    firings_out_ = jobs[iJob]['result']['outputs']['firings_out']
    sx_ = se.MdaSortingExtractor(firings_out_)
    v_sx.append(sx_)

KeyError: 'result'

In [None]:
# compute SNR of the ground truth units

path_json_out = os.path.join(params['out_path'], 'summary_true.mda')
sa.compute_units_info.ComputeUnitsInfo.execute(
    recording_dir = params['in_path'],
    firings = os.path.join(params['out_path'], 'firings_true.mda'),
    json_out = path_json_out,
    _force_run = True
    )

import json
with open(path_json_out) as f:
    snr_json = json.load(f)
unit_snrs = [x['snr'] for x in snr_json]
unit_ids = [x['unit_id'] for x in snr_json]
sx_true.setUnitsProperty(property_name='snr', 
                         values=unit_snrs, unit_ids=unit_ids)

In [None]:
# display sorting comparison tables

v_comparison = []
for iJob in range(0, len(v_sx)):
    sx_ = v_sx[iJob]
    comparison=st.comparison.SortingComparison(
        sorting1=sx_true, sorting1_name='true',
        sorting2=sx_, sorting2_name=params['sorter'],
        )
    comparison_table = sw.SortingComparisonTable(comparison=comparison)
    v_comparison.append(comparison)
    
    print('sorting output for {} in job {}:'.format(
        params['sorter'], iJob))
    comparison_table.display()

In [None]:
# plot SNR vs accuracy for the first sorting output

sw.SortingAccuracyWidget(
    sorting_comparison=v_comparison[0],
    property_name='snr',
    ).plot()

In [None]:
# plot SNR vs accuracy for all sorting output

def plot_comparisons(v_comparison, params):
    v_snr = []
    v_accuracy = []
    for iJob in range(0, len(v_comparison)):
    #for iJob in range(1,2):
        SC = v_comparison[iJob]
        units = SC.getSorting1().getUnitIds()
        accuracy = [SC.getAgreementFraction(unit) for unit in units]
        snr = SC.getSorting1().getUnitsProperty(unit_ids=units, property_name='snr')
        v_snr.append(snr)
        v_accuracy.append(accuracy)

    plt.plot(v_snr, v_accuracy, '.')
    plt.xlabel('SNR')
    plt.ylabel('Accuracy')
    plt.title('Sorted by {} on {} recordings'.format(
        params['sorter'], params['num_jobs']))
    plt.ylim(0,1)
    plt.show()                

In [None]:
plot_comparisons(v_comparison, params)

In [None]:
plot_comparisons(v_comparison, params)