In [None]:
import os
import kachery as ka
import hither2 as hi
from spikeforest2_utils import AutoSortingExtractor, MdaRecordingExtractor
import spikeextractors as se
from compute_units_info import compute_units_info
from compare_with_truth import compare_with_truth

In [None]:
a = ka.load_object(
    'sha1://94624793f05df0187450e97df1d71119ef78c9ae/analysis.json',
    fr='default_readwrite'
)
assert a is not None

In [None]:
database = hi.Database(mongo_url=os.environ['LABBOX_EPHYS_MONGO_URI'], database='labbox')
job_handler = hi.RemoteJobHandler(
    database=database,
    compute_resource_id='spikeforest_flatiron'
)
job_cache = hi.JobCache(database=database)

In [None]:
class SFAnalysis:
    def __init__(self, path_or_dict):
        if type(path_or_dict) == str:
            self._obj = ka.load_object(path_or_dict)
            assert self._obj is not None, 'Unable to load analysis object'
        elif type(path_or_dict) == dict:
            self._obj = path_or_dict
        else:
            raise Exception('Invalid type for path_or_dict. Must be string or dict.')
    
    def get_study_object(self, study_name):
        a = [s for ss in self._obj['StudySets'] for s in ss['studies'] if s['name'] == study_name]
        assert len(a) >= 1, f'Unable to get study object {study_name}'
        assert len(a) <= 1, f'Multiple study objects found for {study_name}'
        return a[0]
    
    def get_study_names(self):
        return [s['name'] for ss in self._obj['StudySets'] for s in ss['studies']]
    
    def get_recording_names(self, study_name):
        s = self.get_study_object(study_name)
        return [r['name'] for r in s['recordings']]
    
    def get_recording_object(self, study_name, recording_name):
        study_obj = self.get_study_object(study_name)
        a = [r for r in study_obj['recordings'] if r['name'] == recording_name]
        assert len(a) >= 1, f'Unable to get recording object {study_name} {recording_name}'
        assert len(a) <= 1, f'Multiple recording objects found for {study_name} {recording_name}'
        return a[0]
    
    def get_recording_path(self, study_name, recording_name):
        recobj = self.get_recording_object(study_name, recording_name)
        return recobj['directory']
    
    def get_recording_extractor(self, study_name, recording_name, download=False):
        recobj = self.get_recording_object(study_name, recording_name)
        R = MdaRecordingExtractor(recording_directory=recobj['directory'], download=download)
        return R
    
    def get_sorting_result_object(self, study_name, recording_name, sorter_name):
        srs = [sr for sr in self._obj['SortingResults'] if sr['studyName'] == study_name and sr['recordingName'] == recording_name and sr['sorterName'] == sorter_name]
        assert len(srs) >= 1, f'Unable to get sorting result object {study_name} {recording_name} {sorter_name}'
        assert len(srs) <= 1, f'Multiple sorting result objects found for {study_name} {recording_name} {sorter_name}'
        return srs[0]
    
    def get_sorting_result_objects(self, study_names, sorter_names):
        srs = [sr for sr in self._obj['SortingResults'] if sr['studyName'] in study_names and sr['sorterName'] in sorter_names]
        return srs
    
    def get_sorting_result_extractor(self, study_name, recording_name, sorter_name):
        R = self.get_recording_extractor(study_name, recording_name)
        srobj = self.get_sorting_result_object(study_name, recording_name, sorter_name)
        S = AutoSortingExtractor(srobj['firings'], samplerate=R.get_sampling_frequency())
        return S

    def get_sorting_result_path(self, study_name, recording_name, sorter_name):
        R = self.get_recording_extractor(study_name, recording_name)
        srobj = self.get_sorting_result_object(study_name, recording_name, sorter_name)
        return srobj['firings']
    
    def get_sorting_true_extractor(self, study_name, recording_name):
        R = self.get_recording_extractor(study_name, recording_name)
        recobj = self.get_recording_object(study_name, recording_name)
        S = AutoSortingExtractor(recobj['firingsTrue'], samplerate=R.get_sampling_frequency())
        return S
    
    def get_sorting_true_path(self, study_name, recording_name):
        R = self.get_recording_extractor(study_name, recording_name)
        recobj = self.get_recording_object(study_name, recording_name)
        return recobj['firingsTrue']

In [None]:
A = SFAnalysis(a)

In [None]:
study_names = ['hybrid_static_tetrode', 'hybrid_static_siprobe']
sorter_names = ['MountainSort4', 'SpykingCircus', 'IronClust', 'KiloSort', 'KiloSort2']
true_units = []
sorted_units = []
with hi.config(job_cache=job_cache, job_handler=job_handler, container=True), ka.config(fr='default_readwrite'):
    sorted_unit_results = []
    true_unit_results = []
    for study_name in study_names:
        for sorter_name in sorter_names:
            for recname in A.get_recording_names(study_name):
                print(f'{study_name} {recname}')
                R = A.get_recording_extractor(study_name, recording_name=recname, download=False)
                duration_sec = R.get_num_frames() / R.get_sampling_frequency()
                S_true = A.get_sorting_true_extractor(study_name, recording_name=recname)
                S = A.get_sorting_result_extractor(study_name, recording_name=recname, sorter_name=sorter_name)

                # Do we need to do this?
                ka.store_file(A.get_sorting_result_path(study_name, recname, sorter_name), to='default_readwrite')
                
                sorted_units_info = compute_units_info.run(
                    recording_path=A.get_recording_path(study_name, recname),
                    sorting_path=A.get_sorting_result_path(study_name, recname, sorter_name)
                ).set(label=f'Compute sorted units info: {study_name} {recname} {sorter_name}')
                # we are reversing true vs. sorted here, intentionally
                comparison = compare_with_truth.run(
                    sorting_path=A.get_sorting_true_path(study_name, recname),
                    sorting_true_path=A.get_sorting_result_path(study_name, recname, sorter_name)
                ).set(label=f'Compare with truth (*): {study_name} {recname} {sorter_name}')
                sorted_unit_results.append(dict(
                    study=study_name,
                    recording=recname,
                    sorter=sorter_name,
                    S=S,
                    S_true=S_true,
                    units_info=sorted_units_info,
                    comparison=comparison
                ))

                true_units_info = compute_units_info.run(
                    recording_path=A.get_recording_path(study_name, recname),
                    sorting_path=A.get_sorting_true_path(study_name, recname)
                ).set(label=f'Compute true units info: {study_name} {recname}')
                comparison = compare_with_truth.run(
                    sorting_path=A.get_sorting_result_path(study_name, recname, sorter_name),
                    sorting_true_path=A.get_sorting_true_path(study_name, recname)
                ).set(label=f'Compare with truth: {study_name} {recname} {sorter_name}')
                true_unit_results.append(dict(
                    study=study_name,
                    recording=recname,
                    sorter=sorter_name,
                    S=S,
                    S_true=S_true,
                    units_info=true_units_info,
                    comparison=comparison
                ))
        
        
    for result in sorted_unit_results:
        units_info = result['units_info'].wait()
        comparison = result['comparison'].wait()
        S = result['S']
        S_true = result['S_true']
        units_info_by_id = dict()
        for u in units_info:
            units_info_by_id[u['unit_id']] = u
        comparison_by_id = dict()
        for _, a in comparison.items():
            comparison_by_id[a['unit_id']] = a
        for uid in S.get_unit_ids():
            u = units_info_by_id[uid]
            c = comparison_by_id[uid]
            unit = dict(
                study=result['study'],
                recording=result['recording'],
                sorter=result['sorter'],
                unit_id=uid,
                snr=u['snr'],
                peak_channel=u['peak_channel'],
                num_events=u['num_events'],
                firing_rate=u['firing_rate'],
                accuracy=float(c['accuracy']),
                best_unit=c['best_unit'],
                matched_unit=c['matched_unit'],
                num_matches=c['num_matches'],
                num_false_negatives=c['num_false_positives'], # we intentionally switch fp/fn
                num_false_positives=c['num_false_negatives'],
                f_n=float(c['f_p']), # we intentionally switch fp/fn
                f_p=float(c['f_n'])
            )
            sorted_units.append(unit)
    
    for result in true_unit_results:
        units_info = result['units_info'].wait()
        comparison = result['comparison'].wait()
        S = result['S']
        S_true = result['S_true']
        units_info_by_id = dict()
        for u in units_info:
            units_info_by_id[u['unit_id']] = u
        comparison_by_id = dict()
        for _, a in comparison.items():
            comparison_by_id[a['unit_id']] = a
        for uid in S_true.get_unit_ids():
            u = units_info_by_id[uid]
            c = comparison_by_id[uid]
            unit = dict(
                study=result['study'],
                recording=result['recording'],
                sorter=result['sorter'],
                unit_id=uid,
                snr=u['snr'],
                peak_channel=u['peak_channel'],
                num_events=u['num_events'],
                firing_rate=u['firing_rate'],
                accuracy=float(c['accuracy']),
                best_unit=c['best_unit'],
                matched_unit=c['matched_unit'],
                num_matches=c['num_matches'],
                num_false_negatives=c['num_false_negatives'],
                num_false_positives=c['num_false_positives'],
                f_n=float(c['f_n']),
                f_p=float(c['f_p'])
            )
            true_units.append(unit)

In [None]:
import pandas as pd
df_sorted_units = pd.DataFrame(sorted_units)
df_true_units = pd.DataFrame(true_units)

In [None]:
display(df_sorted_units)
display(df_true_units)

In [None]:
def store_dataframe(df, basename):
    with hi.TemporaryDirectory() as tmpdir:
        fname = tmpdir + '/' + basename
        df.to_csv(fname)
        return ka.store_file(fname)

with ka.config(to='default_readwrite'):
    csv_sorted_units = store_dataframe(df_sorted_units, basename='sorted_units.csv')
    csv_true_units = store_dataframe(df_true_units, basename='true_units.csv')
    print(csv_sorted_units)
    print(csv_true_units)

In [None]:
from IPython.core.display import display, HTML
import altair as alt
for study in study_names:
    for sorter in sorter_names:
        display(HTML(f'<h2>{study} ({sorter})</h2>'))

        df0 = df_sorted_units[(df_sorted_units['sorter'] == sorter) & (df_sorted_units['study'] == study)]
        ch1 = alt.Chart(
            df0[(df0['snr'] >= 5)],
            title=[f'Accuracy vs. firing rate for sorted units (snr>=5)']
        ).mark_point().encode(
            x='firing_rate',
            y='accuracy'
        )
        ch2 = alt.Chart(
            df0,
            title=[f'Accuracy vs. SNR for sorted units']
        ).mark_point().encode(
            x='snr',
            y='accuracy'
        )


        df0 = df_true_units[(df_true_units['sorter'] == sorter) & (df_true_units['study'] == study)]
        ch3 = alt.Chart(
            df0,
            title=[f'Accuracy vs. firing rate for true units']
        ).mark_point().encode(
            x='firing_rate',
            y='accuracy'
        )
        ch4 = alt.Chart(
            df0,
            title=[f'Accuracy vs. SNR for true units']
        ).mark_point().encode(
            x='snr',
            y='accuracy'
        )
        alt.vconcat(
            alt.hconcat(ch1, ch2),
            alt.hconcat(ch3, ch4)
        ).display()