## Objective
Detect and save sequences in pre-recording

This is the same code as detect_sequences.ipynb, but the cells are separated, so it is easier to test.

## Globals setup

In [50]:
%env MKL_NUM_THREADS=1
%env NUMEXPR_NUM_THREADS=1
%env OMP_NUM_THREADS=1

env: MKL_NUM_THREADS=1
env: NUMEXPR_NUM_THREADS=1
env: OMP_NUM_THREADS=1
4


In [2]:
%load_ext autoreload

In [3]:
from copy import deepcopy
from importlib import reload
from multiprocessing import Pool
import shutil
from pathlib import Path
import pickle

import matplotlib.pyplot as plt
import numpy as np
import torch

from spikeinterface.extractors import MaxwellRecordingExtractor
from tqdm import tqdm

%autoreload 2
from braindance.core.spikesorter.manuscript_code import utils
from braindance.core.spikesorter.manuscript_code import si_rec13 as F  # This forces you to manually reload every time modification happens (prevents forgetfulness errors)
# from src.sorters.base import Unit

from braindance.core.spikedetector.model2 import ModelSpikeSorter

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
# Load recording
RECORDING_PATH = "data.raw.h5"
##    
RECORDING = MaxwellRecordingExtractor(RECORDING_PATH)
SAMP_FREQ = round(RECORDING.get_sampling_frequency() / 1000)  # kHz
NUM_ELECS = RECORDING.get_num_channels()
ELEC_LOCS = RECORDING.get_channel_locations()

assert SAMP_FREQ <= 35, "SAMP_FREQ must be in kHz"
if SAMP_FREQ not in {20, 30}:
    print("NEED TO CHANGE FRONT_BUFFER AND OUTPUT_WINDOW_HALF_SIZE TO MODEL'S VALUES")

In [5]:
if RECORDING.get_total_duration() >= 5 * 60:  # Recording is greater than five minutes
    training_duration_ms = RECORDING.get_total_samples(
    ) / RECORDING.get_sampling_frequency() * 1000
    TRAINING_MS = (training_duration_ms - 5*60*1000, training_duration_ms)  # Last 5 minute of first patch
    TRACES_TRAINING_MS = (50, 5*60*1000)  # Rel to scaled_traces
    print("Pre-recording is greater than five minutes. Using last five minutes to detect sequences")
else:
    TRAINING_MS = (0, RECORDING.get_total_duration() * 1000)
    TRACES_TRAINING_MS = (50, RECORDING.get_total_duration() * 1000)
    
TESTING_MS = (-1, -1)  # Not used for patch recordings
# TESTING_MS = (training_duration_ms, RECORDING.get_total_duration() * 1000)  # 5 min to 10 min in recording (in ms)

Pre-recording is greater than five minutes. Using last five minutes to detect sequences


In [6]:
ROOT_PATH = Path("/data/MEAprojects/BrainDance/braindance/core/data/test_rt_sort2")
ROOT_PATH_MODEL = ROOT_PATH / "dl_model"
MODEL_PATH = Path("/data/MEAprojects/BrainDance/braindance/core/spikedetector/model_0_4_4_5118")

STRINGENT_THRESH = 0.275
STRINGENT_THRESH_LOGIT = F.sigmoid_inverse(STRINGENT_THRESH)
LOOSE_THRESH = 0.1 
LOOSE_THRESH_LOGIT = F.sigmoid_inverse(LOOSE_THRESH)

INFERENCE_SCALING_NUMERATOR = 12.6 

FRONT_BUFFER = round(2*SAMP_FREQ)
OUTPUT_WINDOW_HALF_SIZE = round(3*SAMP_FREQ)
PRE_MEDIAN_FRAMES = round(50 * SAMP_FREQ)

## No user inputs below
ROOT_PATH.mkdir(exist_ok=True, parents=True)
ROOT_PATH_MODEL.mkdir(exist_ok=True, parents=True)

SCALED_TRACES_PATH = ROOT_PATH_MODEL / "scaled_traces.npy"

MODEL_TRACES_PATH = ROOT_PATH_MODEL / "model_traces.npy"
MODEL_OUTPUTS_PATH = ROOT_PATH_MODEL / "model_outputs.npy" 

ALL_CROSSINGS_PATH  = ROOT_PATH_MODEL / "all_crossings.npy"
ELEC_CROSSINGS_IND_PATH = ROOT_PATH_MODEL / "elec_crossings_ind.npy"

print(ROOT_PATH)
print(ROOT_PATH_MODEL)

/data/MEAprojects/BrainDance/braindance/core/data/test_rt_sort2
/data/MEAprojects/BrainDance/braindance/core/data/test_rt_sort2/dl_model


In [37]:
F.RECORDING = RECORDING
F.NUM_ELECS = NUM_ELECS
F.ELEC_LOCS = ELEC_LOCS
F.SAMP_FREQ = SAMP_FREQ
F.FRONT_BUFFER = FRONT_BUFFER
F.INFERENCE_SCALING_NUMERATOR = INFERENCE_SCALING_NUMERATOR
F.PRE_MEDIAN_FRAMES = PRE_MEDIAN_FRAMES

In [8]:
# For RT-Sort manuscript: measure time to detect sequences
import time

class Stopwatch:
    def __init__(self):
        self.duration = 0
        self.start_time = 0
    def start(self):
        self.start_time = time.time()
    def stop(self):
        stop_time = time.time()
        self.duration += stop_time - self.start_time

stopwatch = Stopwatch()

In [10]:
stopwatch.start()
F.save_traces_mea_new(RECORDING_PATH, SCALED_TRACES_PATH, start_ms=TRAINING_MS[0], end_ms=TRAINING_MS[1])
stopwatch.stop()

stopwatch.start()
model = ModelSpikeSorter.load(MODEL_PATH)
model.compile(NUM_ELECS, MODEL_PATH)
stopwatch.stop()

stopwatch.start()
F.run_dl_model(MODEL_PATH, SCALED_TRACES_PATH, MODEL_TRACES_PATH, MODEL_OUTPUTS_PATH)
stopwatch.stop()

stopwatch.start()
F.NUM_ELECS = NUM_ELECS
F.SAMP_FREQ = SAMP_FREQ
F.FRONT_BUFFER = FRONT_BUFFER
F.STRINGENT_THRESH = STRINGENT_THRESH
F.STRINGENT_THRESH_LOGIT = STRINGENT_THRESH_LOGIT
F.extract_crossings(MODEL_OUTPUTS_PATH, ALL_CROSSINGS_PATH,
                    ELEC_CROSSINGS_IND_PATH)
stopwatch.stop()
print(f"Time to run DL model: {stopwatch.duration} seconds")

# Sanity check that there are stringent detections
print(len(np.load(ALL_CROSSINGS_PATH, allow_pickle=True)))

Alllocating memory for traces ...
Extracting traces ...


100%|██████████| 60/60 [00:08<00:00,  7.13it/s]


Loading DL model ...
Allocating memory to save model traces and outputs ...
Inference scaling: 1.8
Running model ...


100%|██████████| 49999/49999 [01:46<00:00, 467.48it/s]
100%|██████████| 5999/5999 [00:47<00:00, 125.54it/s]


Time to run DL model: 193.61896872520447 seconds
3014319


In [9]:
stopwatch.start()

# No user inputs here. Run after running DL model
ALL_CLOSEST_ELECS = []
for elec in range(NUM_ELECS):
    elec_ind = []
    dists = []
    x1, y1 = ELEC_LOCS[elec]
    for elec2 in range(RECORDING.get_num_channels()):
        if elec == elec2:
            continue
        x2, y2 = ELEC_LOCS[elec2]
        dists.append(np.sqrt((x2 - x1)**2 + (y2 - y1)**2))
        elec_ind.append(elec2)
    order = np.argsort(dists)
    ALL_CLOSEST_ELECS.append(np.array(elec_ind)[order])   
# 
TRACES = np.load(MODEL_TRACES_PATH, mmap_mode="r")
FILT_TRACES = np.load(SCALED_TRACES_PATH, mmap_mode="r")  # called FILT_TRACES, but these are not actually filtered
OUTPUTS = np.load(MODEL_OUTPUTS_PATH, mmap_mode="r")
ALL_CROSSINGS = np.load(ALL_CROSSINGS_PATH, allow_pickle=True)
ELEC_CROSSINGS_IND = np.load(ELEC_CROSSINGS_IND_PATH, allow_pickle=True)

ALL_CROSSINGS = [tuple(cross) for cross in ALL_CROSSINGS]
ELEC_CROSSINGS_IND = [tuple(ind) for ind in ELEC_CROSSINGS_IND]  # [(elec's cross times ind in all_crossings)]

stopwatch.stop()

In [39]:
# Set global variables in .py
reload(F)

F.RECORDING = RECORDING
F.MEA = True
F.STRINGENT_THRESH = STRINGENT_THRESH
F.STRINGENT_THRESH_LOGIT = STRINGENT_THRESH_LOGIT
F.LOOSE_THRESH = LOOSE_THRESH
F.LOOSE_THRESH_LOGIT = LOOSE_THRESH_LOGIT
F.INFERENCE_SCALING_NUMERATOR = INFERENCE_SCALING_NUMERATOR

# F.CHANS_RMS = CHANS_RMS
F.SAMP_FREQ = SAMP_FREQ
F.NUM_ELECS = NUM_ELECS
F.ELEC_LOCS = ELEC_LOCS

F.ALL_CLOSEST_ELECS = ALL_CLOSEST_ELECS

F.FRONT_BUFFER = FRONT_BUFFER
F.OUTPUT_WINDOW_HALF_SIZE = OUTPUT_WINDOW_HALF_SIZE

F.N_BEFORE = F.N_AFTER = round(0.5 * SAMP_FREQ)  # Window for looking for electrode codetections
F.MIN_ELECS_FOR_ARRAY_NOISE = max(100, round(0.1 * NUM_ELECS))
F.MIN_ELECS_FOR_SEQ_NOISE = max(50, round(0.05 * NUM_ELECS))
F.PRE_MEDIAN_FRAMES = PRE_MEDIAN_FRAMES

F.MIN_ACTIVITY = 0.05 * (TRAINING_MS[1] - TRAINING_MS[0]) / 1000

# If doing on new recording, these should be set after ## Full run - DL model
F.TRACES = TRACES
F.OUTPUTS = OUTPUTS
F.ALL_CROSSINGS = ALL_CROSSINGS
F.ELEC_CROSSINGS_IND = ELEC_CROSSINGS_IND

# Different parameters for MEA
F.MIN_AMP_DIST_P = -1
F.MAX_AMP_MEDIAN_DIFF_SPIKES = F.MAX_AMP_MEDIAN_DIFF_SEQUENCES = 0.65
F.MAX_LATENCY_DIFF_SPIKES = F.MAX_LATENCY_DIFF_SEQUENCES = 3.5
F.CLIP_LATENCY_DIFF = 7
F.CLIP_AMP_MEDIAN_DIFF = 1.3
F.MAX_ROOT_AMP_MEDIAN_STD_SPIKES = 2.5
F.MAX_ROOT_AMP_MEDIAN_STD_SEQUENCES = np.inf

In [21]:
stopwatch.start()

MIN_SPIKES = max(10, 0.05 * (TRAINING_MS[1] - TRAINING_MS[0]) / 1000)

##
all_clusters = F.form_all_clusters(TRACES_TRAINING_MS)
# utils.pickle_dump(all_clusters, ROOT_PATH / "all_clusters.pickle")
# all_clusters = utils.pickle_load(ROOT_PATH / "all_clusters.pickle")

all_clusters_reassigned = F.reassign_spikes(all_clusters, TRACES_TRAINING_MS, MIN_SPIKES)
# utils.pickle_dump(all_clusters_reassigned, ROOT_PATH / "all_clusters_reassigned.pickle")
# all_clusters_reassigned = utils.pickle_load(ROOT_PATH / "all_clusters_reassigned.pickle")

intra_merged_clusters = F.intra_merge(all_clusters_reassigned) 
trained_sequences = F.inter_merge(intra_merged_clusters, MIN_SPIKES)
# utils.pickle_dump(trained_sequences, ROOT_PATH / "trained_sequences.pickle")
# trained_sequences = utils.pickle_load(ROOT_PATH / "trained_sequences.pickle")
stopwatch.stop()
print(f"Time to detect sequences: {stopwatch.duration} seconds")

# Save data
utils.pickle_dump(all_clusters, ROOT_PATH / "all_clusters.pickle")
utils.pickle_dump(all_clusters_reassigned, ROOT_PATH / "all_clusters_reassigned.pickle")
utils.pickle_dump(trained_sequences, ROOT_PATH / "trained_sequences.pickle")

SEQUENCES = trained_sequences


100%|██████████| 772/772 [01:11<00:00, 10.83it/s]


162 sequences before merging


100%|██████████| 59980/59980 [02:01<00:00, 494.07it/s]
100%|██████████| 134/134 [00:10<00:00, 12.72it/s]
100%|██████████| 42/42 [00:00<00:00, 47.04it/s]


46 sequences after first merging

Merged 29 with 41
Latency diff: 0.14. Amp median diff: 0.08
Amp dist p-value 0.3882
#spikes:
Merge base: 242, Add: 363, Overlaps: 1
After merging: 604

Merged 45 with 43
Latency diff: 0.18. Amp median diff: 0.12
Amp dist p-value 0.8174
#spikes:
Merge base: 751, Add: 478, Overlaps: 5
After merging: 1226

Merged 44 with [45, 43]
Latency diff: 0.25. Amp median diff: 0.09
Amp dist p-value 0.1826
#spikes:
Merge base: 781, Add: 1226, Overlaps: 12
After merging: 1988

Merged [29, 41] with 39
Latency diff: 0.24. Amp median diff: 0.12
Amp dist p-value 0.6890
#spikes:
Merge base: 604, Add: 115, Overlaps: 8
After merging: 711

Merged 31 with 35
Latency diff: 0.36. Amp median diff: 0.15
Amp dist p-value 2.6953
#spikes:
Merge base: 19, Add: 168, Overlaps: 0
After merging: 187

Merged 24 with 26
Latency diff: 0.76. Amp median diff: 0.08
Amp dist p-value 0.2090
#spikes:
Merge base: 126, Add: 65, Overlaps: 29
After merging: 163

Merged 33 with 23
Latency diff: 0.45. A

In [16]:
model = ModelSpikeSorter.load(MODEL_PATH)

In [24]:
trained_sequences = utils.pickle_load(ROOT_PATH / "trained_sequences.pickle")

In [35]:
# Set global variables in .py
reload(F)

F.RECORDING = RECORDING
F.MEA = True
F.STRINGENT_THRESH = STRINGENT_THRESH
F.STRINGENT_THRESH_LOGIT = STRINGENT_THRESH_LOGIT
F.LOOSE_THRESH = LOOSE_THRESH
F.LOOSE_THRESH_LOGIT = LOOSE_THRESH_LOGIT
F.INFERENCE_SCALING_NUMERATOR = INFERENCE_SCALING_NUMERATOR

# F.CHANS_RMS = CHANS_RMS
F.SAMP_FREQ = SAMP_FREQ
F.NUM_ELECS = NUM_ELECS
F.ELEC_LOCS = ELEC_LOCS

F.ALL_CLOSEST_ELECS = ALL_CLOSEST_ELECS

F.FRONT_BUFFER = FRONT_BUFFER
F.OUTPUT_WINDOW_HALF_SIZE = OUTPUT_WINDOW_HALF_SIZE

# Window for looking for electrode codetections
F.N_BEFORE = F.N_AFTER = round(0.5 * SAMP_FREQ)
F.MIN_ELECS_FOR_ARRAY_NOISE = max(100, round(0.1 * NUM_ELECS))
F.MIN_ELECS_FOR_SEQ_NOISE = max(50, round(0.05 * NUM_ELECS))
F.PRE_MEDIAN_FRAMES = PRE_MEDIAN_FRAMES

F.MIN_ACTIVITY = 0.05 * (TRAINING_MS[1] - TRAINING_MS[0]) / 1000

# If doing on new recording, these should be set after ## Full run - DL model
F.TRACES = TRACES
F.OUTPUTS = OUTPUTS
F.ALL_CROSSINGS = ALL_CROSSINGS
F.ELEC_CROSSINGS_IND = ELEC_CROSSINGS_IND

# Different parameters for MEA
F.MIN_AMP_DIST_P = -1
F.MAX_AMP_MEDIAN_DIFF_SPIKES = F.MAX_AMP_MEDIAN_DIFF_SEQUENCES = 0.65
F.MAX_LATENCY_DIFF_SPIKES = F.MAX_LATENCY_DIFF_SEQUENCES = 3.5
F.CLIP_LATENCY_DIFF = 7
F.CLIP_AMP_MEDIAN_DIFF = 1.3
F.MAX_ROOT_AMP_MEDIAN_STD_SPIKES = 2.5
F.MAX_ROOT_AMP_MEDIAN_STD_SEQUENCES = np.inf
##


rt_sort = F.RTSort(trained_sequences, model, SCALED_TRACES_PATH)
rt_sort.save(ROOT_PATH / "rt_sort.pickle")

In [40]:
all_spike_trains = F.assign_spikes_torch(trained_sequences, None, return_spikes=True)

100%|██████████| 59990/59990 [01:51<00:00, 537.21it/s]


In [46]:
trunc_trains = []
for train in all_spike_trains:
    trunc_trains.append(train[train < 60 * 1000])

In [47]:
total = 0
for trains in trunc_trains:
    total += len(trains)
total

1096

In [49]:
np.save(ROOT_PATH / "test_spike_trains.npy", np.array(trunc_trains, dtype=object))