In [None]:
load ext autoreload
autoreload

In [None]:
import numpy as np
from typing import Dict, List, Set

import neuropixels_data_sep_2020 as nd
import spikeextractors as se
import numpy as np
from scipy.stats import wasserstein_distance
from labbox_ephys import prepare_snippets_h5
import kachery as ka
import h5py



In [None]:
def find_channel_neighborhoods(recording, max_dist) -> Dict[int, List[int]]:
    channel_ids = [int(x) for x in recording.get_channel_ids()]
    locations_by_channel = {}
    neighborhoods = {}
    for channel_id in channel_ids:
        locations_by_channel[channel_id] = np.array(recording.get_channel_property(channel_id=channel_id, property_name='location'))
    for channel_id in channel_ids:
        neighborhood_channel_ids = []
        home_location = locations_by_channel[channel_id]
        for other_channel_id in channel_ids:
            loc = locations_by_channel[int(other_channel_id)]
            dist = np.linalg.norm(np.array(loc) - np.array(home_location))
            if dist < max_dist:
                neighborhood_channel_ids.append(int(other_channel_id))
        neighborhoods[int(channel_id)] = neighborhood_channel_ids
    return neighborhoods

In [None]:
def find_unit_peak_channels_hd5(h5_snippet_uri, unit_ids):
    h5_path = ka.load_file(h5_snippet_uri)
    unit_maximum_channels = {}
    with h5py.File(h5_path, 'r') as f:
        for unit_id in unit_ids:
            waveforms = np.array(f.get(f'unit_waveforms/{unit_id}/waveforms'))
            channel_ids = np.array(f.get(f'unit_waveforms/{unit_id}/channel_ids'))
            average_waveform = np.mean(waveforms, axis=0)
            # NOTE CHANGE from fetch_average_waveform_plot_data -- this is peak-to-trough
            # that function does largest-amplitude-from-0 (which may be different)
            channel_maxima = np.max(average_waveform, axis=1) - np.min(average_waveform, axis=1)
            index_of_max_channel = np.argmax(channel_maxima)
            max_channel_id = channel_ids[index_of_max_channel]
            unit_maximum_channels[unit_id] = max_channel_id
    return unit_maximum_channels

In [None]:
# Distances are symmetric.
# For any meaningful comparison to happen we have to be comparing
# units from s1 with units from s2--so we should be able to find the
# complete list of pairs by iterating over just one of the sortings.
def find_local_unit_pairs(sorting1, sorting2, snippet_uri_1, snippet_uri_2, channel_neighborhoods) -> Dict[int, Set[int]]:
    pairs: Dict[int, Set[int]] = {}
    s1_ids = sorting1.get_unit_ids()
    s2_ids = sorting2.get_unit_ids()
    s1_peak_channels: Dict[int, int] = find_unit_peak_channels_hd5(snippet_uri_1, s1_ids)
    s2_peak_channels: Dict[int, int] = find_unit_peak_channels_hd5(snippet_uri_2, s2_ids)
    s2_units_per_peak_channel: Dict[int, Set[int]] = {}
    for channel in s2_peak_channels.values():
        s2_units_per_peak_channel[channel] = set([unit for unit in s2_peak_channels.keys() if s2_peak_channels[unit] == channel])
    for s1_unit in s1_ids:
        peak = s1_peak_channels[s1_unit]
        pairs[s1_unit] = set()
        for channel in channel_neighborhoods[peak]:
            if channel not in s2_units_per_peak_channel: continue
            pairs[s1_unit] = pairs[s1_unit] | s2_units_per_peak_channel[channel]
    return pairs


In [None]:
def count_matching_events(times1, times2, delta=10):
    # put all times in one array, first sorting1 then sorting2
    times_concat = np.concatenate((times1, times2))
    # create mask to identify which of the sortings each spike comes from
    membership = np.concatenate((np.ones(times1.shape) * 1, np.ones(times2.shape) * 2))
    # argsort: gets a list of indices that represent the sorting order of the input
    # (e.g. if input = [5 6 2] output would be [2 0 1] as input[2] < input[0] < input[1])
    # so indices now has the position in a fully sorted list that each element in times_concat would have
    indices = times_concat.argsort()
    # this sorts every spike by time, regardless of sorting source. membership still identifies source.
    times_concat_sorted = times_concat[indices]
    # and this gives us a mask of which sorting each spike comes from, in same order as times_concat_sorted
    membership_sorted = membership[indices]

    # difference between each element and the followingn element (pairwise subtract [a..y] - [b..z])
    diffs = times_concat_sorted[1:] - times_concat_sorted[:-1]
    # & in this ctxt joins queries.This is finding those indices where diffs < delta and the two units
    # do not belong to the same sorting (since sorting[0:-1] is the source train of the left-hand
    # spike in the difference, and sorting[1:] is the source of the right-hand spike).
    # np.where returns a tuple of (array of indices, type) and we only want the index array, hence [0].
    inds = np.where((diffs <= delta) & (membership_sorted[0:-1] != membership_sorted[1:]))[0]
    if (len(inds) == 0):
        return 0
    inds2 = np.where(inds[:-1] + 1 != inds[1:])[0]
    return len(inds2) + 1

def count_unmatched_events(times1, times2, matched_events_count):
    # This is "precision" and "recall" in the "no-ground-truth" regime.
    # Which is essentially symmetrical: I don't know who's right.
    # So all I can say is "A found this many that B didn't".
    # It might be worthwhile to set some sort of confidence thresholds, etc. but that
    # is not yet supported in the underlying sorters.
    t1_unmatched = len(times1) - matched_events_count
    t2_unmatched = len(times2) - matched_events_count
    return (t1_unmatched, t2_unmatched)

def jaccard_index(times1, times2, match_count):
    # Jaccard index is |A n B| / |A U B|, i.e. count(matches)/count(distinct spikes)
    # and of course |A U B| is |A| + |B| - |A n B|
    all_spikes_count = len(times1) + len(times2) - match_count
    return match_count / all_spikes_count

def jaccard_distance(times1, times2, match_count):
    # Convenience, if we want to compute distances
    return 1 - jaccard_index(times1, times2, match_count)

In [None]:
from scipy.stats import wasserstein_distance
# See https://en.wikipedia.org/wiki/Earth_mover%27s_distance and https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.wasserstein_distance.html
def distance_score(times1, times2):
    # NOTE: any sorter that returned confidence metrics could weight the observations by confidence
    # since the scipy implementation allows weighting the values in the empirical distribution
    return wasserstein_distance(times1, times2)

In [None]:
def compare_pairs(pairs, sorting1, sorting2):
    # Ultimately this will return a sparse matrix
    # for now let's just make it work with lists-of-lists or something
    distances = {}
    for s1unit in pairs.keys():
        times1 = sorting1.get_unit_spike_train(s1unit)
        for s2unit in pairs[s1unit]:
            times2 = sorting2.get_unit_spike_train(s2unit)
            tag = (s1unit, s2unit)
            matches = count_matching_events(times1, times2)
            # not currently using this
            # unmatched = count_unmatched_events(times1, times2, matches)
            # distances maps the tag to a tuple of (wasserstein-distance, jaccard-distance).
            distances[tag] = (distance_score(times1, times2), jaccard_distance(times1, times2, matches))
    # Put into sparse-matrix format
    return distances
    

In [None]:
import neuropixels_data_sep_2020 as nd; import spikeextractors as se
recording_id = 'cortexlab-single-phase-3 (ch 0-7, 10 sec)'
recording = nd.load_recording(recording_id, download=False)


In [None]:
# Sample usage

import kachery_p2p as kp
recording_uri = 'sha1dir://fb52d510d2543634e247e0d2d1d4390be9ed9e20.synth_magland/datasets_noise10_K10_C4/001_synth'
recording1 = nd.LabboxEphysRecordingExtractor(recording_uri, download=True)
sorting1_uri = 'sha1://3e411871054e5f3f1a4fabc61291db9d835c5201/3e411871054e5f3f1a4fabc61291db9d835c5201/firings.mda'
sorting2_uri = 'sha1://88c6b5899d74e0a75735e3dcaead70b739c673e7/88c6b5899d74e0a75735e3dcaead70b739c673e7/firings.mda'

sorting1_snippets_uri = 'sha1://727d9ef566cfa4ca66a1e5153f2bcf13d90d3977/snippets.h5?manifest=60c2432b24811eaeb595248efc06d00a2bd0c0b0'
sorting2_snippets_uri = 'sha1://f78cffd6351d2c0fccb529dc1e997b86cec9ce63/snippets.h5?manifest=9602da6fbd2247f996553cc6f3cbff54b4313c46'

kp.load_file(sorting1_uri)
kp.load_file(sorting2_uri)
sorting1 = nd.LabboxEphysSortingExtractor(sorting1_uri) # units 0 - 10
sorting2 = nd.LabboxEphysSortingExtractor(sorting2_uri) # units 1 - 10
# Can check these with sorting1.get_unit_ids() and sorting1.get_unit_spike_train(0)

# Neighborhoods with large radius: everything is in everything's cluster
big_neighborhoods = find_channel_neighborhoods(recording1, 50)
big_pairs = find_local_unit_pairs(sorting1, sorting2, sorting1_snippets_uri, sorting2_snippets_uri, big_neighborhoods) # naturally does full outer product
bigvals = compare_pairs(big_pairs, sorting1, sorting2)

# Smaller radius: yields something like 0 -> 0, 1; 1->0, 1, 2; 2->1, 2, 3; 3->2, 3
small_neighborhoods = find_channel_neighborhoods(recording1, 2)
small_pairs = find_local_unit_pairs(sorting1, sorting2, sorting1_snippets_uri, sorting2_snippets_uri, small_neighborhoods) # does NOT compare everything to everything
smallvals = compare_pairs(small_pairs, sorting1, sorting2)

# Compare to self (to get a little insight into comparison statistics)
mirror_pairs = find_local_unit_pairs(sorting1, sorting1, sorting1_snippets_uri, sorting1_snippets_uri, small_neighborhoods)
mirrorvals = compare_pairs(mirror_pairs, sorting1, sorting1)

print(f"""
Big confusion matrix:
{bigvals}

More local confusion matrix:
{smallvals}

Self-Comparison:
{mirrorvals}
""")

In [None]:
# On real-scale data

import kachery_p2p as kp
import time

recording_uri = 'sha1://8b222e25bc4d9c792e4490ca322b5338e0795596/cortexlab-single-phase-3.json'
recording1 = nd.LabboxEphysRecordingExtractor(recording_uri, download=True)
sorting1_uri = 'sha1://b0ab8219bde481e029b69431d85e3e08bb833851/file.json' # Kilosort for this recording
sorting2_uri = 'sha1://5c72c264f220bf36db4352b0f59380f5e7460bd8/file.json' # SpyKingCircus for this recording
# precomputed and returned by labbox_ephys.prepare_snippets_h5 and now available in my kachery store.
# Could also be presented by hither, but labbox caches and stores them anyway.
sorting1_snippets_uri = 'sha1://5c66150c72758c3e16ba56f1585e260bf1a3328d/snippets.h5?manifest=41a009f1bde0c2c677cb680960f72b78e15aa514'
sorting2_snippets_uri = 'sha1://c1281a2d4a47de8ff845974a88abbbda072d5c6e/snippets.h5?manifest=862ff23ab8ec58e85997642c3dbe32ff3a4636a0'
kp.load_file(sorting1_uri)
kp.load_file(sorting2_uri)

sorting1 = nd.LabboxEphysSortingExtractor(sorting1_uri)
sorting2 = nd.LabboxEphysSortingExtractor(sorting2_uri)
# Can check these with sorting1.get_unit_ids() and sorting1.get_unit_spike_train(0)


start_time = time.perf_counter()
# Neighborhoods with large radius: everything is in everything's cluster
neighborhoods = find_channel_neighborhoods(recording1, 50)
pairs = find_local_unit_pairs(sorting1, sorting2, sorting1_snippets_uri, sorting2_snippets_uri, neighborhoods) # naturally does full outer product
vals = compare_pairs(pairs, sorting1, sorting2)
elapsed = time.perf_counter() - start_time
print(f'Computing neighborhoods, peak channels, and conf matrix for 2 sortings took {elapsed} s.')

# Compare to self (to get a little insight into comparison statistics)
mirror_pairs = find_local_unit_pairs(sorting1, sorting1, sorting1_snippets_uri, sorting1_snippets_uri, neighborhoods)
start_time = time.perf_counter()
mirrorvals = compare_pairs(mirror_pairs, sorting1, sorting1)
elapsed = time.perf_counter() - start_time
print(f'Computing reflection conf matrix alone took {elapsed} s.')



#print(f"""
#Confusion matrix:
#{vals}
#
#Self-Comparison:
#{mirrorvals}
#""")

print('Some arbitrary low-different pairs from the kilosort-spyking comparison:')
for x in vals.keys():
    if vals[x][0] > 100000: continue
    print(f'{x}: {vals[x]}')

i = 0
print('Some instances where the Jaccard self-comparison is unstable:')
for x in mirrorvals.keys():
    if mirrorvals[x][0] > 100000: continue
    if mirrorvals[x][1] == 0: continue
    print(f'{x}: {mirrorvals[x]}')
    i += 1
    if i > 50: break
    