In [None]:
# This example runs single recording sorting using local computer
# Created by James Jun on Feb 26, 2019

# prerequisits
#   pip install ml_ms4alg

# please ignore the warning when running MountainSort4
#   RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192


In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeforest_analysis as sa
import spikeextractors as se
import os
import shutil
import sfdata as sf
import numpy as np
from spikesorters import IronClust, MountainSort4

import spiketoolkit as st
import spikewidgets as sw

In [None]:
# select a data source and sorter
# you may change the data source index and sorter index
v_datasource = ['generate', 'download']
v_sorter = ['IronClust', 'MountainSort4']
i_datasource = 0
i_sorter = 1

params = dict(
    datasource = v_datasource[i_datasource], 
    sorter = v_sorter[i_sorter],
    in_path = 'recordings/example1',
    out_path = 'sortings/example1'
    )

In [None]:
# Define sorters

def irc(recpath, firings_out):
    return IronClust.execute(
            recording_dir=recpath,
            firings_out=firings_out,
            detect_sign=-1,
            adjacency_radius=100,
            prm_template_name='static')

def ms4(recpath, firings_out):
    return MountainSort4.execute(
            recording_dir=recpath,
            firings_out=firings_out,
            detect_sign=-1,
            adjacency_radius=100)

v_sorters = dict(IronClust=irc, MountainSort4=ms4)

In [None]:
# get recording
recpath = params['in_path']
savepath = params['out_path']

# delete previous recording
if os.path.exists(recpath): shutil.rmtree(recpath)
if not os.path.exists(recpath): os.makedirs(recpath)
if not os.path.exists(savepath): os.makedirs(savepath)
        
if params['datasource'] is 'generate':
    # generate recording
    rx, sx_true = se.example_datasets.toy_example1(
        duration=600, num_channels=4, samplerate=30000, K=10)
else:
    # download recording
    kpath = 'kbucket://8b61fa1d5901/'
    rx = se.MdaRecordingExtractor(kpath, download=True)
    sx_true = se.MdaSortingExtractor(kpath + 'firings_true.mda')   
    
se.MdaRecordingExtractor.writeRecording(
    recording=rx, save_path=recpath)
se.MdaSortingExtractor.writeSorting(
    sorting=sx_true, save_path=os.path.join(savepath, 'firings_true.mda'))

In [None]:
# Run spike sorting

firings_out = os.path.join(savepath, 'firings_out.mda')
v_sorters[params['sorter']](
    recpath, 
    firings_out)
sx = se.MdaSortingExtractor(firings_out)


In [None]:
# Run validation and display
comparison=st.comparison.SortingComparison(sorting1=sx_true,sorting2=sx,sorting1_name='irc',sorting2_name='true')
comparison_table = sw.SortingComparisonTable(comparison=comparison)

print('sorting output for {}:'.format(params['sorter']))
comparison_table.display()