In [1]:
# This example runs single recording sorting using local computer
# Created by James Jun on Feb 26, 2019

# prerequisits
# $ pip install ml_ms4alg
# $ conda install -c conda-forge ipywidgets

# please ignore the warning when running MountainSort4
#   RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import spikeforest_analysis as sa
import spikeextractors as se
import os
import shutil
import spikeforest as sf
import numpy as np
from spikesorters import IronClust, MountainSort4
import spiketoolkit as st
import spikewidgets as sw
import ipywidgets as widgets

In [4]:
# select a data source and sorter
# you may change the data source index and sorter index
v_datasource = ['generate', 'download']
v_sorter = ['MountainSort4', 'IronClust']

widget1 = widgets.Dropdown(
    options=v_sorter, 
    index=0, description='Spike sorters')
display(widget1)

widget2 = widgets.Dropdown(
    options=v_datasource, 
    index=0, description='Data source')
display(widget2)

Dropdown(description='Spike sorters', options=('MountainSort4', 'IronClust'), value='MountainSort4')

Dropdown(description='Data source', options=('generate', 'download'), value='generate')

In [5]:
# create a parameters dictionary

params = dict(
    sorter = v_sorter[widget1.index],
    datasource = v_datasource[widget2.index],
    in_path = 'recordings/example1',
    out_path = 'sortings/example1'
    )

In [6]:
# Define sorters

def irc(recpath, firings_out):
    return IronClust.execute(
            recording_dir=recpath,
            firings_out=firings_out,
            detect_sign=-1,
            adjacency_radius=100,
            prm_template_name='static')

def ms4(recpath, firings_out):
    return MountainSort4.execute(
            recording_dir=recpath,
            firings_out=firings_out,
            detect_sign=-1,
            adjacency_radius=100)

v_sorters = dict(IronClust=irc, MountainSort4=ms4)

In [7]:
# get recording
recpath = params['in_path']
savepath = params['out_path']

# delete previous recording
if os.path.exists(recpath): shutil.rmtree(recpath)
if not os.path.exists(recpath): os.makedirs(recpath)
if not os.path.exists(savepath): os.makedirs(savepath)
        
if params['datasource'] is 'generate':
    # generate recording
    rx, sx_true = se.example_datasets.toy_example1(
        duration=600, num_channels=4, samplerate=30000, K=10)
else:
    # download recording
    kpath = 'kbucket://15734439d8cf/groundtruth/magland_synth/datasets_noise10_K10_C4/001_synth/'
    rx = se.MdaRecordingExtractor(kpath, download=True)
    sx_true = se.MdaSortingExtractor(kpath + 'firings_true.mda')   
    
se.MdaRecordingExtractor.writeRecording(
    recording=rx, save_path=recpath)
se.MdaSortingExtractor.writeSorting(
    sorting=sx_true, save_path=os.path.join(savepath, 'firings_true.mda'))

In [8]:
# Run spike sorting

firings_out = os.path.join(savepath, 'firings_out.mda')
v_sorters[params['sorter']](
    recpath, 
    firings_out)
sx = se.MdaSortingExtractor(firings_out)


::::::::::::::::::::::::::::: MountainSort4
Computing sha1 of recordings/example1/raw.mda
MLPR EXECUTING::::::::::::::::::::::::::::: MountainSort4
MountainSort4......


  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)
  return f(*args, **kwds)


Using 4 workers.
Using tmpdir: /tmp/tmpofi3fn0e
Num. workers = 4
Preparing /tmp/tmpofi3fn0e/timeseries.hdf5...
Preparing neighborhood sorters (M=4, N=18000000)...
Neighboorhood of channel 2 has 4 channels.
Neighboorhood of channel 1 has 4 channels.
Neighboorhood of channel 0 has 4 channels.
Neighboorhood of channel 3 has 4 channels.
Detecting events on channel 3 (phase1)...
Detecting events on channel 2 (phase1)...
Detecting events on channel 1 (phase1)...
Detecting events on channel 4 (phase1)...
Elapsed time for detect on neighborhood: 0:00:02.433808
Num events detected on channel 1 (phase1): 15005
Computing PCA features for channel 1 (phase1)...
Elapsed time for detect on neighborhood: 0:00:02.461442
Num events detected on channel 4 (phase1): 6438
Computing PCA features for channel 4 (phase1)...
Elapsed time for detect on neighborhood: 0:00:02.479232
Num events detected on channel 3 (phase1): 10149
Computing PCA features for channel 3 (phase1)...
Elapsed time for detect on neighborh

In [11]:
# Run validation and display
comparison=st.comparison.SortingComparison(
    sorting1=sx_true, sorting1_name='true',
    sorting2=sx, sorting2_name=params['sorter'],
    )
comparison_table = sw.SortingComparisonTable(comparison=comparison)

print('sorting output for {}:'.format(params['sorter']))
comparison_table.display()

sorting output for MountainSort4:


unit_id,accuracy,best_unit,matched_unit,f_n,f_p,num_matches
1,0.87,4,4,0.0,0.13,1232
2,0.99,2,2,0.0,0.01,1372
3,0.45,5,-1,0.0,0.55,622
4,0.99,8,8,0.0,0.01,1387
5,0.97,9,9,0.0,0.03,1359
6,0.97,10,10,0.0,0.03,1361
7,0.98,12,12,0.01,0.01,1393
8,0.98,13,13,0.0,0.02,1377
9,0.78,14,14,0.0,0.22,1096
10,1.0,15,15,0.0,0.0,1428
