In [None]:
# This example runs single recording sorting using local computer
# Created by James Jun on Feb 26, 2019

# prerequisits
# $ pip install ml_ms4alg
# $ conda install -c conda-forge ipywidgets

# please ignore the warning when running MountainSort4
#   RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeforest_analysis as sa
from spikeforest import spikeextractors as se
import os
import shutil
import sfdata as sf
import numpy as np
from spikesorters import IronClust, MountainSort4
from spikeforest import spiketoolkit as st
from spikeforest import spikewidgets as sw
import ipywidgets as widgets

In [None]:
# select a data source and sorter
# you may change the data source index and sorter index
v_datasource = ['generate locally', 'download']
v_sorter = ['MountainSort4', 'IronClust']

widget1 = widgets.Dropdown(
    options=v_sorter, 
    index=0, description='Spike sorters')
display(widget1)

widget2 = widgets.Dropdown(
    options=v_datasource, 
    index=0, description='Data source')
display(widget2)

In [None]:
# create a parameters dictionary

params = dict(
    sorter = v_sorter[widget1.index],
    datasource = v_datasource[widget2.index],
    in_path = 'recordings/example1',
    out_path = 'sortings/example1'
    )

In [None]:
# Define sorters

def irc(recpath, firings_out):
    return IronClust.execute(
            recording_dir=recpath,
            firings_out=firings_out,
            detect_sign=-1,
            adjacency_radius=100,
            prm_template_name='static',
            _force_run=True)

def ms4(recpath, firings_out):
    return MountainSort4.execute(
            recording_dir=recpath,
            firings_out=firings_out,
            detect_sign=-1,
            adjacency_radius=100,
            _fore_run=True)

v_sorters = dict(IronClust=irc, MountainSort4=ms4)

In [None]:
# get recording
recpath = params['in_path']
savepath = params['out_path']

# delete previous recording
if os.path.exists(recpath): shutil.rmtree(recpath)
if not os.path.exists(recpath): os.makedirs(recpath)
if not os.path.exists(savepath): os.makedirs(savepath)
        
if params['datasource'] is 'generate locally':
    # generate recording
    rx, sx_true = se.example_datasets.toy_example1(
        duration=600, num_channels=4, samplerate=30000, K=10)
else:
    # download recording
    kpath = 'kbucket://15734439d8cf/groundtruth/magland_synth/datasets_noise10_K10_C4/001_synth'
    rx = se.MdaRecordingExtractor(kpath, download=True)
    sx_true = se.MdaSortingExtractor(kpath + 'firings_true.mda')   
    
se.MdaRecordingExtractor.writeRecording(
    recording=rx, save_path=recpath)
se.MdaSortingExtractor.writeSorting(
    sorting=sx_true, save_path=os.path.join(savepath, 'firings_true.mda'))

In [None]:
# Run spike sorting. ignore numpy.ufunc w on MountainSort4

firings_out = os.path.join(savepath, 'firings_out.mda')
v_sorters[params['sorter']](
    recpath, 
    firings_out)
sx = se.MdaSortingExtractor(firings_out)


In [None]:
# Run validation and display

comparison=st.comparison.SortingComparison(
    sorting1=sx_true, sorting1_name='true',
    sorting2=sx, sorting2_name=params['sorter'],
    )
comparison_table = sw.SortingComparisonTable(comparison=comparison)

print('sorting output for {}:'.format(params['sorter']))
comparison_table.display()

In [None]:
# compute SNR of the ground truth units

path_json_out = os.path.join(params['out_path'], 'summary_true.mda')
sa.compute_units_info.ComputeUnitsInfo.execute(
    recording_dir = params['in_path'],
    firings = os.path.join(params['out_path'], 'firings_true.mda'),
    json_out = path_json_out,
    _force_run = True
    )

import json
with open(path_json_out) as f:
    snr_json = json.load(f)
unit_snrs = [x['snr'] for x in snr_json]
unit_ids = [x['unit_id'] for x in snr_json]
sx_true.setUnitsProperty(property_name='snr', 
                         values=unit_snrs, unit_ids=unit_ids)

In [None]:
# plot SNR vs accuracy for the first sorting output

sw.SortingAccuracyWidget(
    sorting_comparison=comparison,
    property_name='snr',
    ).plot()