In [None]:
import spikeinterface.full as si
import sys
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from figurl_tiled_image import TiledImage


sys.path.append("../utils/")

from subtract_templates import subtract_templates


%matplotlib widget

In [None]:
job_kwargs = dict(n_jobs=10, chunk_duration="1s", progress_bar=True)

In [None]:
mearec_file = "/home/alessio/Documents/data/mearec/recordings/recording_Neuropixels-128_300_int16.h5"

In [None]:
rec, sort = si.read_mearec(mearec_file)

In [None]:
# sort channels by depth for viz
locations = rec.get_channel_locations()
locations_order = np.argsort(locations[:, 1])
channel_ids_sorted = rec.channel_ids[locations_order]

rec_sorted = rec.channel_slice(channel_ids_sorted)

In [None]:
we = si.extract_waveforms(rec_sorted, sort, folder="mearec_sorted_wf", load_if_exists=True, **job_kwargs)

In [None]:
nbefore = we.nbefore

In [None]:
templates_dict = {u: we.get_template(u) for u in sort.unit_ids}

In [None]:
rec_sub = subtract_templates(rec_sorted, sort, templates_dict, nbefore, verbose=False)

In [None]:
tr = rec_sorted.get_traces(end_frame=32000)
tr_sub = rec_sub.get_traces(end_frame=32000)

In [None]:
plt.figure()

plt.plot(tr[:, 50], label="raw")
plt.plot(tr_sub[:, 50], label="subtract")
plt.legend()

## Compute RMS distributions

In [None]:
chunks_raw = si.get_random_data_chunks(rec_sorted, return_scaled=True)
chunks_sub = si.get_random_data_chunks(rec_sub, return_scaled=True)

In [None]:
rms_raw = np.sqrt(np.sum(chunks_raw ** 2, axis=0) / chunks_raw.shape[0])
rms_sub = np.sqrt(np.sum(chunks_sub ** 2, axis=0) / chunks_sub.shape[0])

In [None]:
fig, ax = plt.subplots()

_ = sns.kdeplot(rms_raw, ax=ax, cut=0)
_ = sns.kdeplot(rms_sub, ax=ax, cut=0)

## Figurl visualization

In [None]:
def convert_and_upload(processing_steps, labels, start_frame, num_samples, colormap="PRGn"):
    import spikeinterface.widgets as sw
    X = TiledImage(tile_size=512)
    
    for step, label in zip(processing_steps, labels):
        
        print('Processing ' + label)
        arr = step.get_traces(start_frame=start_frame,
                              end_frame=start_frame+num_samples)

        if label == 'Centered':
            color_range = 5000
        else:
            color_range = 250
        
        img = sw.array_to_image(arr, 
                                color_range=color_range,
                                num_timepoints_per_row = 6000.,
                                colormap=colormap)
        
        X.add_layer(label, img)
        
    url = X.url(label='SpikeInterface TiledImage example')
    
    return url

In [None]:
processing_steps = [rec_sorted, rec_sub]
labels = ["raw", "subtract"]
start_frame = 32000
num_samples = 32000

In [None]:
# url = "https://www.figurl.org/f?v=gs://figurl/figurl-tiled-image-2&d=sha1://079d0446768eaae0cfb7dc74b8765b9ac4523fea&label=SpikeInterface%20TiledImage%20example"
url = None

In [None]:
if url is None:
    url = convert_and_upload(processing_steps, labels, start_frame, num_samples, colormap="viridis")

In [None]:
print(url)

Sample URL:
https://figurl.org/f?v=gs://figurl/figurl-tiled-image-2&d=sha1://da2f0d1798758ed4379b1dcbb0d3d504664b4fdd&label=SpikeInterface%20TiledImage%20example

## Detect peaks

In [None]:
from spikeinterface.sortingcomponents.peak_detection import detect_peaks

In [None]:
noise_levels = si.get_noise_levels(rec_sorted)

In [None]:
peaks_raw = detect_peaks(rec_sorted, noise_levels=noise_levels, **job_kwargs)

In [None]:
peaks_sub = detect_peaks(rec_sub, noise_levels=noise_levels, **job_kwargs)

In [None]:
print(f"Fraction of detected spikes vs GT after subtraction: {len(peaks_sub) / len(peaks_raw)}")

In [None]:
plt.figure()
_ = plt.hist(peaks_raw["amplitude"], bins=100)
_ = plt.hist(peaks_sub["amplitude"], bins=100)
plt.xlim(-300, 0)