In [None]:
%load_ext autoreload
%autoreload 2

import spikeextractors as si
import spikewidgets as sw
import spiketoolkit as st

import os

from kbucket import client as kb
from pairio import client as pa

#from spikeforest_sort import spikeforest_sort, mountainsort4b_params, 

import spikeforest as sf

In [None]:
connect_to_remote=False

if connect_to_remote:
    if os.getenv('SPIKEFOREST_PAIRIO_TOKEN'):
        pa.setConfig(user='spikeforest',token=os.getenv('SPIKEFOREST_PAIRIO_TOKEN'))

    if os.getenv('SPIKEFOREST_KBUCKET_TOKEN'):
        print('Setting upload server to magland.spikeforest')
        kb.setConfig(upload_share_id='magland.spikeforest',upload_token=os.getenv('SPIKEFOREST_KBUCKET_TOKEN'))
        kb.testUpload()

In [None]:
pa.setConfig(collections=['spikeforest'])
kb.setConfig(share_ids=['magland.spikeforest'])

In [None]:
sorters=[]
ms4_params=dict(
    detect_sign=-1,
    adjacency_radius=-1,
    detect_threshold=3
)
sorters.append(dict(
    name='MountainSort4',
    processor=sf.MountainSort4,
    params=ms4_params
))

In [None]:
datasets=[]
study_dir='kbucket://b5ecdf1474c5/spikeforest/gen_synth_datasets/datasets_noise10_K20'
study_name='synth_jfm_noise10_K20'
dd=kb.readDir(study_dir)
for dsname in dd['dirs']:
    dsdir='{}/{}'.format(study_dir,dsname)
    datasets.append(dict(
        name=dsname,
        dataset_dir=dsdir
    ))
#datasets=[datasets[0]]

In [None]:
from matplotlib import pyplot as plt
from PIL import Image

def save_plot(jpg_fname,quality=20):
    plt.savefig(jpg_fname+'.png')
    plt.close()
    im=Image.open(jpg_fname+'.png').convert('RGB')
    os.remove(jpg_fname+'.png')
    im.save(jpg_fname,quality=quality)

def prepareSortingSummary(result):
    ret={'plots':{}}
    recording=si.MdaRecordingExtractor(dataset_directory=result['dataset_dir'])
    sorting=si.MdaSortingExtractor(firings_file=kb.realizeFile(result['firings']))
    sw.UnitWaveformsWidget(recording=recording,sorting=sorting).plot()
    save_plot('unit_waveforms.jpg')
    fname=kb.uploadFile('unit_waveforms.jpg')
    ret['plots']['unit_waveforms']=fname
    return ret

In [None]:
def prepareComparisonWithTruth(result):
    ret={}
    sorting=si.MdaSortingExtractor(firings_file=kb.realizeFile(result['firings'])) # for now we need realizeFile -- later this will not be necessary, once we are using kbucket in spikeextractors
    sorting_true=si.MdaSortingExtractor(firings_file=result['firings_true'])
    SC=st.comparison.SortingComparison(sorting_true,sorting)
    ret['table']=sw.SortingComparisonTable(comparison=SC).getDataframe().transpose().to_dict()
    return ret

In [None]:
results=[]
for dataset in datasets:
    for sorter in sorters:
        print('SORTER:{}     DATASET: {}'.format(sorter['name'],dataset['name']))
        result=sf.sortDataset(
            sorter=sorter,
            dataset=dataset,
            _force_run=False
        )
        result['comparison_with_truth']=prepareComparisonWithTruth(result)
        result['summary']=prepareSortingSummary(result)
        results.append(result)
    #break
print('Saving results object...')
kb.saveObject(results,key=dict(name='spikeforest_results',study_name=study_name))