In [None]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
%matplotlib inline

import numpy as np
import h5py

import os, sys
import spikeextractors as se
import ast
from detection_helper_functions import getSortedDetectionInfo, getNClosestPositions
from detection_helper_functions import runThresholdDetection, getDist, evaluateDetection

In [None]:
#Set up extractors and dataset
dataset_directory = 'path-to-mea-rec-h5-file' # folder or hdf5 file
dataset = h5py.File(dataset_directory)
recording = se.MEArecRecordingExtractor(recording_path=dataset_directory)
sorting = se.MEArecSortingExtractor(recording_path=dataset_directory)

neuron_locs = []
for i in range(50):
    annotation_hf = dataset['spiketrains'][str(i)]['annotations']
    annotation_dict = ast.literal_eval(annotation_hf[()])
    neuron_locs.append(annotation_dict['loc'])

#Probe channel positions and locations of the neurons in 2D numpy arrays
channel_positions = np.asarray(dataset['positions'])
neuron_locs = np.asarray(neuron_locs)

#Return the spike times and associated neurons that fired in two numpy arrays
firing_times, firing_neurons = getSortedDetectionInfo(sorting)

print("Num Channels: " + str(len(recording.getChannelIds())))
print("Num Frames: " + str(recording.getNumFrames()))
print("Firing times: " + str(firing_times))
print("Firing neurons: " + str(firing_neurons))

In [None]:
#Get distance matrix between all channels (channels x channels dimension)
dist_matrix = []
for channel_position in channel_positions:
    closest_channels, closest_dists = getNClosestPositions(channel_positions.shape[0], 
                                                           channel_position, 
                                                           channel_positions)
    inds = np.argsort(closest_channels)
    closest_channels = closest_channels[inds]
    closest_dists = closest_dists[inds]
    dist_matrix.append(closest_dists)
dist_matrix = np.asarray(dist_matrix)

In [None]:
#Plot traces from channel and spike times from closest neuron for the given frames
channel_ids = [5]
t0 = 0
t1 = 10000

plt.figure(figsize=(20,10))
plt.plot(recording.getTraces(channel_ids=channel_ids, start_frame=t0, end_frame=t1)[0])
closest_neurons, _ = getNClosestPositions(5, channel_positions[channel_ids[0]], neuron_locs)
closest_neuron = closest_neurons[0]
closest_neuron_st = sorting.getUnitSpikeTrain(unit_id=closest_neuron)
print(closest_neuron)
num_spikes = 0
spike_times = []
for spike_time in sorting.getUnitSpikeTrain(unit_id=closest_neuron, start_frame=t0, end_frame=t1):
    plt.axvline(x=spike_time - t0, color='red', linestyle='--', alpha=.5)
    num_spikes += 1
    spike_times.append(spike_time)
plt.axhline(-40, color='black', linestyle='dashdot')
plt.xlabel('Frames')
plt.xlabel('mV')
print('Channel: ' + str(channel_ids[0]))
print('Closest Neuron: ' + str(closest_neuron))
print('Neuron ' + str(closest_neuron) + ' spiked ' + str(num_spikes) + ' times in frames ' +  str(t0) + '-' + str(t1) + ' at ' + str(spike_times))

snippets = recording.getSnippets(reference_frames=spike_times, snippet_len=50, channel_ids=channel_ids)
plt.figure()
for snippet in snippets:
    plt.plot(snippet[0])
plt.xlabel('Frames')
plt.xlabel('mV')

In [None]:
#Basic detection method for spikes for a channel in-between given frames (thresholding and peak detection)
channel_ids = range(100)
t0 = 0
t1 = 1920000
threshold = 40
refractory_period = 10
duplicate_radius = 100

detected_firing_times, detected_channels = runThresholdDetection(channel_ids, 
                                                                 t0, 
                                                                 t1, 
                                                                 threshold, 
                                                                 refractory_period,
                                                                 duplicate_radius,
                                                                 dist_matrix,
                                                                 recording
                                                                )
print("Num detected events: " + str(detected_firing_times.shape[0]))

In [None]:
# Evaluating the detection results(detected_firing_times is sorted numpy array [0, 2, 100, 240, ...]) and
# detected_channels is the corresponding channels the spikes have largest amplitudes on [34, 12, 23, 26, 99, ...]
max_neuron_channel_dist = 100 #Dist (in microns) from channels to unit locations for them to be considered candidates for the spike
matched_events, unmatched_detections, unmatched_firings = evaluateDetection(
                                                                            detected_firing_times, 
                                                                            detected_channels, 
                                                                            firing_times, 
                                                                            firing_neurons, 
                                                                            channel_positions, 
                                                                            neuron_locs,
                                                                            max_neuron_channel_dist, 
                                                                            jitter=10 #Frames allowed between detection and true event
                                                                           )
print("True Positives: " + str(len(matched_events)))
print("False Positives: " + str(len(unmatched_detections)))
print("False Negatives: " + str(len(unmatched_firings)))

In [None]:
#Understanding the evaluation results afterwards

#(Detected frame, matched ground truth frame, detected channel) (True positive)
print("First matched event: " + str(matched_events[0]))

#(Detected frame, detected channel) (False positive)
print("First unmatched detection: " + str(unmatched_detections[0]))

#(Firing time, firing neuron) (False negative)
print("First ground truth unmatched: " + str((unmatched_firings[0])))

In [None]:
#Plot unmatched detections
channel_ids = [67]
t0 = 0
t1 = 100000

bad_detections = [detection for detection in unmatched_detections if detection[1] == channel_ids[0] and detection[0] <= t1 and detection[0] >= t0]
print("Not matched detections: "+ str(bad_detections))

plt.figure(figsize=(20,10))
plt.plot(recording.getTraces(channel_ids=channel_ids, start_frame=t0, end_frame=t1)[0])
closest_neurons, _ = getNClosestPositions(5, channel_positions[channel_ids[0]], neuron_locs)
for bad_detection in bad_detections:
    plt.axvline(x=bad_detection[0] - t0, color='red', linestyle='--', alpha=.5)
    
plt.axhline(-40, color='black', linestyle='dashdot')
print('Channel: ' + str(channel_ids[0]))