In [2]:
import spikeinterface.full as si
import numpy as np
import pickle

import submitit
from memory_profiler import memory_usage
import time
import shutil
import os

import asyncio
import gc


last_job=None

In [3]:
duration_extract = 2 #min

In [4]:
# stabilisation
file = "/crnldata/waking/audrey_hay/NPX/NPX1/VB/Expe_2024-07-22_17-29-40/NP_spikes_2024-07-22T17_29_40.raw"


# vrai enregisstrement
file = "/crnldata/waking/audrey_hay/NPX/NPX1/VB/Expe_2024-07-22_17-55-16/NP_spikes_2024-07-22T17_55_16.raw"

file = "/mnt/data/ahay/NP_spikes_2024-07-22T17_55_16.raw"

In [5]:
def GenerateDict(file, duration_extract):
    dirpath = os.path.join(os.getcwd(), 'kilosort4_output')
    if os.path.exists(dirpath) and os.path.isdir(dirpath):
        shutil.rmtree(dirpath)
        print(f'{dirpath} existed so it was deleted')
    
    with open('/crnldata/waking/audrey_hay/NPX/NPXprobe.pkl', 'rb') as outp: 
        probe = pickle.load(outp)
    probe.set_device_channel_indices(np.arange(384))

    raw_rec = si.read_binary(file, dtype='uint16', num_channels=384, sampling_frequency=30_000.)
    raw_rec = raw_rec.set_probe(probe)
    raw_rec = raw_rec.frame_slice(0, 30_000 * 60 * duration_extract)

    sorting = si.run_sorter('kilosort4', raw_rec, verbose=False)
    return sorting


In [11]:
def runFullSpikeSorting(file, sorting):
    dirpath = os.path.join(os.getcwd(), 'sorting_analyzer_demo_K')
    if os.path.exists(dirpath) and os.path.isdir(dirpath):
        shutil.rmtree(dirpath)
        print(f'{dirpath} existed so it was deleted')

    with open('/crnldata/waking/audrey_hay/NPX/NPXprobe.pkl', 'rb') as outp: 
        probe = pickle.load(outp)
    probe.set_device_channel_indices(np.arange(384))

    raw_rec = si.read_binary(file, dtype='uint16', num_channels=384, sampling_frequency=30_000.)
    raw_rec = raw_rec.set_probe(probe)

    rec = raw_rec.astype('float32')
    rec = si.bandpass_filter(rec)
    rec = si.common_reference(rec)
    rec.get_dtype()

    si.set_global_job_kwargs(n_jobs=40, progress_bar=True, chunk_duration="1s")

    sorting_analyzer = si.create_sorting_analyzer(sorting, rec, sparse=True)

    job_kwargs = dict(n_jobs=40, progress_bar=True, chunk_duration="1s")

    sorting_analyzer.compute("random_spikes", method="uniform", max_spikes_per_unit=500)
    sorting_analyzer.compute("waveforms", **job_kwargs)
    sorting_analyzer.compute("templates", **job_kwargs)
    sorting_analyzer.compute("noise_levels")
    sorting_analyzer.compute("unit_locations", method="monopolar_triangulation")
    sorting_analyzer.compute("isi_histograms")
    sorting_analyzer.compute("correlograms", window_ms=100, bin_ms=5.)
    sorting_analyzer.compute("principal_components", n_components=3, mode='by_channel_global', whiten=True, **job_kwargs)
    sorting_analyzer.compute("quality_metrics", metric_names=["snr", "firing_rate"])
    sorting_analyzer.compute("template_similarity")
    sorting_analyzer.compute("spike_amplitudes", **job_kwargs)


    sorting_analyzer.save_as(folder='./sorting_analyzer_demo_K', format='binary_folder')

In [7]:
def checkRessources():
    # check node and CPU information
    print("### Node counts: \nA: currently in use \B available")
    !sinfo -o%A
    print("### CPU counts: \nA: core currently in use \nI: available \nO: unavailable (maintenance, down, etc) \nT: total")
    !sinfo -o%C
    !sinfo

    # check some stats of our last job
    if last_job is not None:
        print('### CPU time and MaxRSS of our last job (about 1000Mb should be added to your MaxRSS (Mb) in order to cover safely the memory needs of the python runtime)###')
        os.system(f'sacct -j {last_job.job_id} --format="CPUTime,MaxRSS"')

## Sort spikes for the first few minutes of recording

It is good practice to have a look at available ressources and current use of the cluster

In [None]:
checkRessources()

In [None]:
#it takes about 90s
gc.collect()
start_time = time.time()

executor = submitit.AutoExecutor(folder=os.getcwd()+'/si_logs/')
#executor.update_parameters(mem_gb=5, timeout_min=5, slurm_partition="CPU", cpus_per_task=50)
executor.update_parameters(mem_gb=5, timeout_min=5, slurm_partition="GPU", cpus_per_task=2)

# actually submit the job
job = executor.submit(GenerateDict, file, duration_extract)

# print the ID of your job
print("submit job" + str(job.job_id))  

# await a single result
await job.awaitable().results()
print(f"job {job.job_id} completed in " + str(time.time()-start_time) + " seconds")

last_job = job
sorting = job.result()
print(sorting)

In [None]:
#kilosort4 run time 2212.34s for 8Gb 10cpus num1 (15.74, 15.42, 1.49, 2.11)
#100%|██████████| 60/60 [39:44<00:00, 39.74s/it] 8/10 python

#32%|███▏      | 19/60 [09:26<19:40, 28.78s/it] 8/30 submitit
#32%|███▏      | 19/60 [09:13<19:55, 29.16s/it] 16/30 submitit
#60%|██████    | 36/60 [12:10<07:36, 19.02s/it] 5/30 submitit
#23%|██▎       | 14/60 [10:14<32:55, 42.95s/it] 5/10 submitit
#42%|████▏     | 25/60 [06:35<07:23, 12.66s/it] 5/50 submitit
#33%|███▎      | 20/60 [05:32<10:29, 15.75s/it] 5/50 submitit data in mnt


#GPU
#job completed: 33972 returned in 110.73936009407043 seconds 5/50
#job completed: 33973 returned in 94.93399000167847 seconds 5/10
#job completed: 33975 returned in 93.79518842697144 seconds 5/2

## Cure the clusters
Here you should ensure that yopu are happy with the clusters that were found

## Sort full recording

In [None]:
print(sorting)

start_time = time.time()
mem_usage=memory_usage((runFullSpikeSorting,(file,sorting)))
end_time = time.time()
print('Maximum memory usage (in MB): %s' % max(mem_usage))
print('Maximum memory usage (in GB): %s' % (max(mem_usage)/1000))
print('Time taken (in s): %s' % (end_time-start_time))

In [None]:
start_time = time.time()

executor = submitit.AutoExecutor(folder=os.getcwd()+'/si_logs/')
#executor.update_parameters(slurm_array_parallelism=2, mem_gb=30, timeout_min=10, slurm_partition="CPU", cpus_per_task=50)
executor.update_parameters(mem_gb=60, timeout_min=20, slurm_partition="GPU", cpus_per_task=70) #cpus_per_task

# actually submit the job
job = executor.submit(runFullSpikeSorting, file, sorting)

# print the ID of your job
print("submit job" + str(job.job_id))  

# await a single result
await job.awaitable().results()
print(f"job {job.job_id} completed in " + str(time.time()-start_time) + " seconds")

In [None]:
#job 34074 completed in 408.1813361644745 second (slurm_array_parallelism=2, mem_gb=60, timeout_min=20, slurm_partition="GPU", cpus_per_task=50)
#job 34078 completed in 437.409494638443 seconds (slurm_array_parallelism=3, mem_gb=60, timeout_min=20, slurm_partition="GPU", cpus_per_task=50)
#job 34081 completed in 423.37460565567017 seconds (slurm_array_parallelism=2, mem_gb=60, timeout_min=20, slurm_partition="GPU", slurm_gres="gpu:2", cpus_per_task=50)
#job 34085 completed in 367.0000305175781 seconds (slurm_array_parallelism=2, mem_gb=60, timeout_min=20, slurm_partition="GPU", cpus_per_task=70)
#job 34089 completed in 370.1982145309448 seconds (slurm_array_parallelism=2, mem_gb=60, timeout_min=20, slurm_partition="GPU", cpus_per_task=80)
#job 34093 completed in 355.1876621246338 seconds (mem_gb=60, timeout_min=20, slurm_partition="GPU", cpus_per_task=70)

last_job = job
checkRessources()