In [3]:
from pathlib import Path
import medicine
import numpy as np
import time
import os

from si_utils import *

from spikeinterface.core import set_global_job_kwargs, get_global_job_kwargs
from spikeinterface.core.motion import Motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.extractors import read_spikeglx, get_neo_streams

def si_preprocess_with_medicine(spikeglx_folder, IMEC=0):

    t = time.time()
    global_job_kwargs = dict(n_jobs=40, chunk_duration='1s', progress_bar=True)
    print(get_global_job_kwargs())
    set_global_job_kwargs(**global_job_kwargs)
    
    print("Step 1: Setting up folders and reading raw recording...")
    medicine_output_dir, preprocess_folder, figs_folder = make_folder_paths(spikeglx_folder, IMEC)
    print(preprocess_folder)
    
    # SpikeInterface recording object you would like to do motion correction for
    stream_names, stream_ids = get_neo_streams('spikeglx', spikeglx_folder)
    recording = read_spikeglx(spikeglx_folder, stream_name=f'imec{IMEC}.ap', load_sync_channel=False)

    if not os.path.isdir(medicine_output_dir):     
        print("Step 2: Estimating motion correction...")
        
        # Detect, extract, and localize peaks, such as with the following pipeline
        peaks = detect_peaks(recording, method="locally_exclusive")
        peak_locations = localize_peaks(recording, peaks, method="monopolar_triangulation")
        
        # Create directory to store MEDiCINe outputs for this recording
        #medicine_output_dir = Path('path/to/medicine/output/directory')
        medicine_output_dir.mkdir(parents=True, exist_ok=True)
        
        # Run MEDiCINe to estimate motion
        medicine.run_medicine(
            peak_amplitudes=peaks['amplitude'],
            peak_depths=peak_locations['y'],
            peak_times=peaks['sample_index'] / recording.get_sampling_frequency(),
            output_dir=medicine_output_dir,
        )
    
    # Load motion estimated by MEDiCINe
    motion = np.load(medicine_output_dir / 'motion.npy')
    time_bins = np.load(medicine_output_dir / 'time_bins.npy')
    depth_bins = np.load(medicine_output_dir / 'depth_bins.npy')

    print("Step 3: Applying motion correction...")
    # Use interpolation to correct for motion estimated by MEDiCINe
    motion_object = Motion(
        displacement=motion,
        temporal_bins_s=time_bins,
        spatial_bins_um=depth_bins,
    )
    recording_motion_corrected = InterpolateMotionRecording(
        recording,
        motion_object,
        border_mode='force_extrapolate',
    )

    print("Step 4: Running kilosort...")
    params = si.get_default_sorter_params(sorter_name_or_class='kilosort4')
    params_kilosort4 = {
        'do_correction': False,
        'bad_channels': None #would need to change if we choose not to delete bad channels
    }
    
    sorting = si.run_sorter('kilosort4', recording_motion_corrected, remove_existing_folder=True, folder=preprocess_folder.parent / 'kilosort4_preprocess',
                            docker_image=False, verbose=True, **params_kilosort4)

    print(" Done! Drift-corrected kilosort ran successfully.")


In [4]:
IN1 = 'kendra_scrappy_0124a_g0'
IN2 = 0

spikeglx_folder = f"/ix1/pmayo/lab_NHPdata/{IN1}/"

si_preprocess_with_medicine(spikeglx_folder, IMEC=IN2)

{'pool_engine': 'process', 'n_jobs': 40, 'chunk_duration': '1s', 'progress_bar': True, 'mp_context': None, 'max_threads_per_worker': 1}
Step 1: Setting up folders and reading raw recording...
/ix1/pmayo/lab_NHPdata/kendra_scrappy_0124a_g0/kendra_scrappy_0124a_g0_imec0/preprocess
Step 2: Estimating motion correction...


noise_level (workers: 20 processes):   0%|          | 0/20 [00:00<?, ?it/s]

detect peaks using locally_exclusive (workers: 40 processes):   0%|          | 0/6273 [00:00<?, ?it/s]

localize peaks using monopolar_triangulation (workers: 40 processes):   0%|          | 0/6273 [00:00<?, ?it/s]

Exception in thread Thread-8:
Traceback (most recent call last):
  File "/ihome/pmayo/knoneman/.conda/envs/kilosort/lib/python3.10/threading.py", line 1016, in _bootstrap_inner
    self.run()
  File "/ihome/pmayo/knoneman/.conda/envs/kilosort/lib/python3.10/concurrent/futures/process.py", line 323, in run
    self.terminate_broken(cause)
  File "/ihome/pmayo/knoneman/.conda/envs/kilosort/lib/python3.10/concurrent/futures/process.py", line 463, in terminate_broken
    work_item.future.set_exception(bpe)
  File "/ihome/pmayo/knoneman/.conda/envs/kilosort/lib/python3.10/concurrent/futures/_base.py", line 561, in set_exception
    raise InvalidStateError('{}: {!r}'.format(self._state, self))
concurrent.futures._base.InvalidStateError: CANCELLED: <Future at 0x7f43d49d70a0 state=cancelled>


BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.