# View main analysis

This notebook provides a view into a snapshot of the SpikeForest analysis. A snapshot URL may be obtained from the "Archive" section of the website or it may be created offline using the spikeforest Python package.

Because this notebook is checked into the git repo, it is a good idea to make a working copy before running or modifying it. If you do modify and push back to the repo, please clear the outputs first.

In [None]:
# Imports

from mountaintools import client as mt
from spikeforest import SFMdaRecordingExtractor, SFMdaSortingExtractor
# from spikeforest import MainAnalysisView
import spikeforestwidgets as SFW
import vdomr as vd
import numpy as np
import pandas as pd

In [None]:
# Load the analysis snapshot object
# You can obtain a snapshot URL from the "Archive" section of the website
# or you can use a path to a local file
snapshot_path = 'sha1://d0eb11774305a926e75ad232e4a6b4a54ffed4b2/analysis.json'

# Configure mountaintools to download from the public spikeforest kachery
mt.configDownloadFrom('spikeforest.public')
A = mt.loadObject(path=snapshot_path)

In [None]:
class MainAnalysisView():
    def __init__(self, obj: dict):
        self._obj = obj
        self._metric = 'accuracy'  # accuracy, precision, recall
        self._mode = 'average'  # average or count
        self._snr_threshold = 8
        self._metric_threshold = 0.8

    def mainTable(self):        
        A = self._obj
        snr_threshold = 8
        sorters = A['Sorters']
        table_rows = []
        for sset in A['StudySets']:
            # display(vd.h5(sset['name']))
            for study in sset['studies']:
                sar = self._find_study_analysis_result(study['name'])
                assert sar
                row = dict(
                    study=study['name']
                )
                for sr in sar['sortingResults']:
                    sorter_name = sr['sorterName']
                    SRs = self._find_sorting_results(study['name'], sorter_name)
                    if len(SRs) > 0:
                        if self._mode == 'average':
                            val = self._compute_average(sar, sr)
                        elif self._mode == 'count':
                            val = self._compute_count(sar, sr)
                        else:
                            val = 0
                        num_missing = self._count_missing(sr)
                        if num_missing > 0:
                            val = '{}*'.format(round(val, 2))
                        else:
                            val = '{}'.format(round(val, 2))
                    else:
                        val = ''
                    row[sorter_name] = val
                table_rows.append(row)
        df = pd.DataFrame(table_rows)
        df = df[['study'] + [sorter['name'] for sorter in sorters]]
        return df
    
    def setMetric(self, metric: str):
        self._metric = metric
        
    def setMode(self, mode: str):
        self._mode = mode
    
    def _compute_average(self, sar: dict, sorting_result:dict):
        snr_threshold = self._snr_threshold
        snrs = sar['trueSnrs']
        if self._metric == 'accuracy':
            x = sorting_result['accuracies']
        elif self._metric == 'precision':
            x = sorting_result['precisions']
        elif self._metric == 'recall':
            x = sorting_result['recalls']
        else:
            raise Exception('Invalid metric: {}'.format(self._metric))
        x_to_use = [x[i] for i in range(len(x)) if snrs[i] is not None and snrs[i] >= snr_threshold]
        x_to_use = [x for x in x_to_use if x is not None]
        if x_to_use:
            return np.mean(x_to_use)
        else:
            return 0

    def _compute_count(self, sar: dict, sorting_result:dict):
        metric_threshold = self._metric_threshold
        if self._metric == 'accuracy':
            x = sorting_result['accuracies']
        elif self._metric == 'precision':
            x = sorting_result['precisions']
        elif self._metric == 'recall':
            x = sorting_result['recall']
        else:
            raise Exception('Invalid metric: {}'.format(self._metric))
        x_to_use = [x[i] for i in range(len(x)) if x[i] is not None and x[i] >= metric_threshold]
        return len(x_to_use)
    
    def _find_sorting_results(self, study_name: str, sorter_name: str):
        return [SR for SR in self._obj['SortingResults'] if (SR['studyName'] == study_name) and (SR['sorterName'] == sorter_name)]

    def _find_study_analysis_result(self, study_name: str):
        A = self._obj
        for x in A['StudyAnalysisResults']:
            if x['studyName'] == study_name:
                return x
    def _count_missing(self, sorting_result: dict):
        return len([x for x in sorting_result['accuracies'] if x is None])

In [None]:
V = MainAnalysisView(A)
V.setMode('average')
V.setMetric('accuracy')
display(V.mainTable())

In [None]:
# James: here is how to superimpose an updated analysis on top of an existing one:

from_website = mt.loadObject(path='sha1://d0eb11774305a926e75ad232e4a6b4a54ffed4b2/analysis.json')
update = mt.loadObject(path='key://pairio/spikeforest/test1.json')
sorter_names_in_update = [s['name'] for s in update['Sorters']]
for sr in from_website['SortingResults']:
    if sr['sorterName'] not in sorter_names_in_update:
        update['SortingResults'].append(sr)
for sar in update['StudyAnalysisResults']:
    sarW = find_study_analysis_result(from_website, sar['studyName'])
    for sr in sarW['sortingResults']:
        if sr['sorterName'] not in sorter_names_in_update:
            sar['sortingResults'].append(sr)
for sorter in from_website['Sorters']:
    if sorter['name'] not in sorter_names_in_update:
        update['Sorters'].append(sorter)
        
A=update

In [None]:
# OLD INFO

# An example command for James:
# > ./assemble_website_data.py --output_ids hybrid_janelia_irc,paired_kampff_irc --dest_key_path output.json