In [None]:
import h5py
import pickle

import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

%matplotlib notebook

## Set up hs2 objects

In [None]:
from hs2 import HSDetection, HSClustering
from probe import NeuroSeeker_128


data_path = '../datasets/TEST_HIGHPASS_INVERTED_PIP_Kampff_2015_09_03_Pair_9_0.hdf5'

Probe = NeuroSeeker_128(data_file_path=data_path)


default_detection_parameters = {'to_localize': True,
                                  'cutout_start': 10,
                                  'cutout_end': 30,
                                  'threshold': 130,
                                  'maa': 0,
                                  'maxsl': 12,
                                  'minsl': 3,
                                  'ahpthr': 0
                                }

HSD = HSDetection(Probe, **default_detection_parameters)

## Load GT spiketrain

The optimisation class expects ground-truth as a numpy array of timestamps

In [None]:
gt_path = '../datasets/HIGHPASS_INVERTED_PIP_Kampff_2015_09_03_Pair_9_0_Thresh_15_SpikesSYCL.txt'

gt_spiketrain = np.loadtxt(gt_path, dtype='int')[:,1]

closest_ch = 109

# Optimise both detection and clustering parameters

In [None]:
from parameter_optimisation import OptimiseParameters

# Define parameters to optimise over
detec_params_to_opt = {'threshold': (50, 300),
                       'ahpthr': (0, 20),
                       'maxsl': (0, 30)
                       }

clust_params_to_opt = {'bandwidth': (0., 4.),
                       'alpha': (0., 5.),
                       'pca_ncomponents': (1, 10)
                       }

detect_results_name = 'result_optim_params_detect' 
clust_results_name = 'result_optim_params_clust' 

op = OptimiseParameters(gt_spiketrain,
                        closest_ch, 
                        Probe,
                        HSD, 
                        detec_params_to_opt, 
                        None, 
                        clust_params_to_opt, 
                        optimise_detection=True, 
                        optimise_clustering=True,
                        detec_run_schedule=(5,1),
                        clust_run_schedule=(5,1),
                        detec_outfile=detect_results_name, 
                        clust_outfile=clust_results_name)

HSC = op.run()

In [None]:
results_path = ''
with open('{}{}.pickle'.format(results_path,detect_results_name), 'rb') as f:
    results_obj = pickle.load(f)

## Bayesian optimisation plots

In [None]:
%matplotlib inline
from skopt.plots import plot_convergence, plot_evaluations, plot_objective

parameter_names = list(detec_params_to_opt.keys())

plt.figure(figsize=(15,3.5))
plot_convergence(results_obj, ax=plt.gca());
plot_evaluations(results_obj, dimensions=parameter_names);
plot_objective(results_obj, dimensions=parameter_names);

# Validation of Detection

In [None]:
def plot_heatmap(obj, title):
    plt.figure(figsize=(17,3))
    
    TPs = np.asarray(list(map(len, obj['TPs']))).reshape(32,4)
    FNs = np.asarray(list(map(len, obj['FNs'])))
    
    plt.imshow(np.transpose(TPs), cmap='jet', interpolation='bilinear')
    for i,FN in enumerate(FNs):
        color = 'w' # if i is 116 else 'w'
        color = 'magenta' if i is 109 else color
        plt.text(i//4, i%4, FN, 
                 color=color, 
                 fontsize=10 if FN>10**3 else 12, 
                 horizontalalignment='center')
    plt.title('Missed detections on each channel.', fontsize=15);

In [None]:
%matplotlib inline
plot_heatmap(results_obj, 'Missed detections per channel')

## Count duplicate spikes

In [None]:
orig_count = HSD.spikes.shape[0]
duplicate_count = 0
for ch in tqdm(range(128), desc='Deduplication', unit=' channels'):
    # Select all in neighbourhood
    neigh_spikes = HSD.spikes.loc[HSD.spikes['ch'].isin(Probe.neighbors[ch])]
    
    # Count duplicated timestamps in neighbourhood
    if len(neigh_spikes) > 0:
        duplicate_count += np.sum(neigh_spikes.t.duplicated())

print("Found {} ({:.2f}%) duplicate spikes.".format(duplicate_count, 100*duplicate_count/orig_count))

## Manually inspect detection

In [None]:
lower_lim = 0
upper_lim = 2*10**6

In [None]:
probe_raw = Probe.Read(lower_lim, upper_lim).reshape(-1, 128)
# in case there weren't enough frames in probe file
upper_lim = min(upper_lim, probe_raw.shape[0])

In [None]:
from probes.readUtils import readNeuroSeekerPipette

gt_raw = readNeuroSeekerPipette(h5py.File(data_path), lower_lim, upper_lim)

In [None]:
%matplotlib notebook
%matplotlib notebook

ch = closest_ch

ch_spiketrain = HSD.spikes.loc[HSD.spikes['ch'] == ch]

plt.figure(figsize=(10, 6))

n_frames = upper_lim - lower_lim
xs = np.arange(n_frames)

# plt.xlim(200000, 300000)
plt.ylim(-4000, 6000)
plt.title("Channel {}".format(ch))
plt.xlabel("Frames")
plt.ylabel("Voltage (arbitrary units)")

# Graphs the signal on that channel
plt.plot(xs, probe_raw[lower_lim:upper_lim, ch], zorder=1, color='grey')

# # Graphs all spikes detected on channel ch from channels
trim_spikes_probe = ch_spiketrain[(lower_lim < ch_spiketrain.t) & (ch_spiketrain.t < upper_lim)]
plt.scatter(trim_spikes_probe.t, probe_raw[:, ch][trim_spikes_probe.t], zorder=2, color='r', marker='o')

# Plot pipette as well
scale_gt = 5
shift_gt = +0
plt.plot(xs, gt_raw[lower_lim:upper_lim]*scale_gt + shift_gt, zorder=0, color='lightgrey')
trim_spikes_pip = gt_spiketrain[np.logical_and([lower_lim < gt_spiketrain], [gt_spiketrain < upper_lim])[0]]
plt.scatter(trim_spikes_pip, gt_raw[trim_spikes_pip]*scale_gt + shift_gt, zorder=3, color='lightsalmon', marker='o');

# Sorting plots

In [None]:
clust_results_path = ''

with open('{}{}.pickle'.format(clust_results_path, clust_results_name), 'rb') as f:
    clust_results = pickle.load(f)

c = clust_results['most_popular_cluster']
alpha = clust_results['clustering_parameters']['alpha']
bandw = clust_results['clustering_parameters']['bandwidth']
n_PCA = clust_results['clustering_parameters']['n_pca']

In [None]:
print('Most popular cluster={} | alpha={} | bandwidth={} | nPCA={}'.format(c, alpha, bandw, n_PCA))

In [None]:
%matplotlib inline
plt.figure(figsize=(15,3.5))
plot_convergence(clust_results, ax=plt.gca());

In [None]:
%matplotlib inline
ch = closest_ch
gt = clust_results['Ps']
sp = HSC.spikes[HSC.spikes['ch'].isin(Probe.neighbors[ch])]
y_lim = (24.5, 29.5)
x_lim = (0,4)

# Filter out different groups of spikes
mask =  HSC.spikes['ch'].isin(Probe.neighbors[ch]) & HSC.spikes['t'].isin(gt)
gt_sp = HSC.spikes[mask]
mask =  HSC.spikes['ch'].isin(Probe.neighbors[ch]) & HSC.spikes['t'].isin(gt) & (HSC.spikes['cl']==c)
tp_sp = HSC.spikes[mask]
mask =  HSC.spikes['ch'].isin(Probe.neighbors[ch]) & ~HSC.spikes['t'].isin(gt) & (HSC.spikes['cl']==c)
fp_sp = HSC.spikes[mask]

# Generate title strings
titles = ['All neighbour.\nspikes (n={})'.format(sp.shape[0]),
          'GT spikes\n(n={})'.format(gt_sp.shape[0]),
          'TPs (n={})'.format(len(tp_sp)),
          'FPs (n={})'.format(len(fp_sp))
         ]

plt.figure(figsize=(15,4))

for i, title in enumerate(titles):
    plt.subplot(1,4,i+1)
    plt.xlim(*x_lim)
    plt.ylim(*y_lim)
    plt.scatter(sp.x, sp.y, marker='+', color='green', s=5., alpha=0.8)
    if i>0:
        plt.scatter(gt_sp.x, gt_sp.y, marker='+', color='r', s=5., alpha=0.8)
    if i==2:
        plt.scatter(tp_sp.x, tp_sp.y, marker='+', color='blue', s=5., alpha=0.8)
    if i==3:
        plt.scatter(fp_sp.x, fp_sp.y, marker='+', color='yellow', s=5., alpha=0.8)
    plt.title(title, fontsize=24);


In [None]:
HSC.PlotNeighbourhood(cl=c, radius=0.1)

In [None]:
plt.figure(figsize=(8,5))
HSC.PlotShapes([c], nshapes=1000, ax=plt.gca())