In [None]:
import os
import numpy as np
import pandas as pd
import kachery as ka
import hither2 as hi
from spikeforest2_utils import AutoSortingExtractor, MdaRecordingExtractor
import spikeextractors as se

In [None]:
a = ka.load_object(
    'sha1://ca5fb596746c4ddca9d74df25b9bb1ecfd626f21/analysis.json',
    fr='default_readwrite'
)
assert a is not None

In [None]:
class SFAnalysis:
    def __init__(self, path_or_dict):
        if type(path_or_dict) == str:
            self._obj = ka.load_object(path_or_dict)
            assert self._obj is not None, 'Unable to load analysis object'
        elif type(path_or_dict) == dict:
            self._obj = path_or_dict
        else:
            raise Exception('Invalid type for path_or_dict. Must be string or dict.')
    
    def get_study_object(self, study_name):
        a = [s for ss in self._obj['StudySets'] for s in ss['studies'] if s['name'] == study_name]
        assert len(a) >= 1, f'Unable to get study object {study_name}'
        assert len(a) <= 1, f'Multiple study objects found for {study_name}'
        return a[0]
    
    def get_study_names(self):
        return [s['name'] for ss in self._obj['StudySets'] for s in ss['studies']]
    
    def get_recording_names(self, study_name):
        s = self.get_study_object(study_name)
        return [r['name'] for r in s['recordings']]
    
    def get_sorter_names(self):
        return [s['name'] for s in self._obj['Sorters']]
    
    def get_recording_object(self, study_name, recording_name):
        study_obj = self.get_study_object(study_name)
        a = [r for r in study_obj['recordings'] if r['name'] == recording_name]
        assert len(a) >= 1, f'Unable to get recording object {study_name} {recording_name}'
        assert len(a) <= 1, f'Multiple recording objects found for {study_name} {recording_name}'
        return a[0]
    
    def get_recording_path(self, study_name, recording_name):
        recobj = self.get_recording_object(study_name, recording_name)
        return recobj['directory']
    
    def get_recording_extractor(self, study_name, recording_name, download=False):
        recobj = self.get_recording_object(study_name, recording_name)
        R = MdaRecordingExtractor(recording_directory=recobj['directory'], download=download)
        return R
    
    def get_sorting_result_object(self, study_name, recording_name, sorter_name):
        srs = [sr for sr in self._obj['SortingResults'] if sr['studyName'] == study_name and sr['recordingName'] == recording_name and sr['sorterName'] == sorter_name]
        assert len(srs) >= 1, f'Unable to get sorting result object {study_name} {recording_name} {sorter_name}'
        assert len(srs) <= 1, f'Multiple sorting result objects found for {study_name} {recording_name} {sorter_name}'
        return srs[0]
    
    def get_sorting_result_objects(self, study_names, sorter_names):
        srs = [sr for sr in self._obj['SortingResults'] if sr['studyName'] in study_names and sr['sorterName'] in sorter_names]
        return srs
    
    def get_sorting_result_extractor(self, study_name, recording_name, sorter_name):
        R = self.get_recording_extractor(study_name, recording_name)
        srobj = self.get_sorting_result_object(study_name, recording_name, sorter_name)
        if 'firings' not in srobj:
            return None
        S = AutoSortingExtractor(srobj['firings'], samplerate=R.get_sampling_frequency())
        return S

    def get_sorting_result_path(self, study_name, recording_name, sorter_name):
        R = self.get_recording_extractor(study_name, recording_name)
        srobj = self.get_sorting_result_object(study_name, recording_name, sorter_name)
        return srobj['firings']
    
    def get_sorting_true_extractor(self, study_name, recording_name):
        R = self.get_recording_extractor(study_name, recording_name)
        recobj = self.get_recording_object(study_name, recording_name)
        S = AutoSortingExtractor(recobj['firingsTrue'], samplerate=R.get_sampling_frequency())
        return S
    
    def get_sorting_true_path(self, study_name, recording_name):
        R = self.get_recording_extractor(study_name, recording_name)
        recobj = self.get_recording_object(study_name, recording_name)
        return recobj['firingsTrue']
    
    def get_study_analysis_result_object(self, study_name):
        ret = [sar for sar in self._obj['StudyAnalysisResults'] if sar['studyName'] == study_name]
        if len(ret) == 0:
            return None
        elif len(ret) == 1:
            return ret[0]
        else:
            raise Exception(f'Too many results for study: {study_name}')
    
    def get_study_sorting_result_object(self, study_name, sorter_name):
        sars = [sar for sar in self._obj['StudyAnalysisResults'] if sar['studyName'] == study_name]
        ret = [r for sar in sars for r in sar['sortingResults'] if r['sorterName'] == sorter_name]
        if len(ret) == 0:
            return None
        elif len(ret) == 1:
            return ret[0]
        else:
            raise Exception(f'Too many results for study and sorter: {study_name} {sorter_name}')

In [None]:
A = SFAnalysis(a)

In [None]:
durations = [300, 600, 1200, 2400, 4800]
nchs = [8, 16]
sorter_names = A.get_sorter_names()

data = []

for nch in nchs:
    for dur in durations:
        for mode in ['STATIC', 'DRIFT']:
            study_name = f'LONG_{mode}_{dur}s_{nch}c'
            for sorter_name in sorter_names:
                sar = A.get_study_analysis_result_object(study_name)
                true_snrs = sar['trueSnrs']
                obj = A.get_study_sorting_result_object(study_name, sorter_name)
                accuracies = obj['accuracies']
                assert len(accuracies) == len(true_snrs)
                inds = [idx for idx in range(len(true_snrs)) if true_snrs[idx] >= 8 and accuracies[idx] is not None]
                accuracies0 = [accuracies[idx] for idx in inds]
                avg_accuracy = np.mean(accuracies0)
                data.append(dict(
                    duration=dur,
                    num_channels=nch,
                    study=study_name,
                    sorter=sorter_name,
                    avg_accuracy=avg_accuracy,
                    mode=mode
                ))
df = pd.DataFrame(data)

In [None]:
for mode in ['STATIC', 'DRIFT']:
    import altair as alt
    txt = 'no drift' if mode == 'STATIC' else 'with drift'
    ch = alt.Chart(
        df[df['mode'] == mode],
        title=[f'Accuracy vs. duration for ten spike sorters ({txt})']
    ).mark_bar().encode(
        x=alt.X('duration:O', axis=alt.Axis(title='duration (sec)')),
        y=alt.Y('avg_accuracy:Q', axis=alt.Axis(format='%', title='Avg. accuracy')),
        column='sorter:N'
    )

    display(ch)

In [None]:
A.get_study_names()

In [None]:
A.get_study_sorting_result_object('LONG_STATIC_4800s_16c', 'MountainSort4')

In [None]:
a['StudyAnalysisResults'][0]['sortingResults'][0]['sorterName']