In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import spikeinterface.full as si
import numpy as np
from pathlib import Path
import time
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import kachery_cloud as kcl
import figurl


%matplotlib widget

In [None]:
n_jobs = 10
job_kwargs = dict(n_jobs=n_jobs, chunk_duration="1s", progress_bar=True)

In [None]:
# load recording and sorting
rec = si.load_extractor("...")
sort = si.load_extractor("...")

In [None]:
we = si.extract_waveforms(rec, sort, folder="wf_folder", load_if_exists=True, **job_kwargs)

In [None]:
sparsity = si.get_template_channel_sparsity(we, method="radius", radius_um=50)

In [None]:
# templates
templates = {}
for unit in sort.unit_ids:
    template_mean = we.get_template(unit, mode="average", sparsity=sparsity)
    template_std = we.get_template(unit, mode="std", sparsity=sparsity)
    
    templates[unit] = {}
    templates[unit]["mean"] = template_mean
    templates[unit]["std"] = template_std 

In [None]:
# ccgs
ccgs, bins = si.compute_correlograms(sorting=sort, symmetrize=True,
                                     bin_ms=0.5)

In [None]:
# spike localization
locs = si.compute_spike_locations(we, method="monopolar_triangulation", method_kwargs={"raidus": 100},
                                  outputs="by_unit", load_if_exists=True, **job_kwargs)

In [None]:
# spike amplitudes
amplitudes = si.compute_spike_amplitudes(we, outputs="by_unit", load_if_exists=True, **job_kwargs)

In [None]:
# similarity
similarity = si.compute_template_similarity(we)

In [None]:
# template metrics
tm = si.calculate_template_metrics(we, upsampling_factor=10)

In [None]:
# quality metrics
metric_names = si.get_quality_metric_list()
# metric_names += si.get_quality_pca_metric_list()
metric_names += ["nearest_neighbor"]

# compute PC
pc = si.compute_principal_components(we, n_jobs=n_jobs, mode="by_channel_local", progress_bar=True,
                                     load_if_exists=True)

In [None]:
qm = si.compute_quality_metrics(we, sparsity=sparsity, verbose=True, progress_bar=True, 
                                metric_names=metric_names, n_jobs=n_jobs, load_if_exists=True)

In [None]:
# merge metrics
metrics = qm.merge(tm, left_index=True, right_index=True)
metrics

In [None]:
# unit locations
unit_locations = si.localize_units(we, method="monopolar_triangulation", output="dict")

# Package for SortingView

### UnitTable

```
type UTColumn = {
    key: string
    label: string
    dtype: string
}

type UTRow = {
    unitId: number
    values: {[key: string]: any}
}

export type UnitsTableViewData = {
    type: 'UnitsTable'
    columns: UTColumn[]
    rows: UTRow[]
}
```

In [None]:
sv_dict = dict()

In [None]:
unit_rows = [{"unitId": u, "values": {}} for u in sort.unit_ids]
unit_columns = []

units_table_view_data = dict(type="UnitsTable",
                             rows=unit_rows, columns=unit_columns)

sv_dict.update(dict(UnitsTableViewData=units_table_view_data))

### Templates

```
type AverageWaveformData = {
    unitId: number | string
    channelIds: (number | string)[]
    waveform: number[][]
    waveformStdDev?: number[][]
}

export type AverageWaveformsViewData = {
    type: 'AverageWaveforms'
    averageWaveforms: AverageWaveformData[]
    samplingFrequency: number
    noiseLevel: number
    channelLocations?: {[key: string]: number[]}
}
```

In [None]:
average_waveforms_data = [{"unitId": u, "channelIds": sparsity[u], "waveform": t["mean"], 
                           "waveformStdDev": t["std"]} for u, t in templates.items()]
locations = rec.get_channel_locations()
channel_locations = [{rec.channel_ids[ch]: locations[ch].astype("float32")} 
                     for ch in np.arange(rec.get_num_channels())]
average_waveforms_view_data = dict(type="AverageWaveforms",
                                   averageWaveforms=average_waveforms_data, 
                                   samplingFrequency=rec.get_sampling_frequency(),
                                   noiseLevel=1, channelLocations=channel_locations)
sv_dict.update(dict(AverageWaveformsViewData=average_waveforms_view_data))

### Correlograms

```
type AutocorrelogramData = {
    unitId: number
    binEdgesSec: number[]
    binCounts: number[]
}

export type AutocorrelogramsViewData = {
    type: 'Autocorrelograms'
    autocorrelograms: AutocorrelogramData[]
}


type CrosscorrelogramData = {
    unitId1: number
    unitId2: number
    binEdgesSec: number[]
    binCounts: number[]
}

export type CrosscorrelogramsViewData = {
    type: 'Crosscorrelograms'
    crosscorrelograms: CrosscorrelogramData[]
}
```

In [None]:
auto_correlogram_data = []
cross_correlogram_data = []

for i in range(ccgs.shape[0]):
    for j in range(i, ccgs.shape[0]):
        if i == j:
            auto_correlogram_data.append(dict(unitId=sort.unit_ids[i], 
                                              binEdgesSec=(bins/1000.).astype("float32"),
                                              binCounts=ccgs[i, j].astype("int32")))
        else:
            cross_correlogram_data.append(dict(unitId1=sort.unit_ids[i], 
                                               unitId2=sort.unit_ids[j], 
                                               binEdgesSec=(bins/1000.).astype("float32"),
                                               binCounts=ccgs[i, j].astype("int32")))
autocorrelograms_view_data = dict(type="Autocorrelograms", autocorrelograms=auto_correlogram_data)
crosscorrelograms_view_data = dict(type="Crosscorrelograms", crosscorrelograms=cross_correlogram_data)

sv_dict.update(dict(AutocorrelogramsViewData=autocorrelograms_view_data))
sv_dict.update(dict(CrosscorrelogramsViewData=crosscorrelograms_view_data))

### Spike Amplitudes

```
type SAUnitData = {
    unitId: number
    spikeTimesSec: number[]
    spikeAmplitudes: number[]
}

export type SpikeAmplitudesViewData = {
    type: 'SpikeAmplitudes'
    startTimeSec: number
    endTimeSec: number
    units: SAUnitData[]
}
```

In [None]:
sa_unit_data = [{"unitId": u, 
                 "spikeTimesSec": (sort.get_unit_spike_train(u) / sort.get_sampling_frequency()).astype("float32"), 
                 "spikeAmplitudes": amps} for u, amps in amplitudes[0].items()]
# channel_locations = [chan: loc]
spike_amplitudes_view_data = dict(type="SpikeAmplitudes",
                                  startTimeSec=0,
                                  endTimeSec=rec.get_total_duration(), 
                                  units=sa_unit_data)
sv_dict.update(dict(SpikeAmplitudesViewData=spike_amplitudes_view_data))

### Spike Locations

```
type SLUnitData = {
    unitId: number
    xLocations: number[]
    yLocations: number[]
    zLocations?: number[]
    spikeTimesSec?: number[] // optionally sync with other view
}

type SpikeLocationsViewData = {
    type: 'SpikeLocations'
    units: SLUnitData[]
    channelLocations?: {[key: string]: number[]}
}
```

In [None]:
sl_unit_data = [{"unitId": u, 
                 "xLocations": loc["x"].astype("float32"),
                 "yLocations": loc["y"].astype("float32"),
                 "zLocations": loc["z"].astype("float32"),
                 "spikeTimesSec":  (sort.get_unit_spike_train(u) / sort.get_sampling_frequency()).astype("float32")}
                for u, loc in locs[0].items()]
# channel_locations = [chan: loc]
spike_locations_view_data = dict(type="SpikeLocations",
                                 startTimeSec=0, 
                                 endTimeSec=rec.get_total_duration(), 
                                 units=sl_unit_data)
sv_dict.update(dict(SpikeLocationsViewData=spike_locations_view_data))

### Similarity

```
type UnitSimilarityMatrixViewData = {
    type: 'UnitSimilarityMatrix'
    unitIds: (number | string)[]
    similarityScores: number[][] // numpy matrix
}
```

In [None]:
unit_similarity_matrix_view_data = dict(type="UnitSimilarityMatrix",
                                        unitIds=sort.unit_ids.astype("int32"),
                                        similarityScores=similarity)
sv_dict.update(dict(UnitSimilarityMatrixViewData=unit_similarity_matrix_view_data))

### Unit locations

```
type ULUnitData = {
    unitId: number
    location: [number, number]
}

type UnitLocationsViewData = {
    type: 'UnitLocations'
    units: ULUnitData[]
    channelLocations?: {[key: string]: number[]}
}
```


In [None]:
ul_unit_data = [dict(unitId=u, location=loc.astype("float32")) for u, loc in unit_locations.items()]

unit_locations_view_data = dict(type="UnitLocations", 
                                units=ul_unit_data, 
                                channelLocations=channel_locations)
sv_dict.update(dict(UnitLocationsViewData=unit_locations_view_data))

### Metrics

```
type UMMetric = {
    name: string
    metricType: 'quality' | 'template' | string
    description: string
}

type UMUnit = {
    unitId: number
    values: {[name: string]: any} // key corresponds to the name
}

export type UnitMetricsViewData = {
    type: 'UnitMetrics'
    metrics: UMMetric[]
    units: UMUnit[]
}
```

In [None]:
template_metric_names = si.get_template_metric_names()

skip_metrics = ['isi_violations_rate', 'isi_violations_count']

um_metrics = []
for metric in metrics.columns:
    if metric not in skip_metrics:
        if metric in template_metric_names:
            metric_type = "template"
        else:
            metric_type = "quality"
        um_metrics.append(dict(name=metric, metricType=metric_type, description=""))

um_units = []
for index, row in metrics.iterrows():
    values = {}
    for metric in row.keys():
        if metric not in skip_metrics:
            values[metric] = row[metric]
    um_units.append(dict(unitId=index, values=values))
    
unit_metrics_view_data = dict(type="UnitMetrics", 
                              metrics=um_metrics, units=um_units)

sv_dict.update(dict(UnitMetricsViewData=unit_metrics_view_data))

### Save to kachery cloud

In [None]:
def _upload_data_and_return_uri(data):
    return kcl.store_json(figurl.serialize_data(data))

In [None]:
data = {
        'type': 'SortingLayout',
        'layout': {
            'type': 'Box',
            'direction': 'horizontal',
            'items': [
                {'type': 'View', 'viewId': 'ut'},
                {
                    'type': 'Box',
                    'direction': 'vertical',
                    'items': [
                        {'type': 'View', 'viewId': 'sa'},
                        {'type': 'View', 'viewId': 'ac'}
                    ],
                    'itemProperties': [
                        {},
                        {}
                    ]
                }
            ],
            'itemProperties': [
                {'maxSize': 150},
                {}
            ]
        },
        'views': [
            {
                'viewId': 'ut',
                'type': 'UnitsTable',
                'dataUri': kcl.store_json(figurl.serialize_data(sv_dict['UnitsTableViewData']))
            },
            {
                'viewId': 'aw',
                'type': 'AverageWaveforms',
                'dataUri': kcl.store_json(figurl.serialize_data(sv_dict['AverageWaveformsViewData']))
            },
            {
                'viewId': 'sa',
                'type': 'SpikeAmplitudes',
                'dataUri': kcl.store_json(figurl.serialize_data(sv_dict['SpikeAmplitudesViewData']))
            },
            {
                'viewId': 'ac',
                'type': 'Autocorrelograms',
                'dataUri': kcl.store_json(figurl.serialize_data(sv_dict['AutocorrelogramsViewData']))
            }
        ]
    }

In [None]:
F = figurl.Figure(view_url='gs://figurl/spikesortingview-5', data=data)
url = F.url(label='Alessio test data fixed')
print(url)