# motion estimation in spikeinterface

In 2021,the SpikeInterface project has started to implemented `sortingcomponents`, a modular module for spike sorting steps.

Here is an overview or our progress integrating motion (aka drift) estimation and correction.


This notebook will be based on the open dataset from Nick Steinmetz published in 2021
"Imposed motion datasets" from Steinmetz et al. Science 2021
https://figshare.com/articles/dataset/_Imposed_motion_datasets_from_Steinmetz_et_al_Science_2021/14024495


The motion estimation is done in several modular steps:
  1. detect peaks
  2. localize peaks:
     * **"center of mass"**
     * **"monopolar_triangulation"** by Julien Boussard and Erdem Varol
       https://openreview.net/pdf?id=ohfi44BZPC4
  3. estimation motion:
     * **rigid** or **non rigid**
     * **"decentralized"** by Erdem Varol and  Julien Boussard
       DOI : 10.1109/ICASSP39728.2021.9414145
  4. compute motion corrected peak localizations for visualization


Here we will show this chain:
* **detect peaks > localize peaks with "monopolar_triangulation" > estimation motion "decentralized" (both rigid and nonrigid)**



In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from pathlib import Path
import spikeinterface.full as si
import numpy as np
import matplotlib.pyplot as plt
plt.rcParams["figure.figsize"] = (20, 12)
from probeinterface.plotting import plot_probe

In [None]:
# local folder
base_folder = Path('/Users/charlie/data/')
dataset_folder = base_folder / 'dataset1'
preprocess_folder = base_folder / 'dataset1_preprocessed'
peak_folder = base_folder / 'dataset1_peaks'
peak_folder.mkdir(exist_ok=True)

In [None]:
# global kwargs for parallel computing
job_kwargs = dict(
    n_jobs=8,
    chunk_size=30_000,
    progress_bar=True,
)

In [None]:
# read the file
rec = si.read_spikeglx(dataset_folder)
rec

In [None]:
fig, ax = plt.subplots()
si.plot_probe_map(rec, ax=ax)
ax.set_ylim(-150, 200)

## preprocess

This takes 4 min for 30min of signals

In [None]:
if not preprocess_folder.exists():
    rec_filtered = si.bandpass_filter(rec, freq_min=300., freq_max=6000.)
    rec_preprocessed = si.common_reference(rec_filtered, reference='global', operator='median')
    rec_preprocessed.save(folder=preprocess_folder, **job_kwargs)
rec_preprocessed = si.load_extractor(preprocess_folder)

In [None]:
# plot and check spikes
si.plot_timeseries(rec_preprocessed, time_range=(100, 110), channel_ids=rec.channel_ids[50:60])

## estimate noise

In [None]:
noise_levels = si.get_noise_levels(rec_preprocessed, return_scaled=False)
fig, ax = plt.subplots(figsize=(8,6))
ax.hist(noise_levels, bins=10)
ax.set_title('noise across channel')

## detect peaks

This take 1min30s

In [None]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

In [None]:
if not (peak_folder / 'peaks.npy').exists():
    peaks = detect_peaks(
        rec_preprocessed,
        method='locally_exclusive',
        local_radius_um=100,
        peak_sign='neg',
        detect_threshold=5,
        n_shifts=5,
        noise_levels=noise_levels,
        **job_kwargs,
    )
    np.save(peak_folder / 'peaks.npy', peaks)
peaks = np.load(peak_folder / 'peaks.npy')
print(peaks.shape)

In [None]:
print(rec_preprocessed)

## localize peaks

Here we chosse **'monopolar_triangulation' with log barrier**

In [None]:
from spikeinterface.sortingcomponents.peak_localization import localize_peaks

In [None]:
if not (peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy').exists():
    peak_locations = localize_peaks(
        rec_preprocessed,
        peaks,
        ms_before=0.3,
        ms_after=0.6,
        method='monopolar_triangulation',
        method_kwargs={
            'local_radius_um': 100.,
            'max_distance_um': 1000.,
            'optimizer': 'minimize_with_log_penality',
        },
        **job_kwargs,
    )
    np.save(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy', peak_locations)
    print(peak_locations.shape)
peak_locations = np.load(peak_folder / 'peak_locations_monopolar_triangulation_log_limit.npy')

In [None]:
print(peak_locations.dtype.fields, peak_locations)

## plot on probe

In [None]:
def clip_values_for_cmap(x):
    low, high = np.percentile(x, [5, 95])
    return np.clip(x, low, high)

In [None]:
fig, axs = plt.subplots(ncols=2, sharey=True, figsize=(15, 10))
ax = axs[0]
si.plot_probe_map(rec_preprocessed, ax=ax)
ax.scatter(peak_locations['x'], peak_locations['y'], c=clip_values_for_cmap(peaks['amplitude']), s=1, alpha=0.002, cmap=plt.cm.plasma)
ax.set_xlabel('x')
ax.set_ylabel('y')
if 'z' in peak_locations.dtype.fields:
    ax = axs[1]
    ax.scatter(peak_locations['z'], peak_locations['y'], c=clip_values_for_cmap(peaks['amplitude']), s=1, alpha=0.002, cmap=plt.cm.plasma)
    ax.set_xlabel('z')
    ax.set_xlim(0, 150)
ax.set_ylim(1800, 2500)

## plot peak depth vs time

In [None]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, c=clip_values_for_cmap(peaks['amplitude']), cmap=plt.cm.plasma, alpha=0.25)
ax.set_ylim(1300, 2500)

## motion estimate : rigid with decentralized

In [None]:
from spikeinterface.sortingcomponents.motion_estimation import (
    estimate_motion,
    make_motion_histogram,
    compute_pairwise_displacement,
    compute_global_displacement
)

In [None]:
bin_um = 5
bin_duration_s=5.

motion_histogram, temporal_bins, spatial_bins = make_motion_histogram(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations, 
    bin_um=bin_um,
    bin_duration_s=bin_duration_s,
    direction='y',
    weight_with_amplitude=False,
)
print(motion_histogram.shape, temporal_bins.size, spatial_bins.size)

In [None]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], spatial_bins[0], spatial_bins[-1])
motion_histogram_vis = np.zeros_like(motion_histogram)
vals = motion_histogram[motion_histogram > 0]
vals -= vals.min()
vals = np.clip(vals, 0, np.percentile(vals, 95))
vals /= vals.max()
motion_histogram_vis[motion_histogram > 0] = 3 + 20 * vals
im = ax.imshow(
    motion_histogram_vis.T,
    interpolation='nearest',
    origin='lower',
    aspect='auto',
    extent=extent,
    cmap=plt.cm.cubehelix,
)
im.set_clim(0, 30)
ax.set_ylim(1300, 2500)
ax.set_xlabel('time[s]')
ax.set_ylabel('depth[um]')


## pariwise displacement from the motion histogram


In [None]:
conv_engine = "numpy"
try:
    import torch
    conv_engine = "torch"
except ImportError:
    pass

pairwise_displacement, pairwise_displacement_weight = compute_pairwise_displacement(
    motion_histogram, bin_um, method='conv', conv_engine=conv_engine, progress_bar=True, max_displacement_um=600
)
np.save(peak_folder / 'pairwise_displacement_conv2d.npy', pairwise_displacement)



In [None]:
fig, ax = plt.subplots()
extent = (temporal_bins[0], temporal_bins[-1], temporal_bins[0], temporal_bins[-1])
# extent = None
im = ax.imshow(
    pairwise_displacement,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im)

## estimate motion (rigid) from the pairwise displacement

In [None]:
pairwise_displacement

In [None]:
# motion = compute_global_displacement(pairwise_displacement)

motion_gd = compute_global_displacement(pairwise_displacement, convergence_method='gradient_descent')
motion_sparse_lsqr = compute_global_displacement(
    pairwise_displacement,
    # thresholding correlations
    sparse_mask=pairwise_displacement_weight > 0.6,
    # weighting by correlations
    pairwise_displacement_weight=pairwise_displacement_weight,
    convergence_method='lsqr_robust',
    lsqr_robust_n_iter=20,
    robust_regression_sigma=2,
)



In [None]:
fig, ax = plt.subplots()
ax.plot(temporal_bins[:-1], motion_gd, label="convergence_method='gradient_descent'")
ax.plot(temporal_bins[:-1], motion_sparse_lsqr, label="convergence_method='lsqr_robust'")
plt.legend()

## motion estimation with one unique funtion

Internally `estimate_motion()` does:
  * make_motion_histogram()
  * compute_pairwise_displacement()
  * compute_global_displacement()
  

In [None]:
from spikeinterface.sortingcomponents.motion_estimation import estimate_motion
from spikeinterface.widgets import plot_pairwise_displacement, plot_displacement

In [None]:
method='decentralized_registration'
method_kwargs = dict(
     pairwise_displacement_method='conv',
     convergence_method='gradient_descent',
     #convergence_method='lsqr_robust',
    
)

# method='decentralized_registration'
# method_kwargs = dict(
#     pairwise_displacement_method='phase_cross_correlation',
#     convergence_method='lsqr_robust',
# )


motion, temporal_bins, spatial_bins, extra_check = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=10.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=None,
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
    upsample_to_histogram_bin=False,
)

In [None]:
plot_pairwise_displacement(motion, temporal_bins, spatial_bins, extra_check, ncols=4)

In [None]:
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=True)

In [None]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=False, ax=ax)

## motion estimation non rigid


In [None]:
method='decentralized_registration'
method_kwargs = dict(
    pairwise_displacement_method='conv',
    convergence_method='gradient_descent',
    conv_engine=conv_engine,
    batch_size=8,
    corr_threshold=0.6,
)

motion, temporal_bins, spatial_bins, extra_check = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=5.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=dict(bin_step_um=400, sigma=3),
    margin_um=-400,
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
    upsample_to_histogram_bin=False,
)


In [None]:
fig, ax = plt.subplots()
for win in extra_check['non_rigid_windows']:
    ax.plot(win, extra_check['spatial_hist_bins'][:-1])

In [None]:
plot_pairwise_displacement(motion, temporal_bins, spatial_bins, extra_check, ncols=4)

In [None]:
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=True)

In [None]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, color='k', alpha=0.05)
plot_displacement(motion, temporal_bins, spatial_bins, extra_check, with_histogram=False, ax=ax)
ax.set_ylim(0, 2000)

In [None]:
fig, ax = plt.subplots()
ax.plot(temporal_bins, motion);

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(motion.T,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    # extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im);

## upsample motion estimate to original domain and apply motion correction to peak localizations

In [None]:
from spikeinterface.sortingcomponents.motion_correction import correct_motion_on_peaks

In [None]:
motion_up, temporal_bins, spatial_bins_up, extra_check_up = estimate_motion(
    rec_preprocessed,
    peaks,
    peak_locations=peak_locations,
    direction='y',
    bin_duration_s=5.,
    bin_um=5.,
    method=method,
    method_kwargs=method_kwargs,
    non_rigid_kwargs=dict(bin_step_um=300, sigma=3),
    margin_um=-400,
    output_extra_check=True,
    progress_bar=True,
    verbose=False,
    upsample_to_histogram_bin=True,
)

In [None]:
fig, ax = plt.subplots()
im = ax.imshow(motion_up.T,
    interpolation='nearest',
    cmap='PiYG',
    origin='lower',
    aspect='auto',
    extent=extent,
)
im.set_clim(-40, 40)
ax.set_aspect('equal')
fig.colorbar(im);

In [None]:
times = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
corrected_peak_locations = correct_motion_on_peaks(peaks, peak_locations, times,
                            motion_up, temporal_bins, spatial_bins_up,
                            direction='y', progress_bar=False)

In [None]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = peak_locations['y']
ax.scatter(x, y, s=1, c=clip_values_for_cmap(peaks['amplitude']), cmap=plt.cm.plasma, alpha=0.25)
ax.set_ylim(1300, 2500)
ax.set_title("unregistered localizations")
ax.set_xlabel("time (s)")
ax.set_ylabel("depth (um)");

In [None]:
fig, ax = plt.subplots()
x = peaks['sample_ind'] / rec_preprocessed.get_sampling_frequency()
y = corrected_peak_locations['y']
ax.scatter(x, y, s=1, c=clip_values_for_cmap(peaks['amplitude']), cmap=plt.cm.plasma, alpha=0.25)
ax.set_ylim(1300, 2500)
ax.set_title("registered localizations")
ax.set_xlabel("time (s)")
ax.set_ylabel("depth (um)");