In [None]:
# This example runs multiple recording sorting using a local computer
# Created by James Jun on Feb 27, 2019

# prerequisits
# $ pip install ml_ms4alg
# $ conda install -c conda-forge ipywidgets

# please ignore the warning when running MountainSort4
#   RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeforest_analysis as sa
import spikeextractors as se
import os
import shutil
import spikeforest as sf
import numpy as np
from spikesorters import IronClust, MountainSort4
import spiketoolkit as st
import spikewidgets as sw
import ipywidgets as widgets
import matplotlib.pyplot as plt


In [None]:
# select a data source and sorter
# you may change the data source index and sorter index
v_datasource = ['generate', 'download']
v_sorter = ['MountainSort4', 'IronClust']

widget1 = widgets.Dropdown(
    options=v_sorter, 
    index=0, description='Spike sorters')
display(widget1)

widget2 = widgets.Dropdown(
    options=v_datasource, 
    index=0, description='Data source')
display(widget2)

In [None]:
# create a parameters dictionary

params = dict(
    sorter = v_sorter[widget1.index],
    datasource = v_datasource[widget2.index],
    in_path = 'recordings/example1',
    out_path = 'sortings/example1',
    num_jobs = 4,
    num_workers = 4,
    )

In [None]:
# Define sorters

def irc(recpath):
    return IronClust.createJob(
            recording_dir=recpath,
            firings_out={'ext':'.mda'},
            detect_sign=-1,
            adjacency_radius=100,
            prm_template_name='static',
            _force_run=True,
            )

def ms4(recpath):
    return MountainSort4.createJob(
            recording_dir=recpath,
            firings_out={'ext':'.mda'},
            detect_sign=-1,
            adjacency_radius=100,
            _force_run=True
            )

v_sorters = dict(IronClust=irc, MountainSort4=ms4)

In [None]:
# get recording

recpath = params['in_path']
savepath = params['out_path']

# delete previous recording
if os.path.exists(recpath): shutil.rmtree(recpath)
if not os.path.exists(recpath): os.makedirs(recpath)
if not os.path.exists(savepath): os.makedirs(savepath)
        
if params['datasource'] is 'generate':
    # generate recording
    rx, sx_true = se.example_datasets.toy_example1(
        duration=600, num_channels=4, samplerate=30000, K=10)
else:
    # download recording
    kpath = 'kbucket://15734439d8cf/groundtruth/magland_synth/datasets_noise10_K10_C4/001_synth/'
    rx = se.MdaRecordingExtractor(kpath, download=True)
    sx_true = se.MdaSortingExtractor(kpath + 'firings_true.mda')   
    
se.MdaRecordingExtractor.writeRecording(
    recording=rx, save_path=recpath)
se.MdaSortingExtractor.writeSorting(
    sorting=sx_true, save_path=os.path.join(savepath, 'firings_true.mda'))

In [None]:
# create a batch (`jobs`) and execute batch

import mlprocessors as mlpr

jobs=[]
for iJob in range(0, params['num_jobs']):
    job=v_sorters[params['sorter']](recpath)
    jobs.append(job)
    

%time mlpr.executeBatch(jobs=jobs, num_workers=params['num_workers'])

In [None]:
# assemble sorting outputs

v_sx = []
for iJob in range(0, len(jobs)):
    firings_out_ = jobs[iJob]['result']['outputs']['firings_out']
    sx_ = se.MdaSortingExtractor(firings_out_)
    v_sx.append(sx_)

In [None]:
# compute SNR of the ground truth units

path_json_out = os.path.join(params['out_path'], 'summary_true.mda')
sa.compute_units_info.ComputeUnitsInfo.execute(
    recording_dir = params['in_path'],
    firings = os.path.join(params['out_path'], 'firings_true.mda'),
    json_out = path_json_out,
    _force_run = True
    )

import json
with open(path_json_out) as f:
    snr_json = json.load(f)
unit_snrs = [x['snr'] for x in snr_json]
unit_ids = [x['unit_id'] for x in snr_json]
sx_true.setUnitsProperty(property_name='snr', 
                         values=unit_snrs, unit_ids=unit_ids)

In [None]:
# display sorting comparison tables

v_comparison = []
for iJob in range(0, len(v_sx)):
    sx_ = v_sx[iJob]
    comparison=st.comparison.SortingComparison(
        sorting1=sx_true, sorting1_name='true',
        sorting2=sx_, sorting2_name=params['sorter'],
        )
    comparison_table = sw.SortingComparisonTable(comparison=comparison)
    v_comparison.append(comparison)
    
    print('sorting output for {} in job {}:'.format(
        params['sorter'], iJob))
    comparison_table.display()

In [None]:
# plot SNR vs accuracy for the first sorting output

sw.SortingAccuracyWidget(
    sorting_comparison=v_comparison[0],
    property_name='snr',
    ).plot()

In [None]:
# plot SNR vs accuracy for all sorting output

def plot_comparisons(v_comparison, params):
    v_snr = []
    v_accuracy = []
    for iJob in range(0, len(v_comparison)):
    #for iJob in range(1,2):
        SC = v_comparison[iJob]
        units = SC.getSorting1().getUnitIds()
        accuracy = [SC.getAgreementFraction(unit) for unit in units]
        snr = SC.getSorting1().getUnitsProperty(unit_ids=units, property_name='snr')
        v_snr.append(snr)
        v_accuracy.append(accuracy)

    plt.plot(v_snr, v_accuracy, '.')
    plt.xlabel('SNR')
    plt.ylabel('Accuracy')
    plt.title('Sorted by {} on {} recordings'.format(
        params['sorter'], params['num_jobs']))
    plt.ylim(0,1)
    plt.show()                

In [None]:
plot_comparisons(v_comparison, params)

In [None]:
plot_comparisons(v_comparison, params)