diff --git a/brainbox/processing.py b/brainbox/processing.py index a20cb1222..79b769abb 100644 --- a/brainbox/processing.py +++ b/brainbox/processing.py @@ -1,17 +1,16 @@ -''' -Processes data from one form into another, e.g. taking spike times and binning them into -non-overlapping bins and convolving spike times with a gaussian kernel. -''' +"""Process data from one form into another. + +For example, taking spike times and binning them into non-overlapping bins and convolving spike +times with a gaussian kernel. +""" import numpy as np import pandas as pd from scipy import interpolate, sparse from brainbox import core -from iblutil.numerical import bincount2D as _bincount2D +from iblutil.numerical import bincount2D from iblutil.util import Bunch import logging -import warnings -import traceback _logger = logging.getLogger(__name__) @@ -118,35 +117,6 @@ def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zer return syncd -def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None): - """ - Computes a 2D histogram by aggregating values in a 2D array. - - :param x: values to bin along the 2nd dimension (c-contiguous) - :param y: values to bin along the 1st dimension - :param xbin: - scalar: bin size along 2nd dimension - 0: aggregate according to unique values - array: aggregate according to exact values (count reduce operation) - :param ybin: - scalar: bin size along 1st dimension - 0: aggregate according to unique values - array: aggregate according to exact values (count reduce operation) - :param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension - :param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension - :param weights: (optional) defaults to None, weights to apply to each value for aggregation - :return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny] - """ - for line in traceback.format_stack(): - print(line.strip()) - warning_text = """Future warning: bincount2D() is now a part of iblutil. - brainbox.processing.bincount2D will be removed in future versions. - Please replace imports with iblutil.numerical.bincount2D.""" - _logger.warning(warning_text) - warnings.warn(warning_text, FutureWarning) - return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights) - - def compute_cluster_average(spike_clusters, spike_var): """ Quickish way to compute the average of some quantity across spikes in each cluster given @@ -197,7 +167,7 @@ def bin_spikes(spikes, binsize, interval_indices=False): def get_units_bunch(spks_b, *args): - ''' + """ Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information (e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for each unit: these arrays are ordered and can be indexed by unit id. @@ -223,18 +193,18 @@ def get_units_bunch(spks_b, *args): -------- 1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units bunch. - >>> import brainbox as bb - >>> import alf.io as aio + >>> from brainbox import processing + >>> import one.alf.io as alfio >>> import ibllib.ephys.spikes as e_spks (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> units_b = bb.processing.get_units_bunch(spks_b) + >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') + >>> units_b = processing.get_units_bunch(spks_b) # Get amplitudes for unit 4. >>> amps = units_b['amps']['4'] TODO add computation time estimate? - ''' + """ # Initialize `units` units_b = Bunch() @@ -261,7 +231,7 @@ def get_units_bunch(spks_b, *args): def filter_units(units_b, t, **kwargs): - ''' + """ Filters units according to some parameters. **kwargs are the keyword parameters used to filter the units. @@ -299,24 +269,24 @@ def filter_units(units_b, t, **kwargs): Examples -------- 1) Filter units according to the default parameters. - >>> import brainbox as bb - >>> import alf.io as aio + >>> from brainbox import processing + >>> import one.alf.io as alfio >>> import ibllib.ephys.spikes as e_spks (*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory): >>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out) # Get a spikes bunch, units bunch, and filter the units. - >>> spks_b = aio.load_object(path_to_alf_out, 'spikes') - >>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters']) + >>> spks_b = alfio.load_object(path_to_alf_out, 'spikes') + >>> units_b = processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters']) >>> T = spks_b['times'][-1] - spks_b['times'][0] - >>> filtered_units = bb.processing.filter_units(units_b, T) + >>> filtered_units = processing.filter_units(units_b, T) 2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false positive rate of 0.2, given a refractory period of 2 ms. - >>> filtered_units = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1) + >>> filtered_units = processing.filter_units(units_b, T, min_amp=0, min_fr=1) TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics are in `clstrs_b['metrics']` - ''' + """ # Set params params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults diff --git a/brainbox/tests/test_processing.py b/brainbox/tests/test_processing.py index 970aaeb34..d136d3d1f 100644 --- a/brainbox/tests/test_processing.py +++ b/brainbox/tests/test_processing.py @@ -1,7 +1,6 @@ from brainbox import processing, core import unittest import numpy as np -import datetime class TestProcessing(unittest.TestCase): @@ -63,15 +62,6 @@ def test_sync(self): self.assertTrue(times2.min() >= resamp2.times.min()) self.assertTrue(times2.max() <= resamp2.times.max()) - def test_bincount2D_deprecation(self): - # Timer to remove bincount2D (now in iblutil) - # Once this test fails: - # - Remove the bincount2D method in processing.py - # - Remove the import from iblutil at the top of that file - # - Delete this test - if datetime.datetime.now() > datetime.datetime(2024, 6, 30): - raise NotImplementedError - def test_compute_cluster_averag(self): # Create fake data for 3 clusters clust1 = np.ones(40) @@ -103,11 +93,6 @@ def test_compute_cluster_averag(self): self.assertEqual(avg_val[2], 0.75) self.assertTrue(np.all(count == (40, 40, 50))) - def test_deprecations(self): - """Ensure removal of bincount2D function.""" - from datetime import datetime - self.assertTrue(datetime.today() < datetime(2024, 8, 1), 'remove brainbox.processing.bincount2D') - if __name__ == '__main__': np.random.seed(0) diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index b258af3ab..ebaeb7e91 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -51,6 +51,7 @@ def patch_imaging_meta(meta: dict) -> dict: for fov in meta.get('FOV', []): if 'roiUuid' in fov: fov['roiUUID'] = fov.pop('roiUuid') + assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version' return meta @@ -753,6 +754,15 @@ def get_timeshifts(raw_imaging_meta): Calculate the time shifts for each field of view (FOV) and the relative offsets for each scan line. + For a 2 scan field, 2 depth recording (so 4 FOVs): + + Frame 1, lines 1-512 correspond to FOV_00 + Frame 1, lines 551-1062 correspond to FOV_01 + Frame 2, lines 1-512 correspond to FOV_02 + Frame 2, lines 551-1062 correspond to FOV_03 + Frame 3, lines 1-512 correspond to FOV_00 + ... + Parameters ---------- raw_imaging_meta : dict @@ -772,26 +782,27 @@ def get_timeshifts(raw_imaging_meta): FOVs = raw_imaging_meta['FOV'] # Double-check meta extracted properly - raw_meta = raw_imaging_meta['rawScanImageMeta'] - artist = raw_meta['Artist'] - assert sum(x['enable'] for x in artist['RoiGroups']['imagingRoiGroup']['rois']) == len(FOVs) - + # assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected. + slice_ids = np.array([fov['slice_id'] for fov in FOVs]) + assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order' + assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered' # Number of scan lines per FOV, i.e. number of Y pixels / image height n_lines = np.array([x['nXnYnZ'][1] for x in FOVs]) - n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks - # Number of lines during flyback - n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) - # The start and end indices of each FOV in the raw images - fov_start_idx = np.insert(np.cumsum(n_lines[:-1] + n_lines_per_gap), 0, 0) - fov_end_idx = fov_start_idx + n_lines - line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] - - line_indices = [] - fov_time_shifts = fov_start_idx * line_period - line_time_shifts = [] - for ln, s, e in zip(n_lines, fov_start_idx, fov_end_idx): - line_indices.append(np.arange(s, e)) - line_time_shifts.append(np.arange(0, ln) * line_period) + # We get indices from MATLAB extracted metadata so below two lines are no longer needed + # n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks + # n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback + line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod'] + frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate'] + + # Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m + # They are indexed from 1 so we subtract 1 to convert to zero-indexed + line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed + assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines' + # The start indices of each FOV in the raw images + fov_start_idx = np.array([lns[0] for lns in line_indices]) + roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV + fov_time_shifts = roi_time_shifts + frame_time_shifts + line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)] return line_indices, fov_time_shifts, line_time_shifts diff --git a/ibllib/pipes/local_server.py b/ibllib/pipes/local_server.py index 6574f88f1..cf3f76437 100644 --- a/ibllib/pipes/local_server.py +++ b/ibllib/pipes/local_server.py @@ -152,7 +152,7 @@ def job_creator(root_path, one=None, dry=False, rerun=False): return pipes, all_datasets -def task_queue(mode='all', lab=None, alyx=None): +def task_queue(mode='all', lab=None, alyx=None, env=(None,)): """ Query waiting jobs from the specified Lab @@ -164,12 +164,18 @@ def task_queue(mode='all', lab=None, alyx=None): Lab name as per Alyx, otherwise try to infer from local Globus install. alyx : one.webclient.AlyxClient An Alyx instance. + env : list + One or more environments to filter by. See :prop:`ibllib.pipes.tasks.Task.env`. Returns ------- list of dict A list of Alyx tasks associated with `lab` that have a 'Waiting' status. """ + def predicate(task): + classe = tasks.str2class(task['executable']) + return (mode == 'all' or classe.job_size == mode) and classe.env in env + alyx = alyx or AlyxClient(cache_rest=None) if lab is None: _logger.debug('Trying to infer lab from globus installation') @@ -179,28 +185,12 @@ def task_queue(mode='all', lab=None, alyx=None): return # if the lab is none, this will return empty tasks each time data_repo = get_local_data_repository(alyx) # Filter for tasks - tasks_all = alyx.rest('tasks', 'list', status='Waiting', - django=f'session__lab__name__in,{lab},data_repository__name,{data_repo}', no_cache=True) - if mode == 'all': - waiting_tasks = tasks_all - else: - small_jobs = [] - large_jobs = [] - for t in tasks_all: - strmodule, strclass = t['executable'].rsplit('.', 1) - classe = getattr(importlib.import_module(strmodule), strclass) - job_size = classe.job_size - if job_size == 'small': - small_jobs.append(t) - else: - large_jobs.append(t) - if mode == 'small': - waiting_tasks = small_jobs - elif mode == 'large': - waiting_tasks = large_jobs - + waiting_tasks = alyx.rest('tasks', 'list', status='Waiting', + django=f'session__lab__name__in,{lab},data_repository__name,{data_repo}', no_cache=True) + # Filter tasks by size + filtered_tasks = filter(predicate, waiting_tasks) # Order tasks by priority - sorted_tasks = sorted(waiting_tasks, key=lambda d: d['priority'], reverse=True) + sorted_tasks = sorted(filtered_tasks, key=lambda d: d['priority'], reverse=True) return sorted_tasks @@ -252,8 +242,7 @@ def tasks_runner(subjects_path, tasks_dict, one=None, dry=False, count=5, time_o if dry: print(session_path, tdict['name']) else: - task, dsets = tasks.run_alyx_task(tdict=tdict, session_path=session_path, - one=one, **kwargs) + task, dsets = tasks.run_alyx_task(tdict=tdict, session_path=session_path, one=one, **kwargs) if dsets: all_datasets.extend(dsets) c += 1 diff --git a/ibllib/pipes/mesoscope_tasks.py b/ibllib/pipes/mesoscope_tasks.py index 906ab7f27..0f83cbe44 100644 --- a/ibllib/pipes/mesoscope_tasks.py +++ b/ibllib/pipes/mesoscope_tasks.py @@ -193,11 +193,12 @@ def _run(self, remove_uncompressed=False, verify_output=True, clobber=False, **k class MesoscopePreprocess(base_tasks.MesoscopeTask): - """Run suite2p preprocessing on tif files""" + """Run suite2p preprocessing on tif files.""" priority = 80 cpu = -1 job_size = 'large' + env = 'suite2p' @property def signature(self): @@ -426,10 +427,7 @@ def _consolidate_exptQC(exptQC): # Concatenate frames frameQC = np.concatenate([e['frameQC_frames'] for e in exptQC], axis=0) - - # Transform to bad_frames as expected by suite2p bad_frames = np.where(frameQC != 0)[0] - return frameQC, frameQC_names, bad_frames def get_default_tau(self): @@ -555,7 +553,7 @@ def _run(self, run_suite2p=True, rename_files=True, use_badframes=True, **kwargs """ Metadata and parameters """ # Load metadata and make sure all metadata is consistent across FOVs meta_files = sorted(self.session_path.glob(f'{self.device_collection}/*rawImagingData.meta.*')) - collections = set(f.parts[-2] for f in meta_files) + collections = sorted(set(f.parts[-2] for f in meta_files)) # Check there is exactly 1 meta file per collection assert len(meta_files) == len(list(self.session_path.glob(self.device_collection))) == len(collections) rawImagingData = [mesoscope.patch_imaging_meta(alfio.load_file_content(filepath)) for filepath in meta_files] @@ -576,19 +574,31 @@ def _run(self, run_suite2p=True, rename_files=True, use_badframes=True, **kwargs self.kwargs = {**self.kwargs, **db} """ Bad frames """ + # exptQC.mat contains experimenter QC values that may not affect ROI detection (e.g. noises, pauses) qc_paths = (self.session_path.joinpath(f[1], 'exptQC.mat') for f in self.input_files if f[0] == 'exptQC.mat') qc_paths = sorted(map(str, filter(Path.exists, qc_paths))) exptQC = [loadmat(p, squeeze_me=True, simplify_cells=True) for p in qc_paths] if len(exptQC) > 0: - frameQC, frameQC_names, bad_frames = self._consolidate_exptQC(exptQC) + frameQC, frameQC_names, _ = self._consolidate_exptQC(exptQC) else: _logger.warning('No frame QC (exptQC.mat) files found.') - frameQC, bad_frames = np.array([], dtype='u1'), np.array([], dtype='i8') + frameQC = np.array([], dtype='u1') frameQC_names = pd.DataFrame(columns=['qc_values', 'qc_labels']) + # If applicable, save as bad_frames.npy in first raw_imaging_folder for suite2p - if len(bad_frames) > 0 and use_badframes is True: - np.save(Path(db['data_path'][0]).joinpath('bad_frames.npy'), bad_frames) + # badframes.mat contains QC values that do affect ROI detection (e.g. no PMT, lens artefacts) + badframes = np.array([], dtype='i8') + total_frames = 0 + # Ensure all indices are relative to total cumulative frames + for m, collection in zip(rawImagingData, collections): + badframes_path = self.session_path.joinpath(collection, 'badframes.mat') + if badframes_path.exists(): + raw_mat = loadmat(badframes_path, squeeze_me=True, simplify_cells=True)['badframes'] + badframes = np.r_[badframes, raw_mat + total_frames] + total_frames += m['nFrames'] + if len(badframes) > 0 and use_badframes is True: + np.save(Path(db['data_path'][0]).joinpath('bad_frames.npy'), badframes) """ Suite2p """ # Create alf it is doesn't exist @@ -625,10 +635,10 @@ def signature(self): (f'_{self.sync_namespace}_DAQdata.timestamps.npy', self.sync_collection, True), (f'_{self.sync_namespace}_DAQdata.meta.json', self.sync_collection, True), ('_ibl_rawImagingData.meta.json', self.device_collection, True), - ('rawImagingData.times_scanImage.npy', self.device_collection, True), + ('rawImagingData.times_scanImage.npy', self.device_collection, True, True), # register raw (f'_{self.sync_namespace}_softwareEvents.log.htsv', self.sync_collection, False), ], 'output_files': [('mpci.times.npy', 'alf/mesoscope/FOV*', True), - ('mpciStack.timeshift.npy', 'alf/mesoscope/FOV*', True), ] + ('mpciStack.timeshift.npy', 'alf/mesoscope/FOV*', True),] } return signature diff --git a/ibllib/pipes/tasks.py b/ibllib/pipes/tasks.py index c55305e11..a409ed1ea 100644 --- a/ibllib/pipes/tasks.py +++ b/ibllib/pipes/tasks.py @@ -88,6 +88,7 @@ from iblutil.util import Bunch import one.params from one.api import ONE +from one.util import ensure_list from one import webclient _logger = logging.getLogger(__name__) @@ -107,9 +108,10 @@ class Task(abc.ABC): time_elapsed_secs = None time_out_secs = 3600 * 2 # time-out after which a task is considered dead version = ibllib.__version__ - signature = {'input_files': [], 'output_files': []} # list of tuples (filename, collection, required_flag) - force = False # whether or not to re-download missing input files on local server if not present + signature = {'input_files': [], 'output_files': []} # list of tuples (filename, collection, required_flag[, register]) + force = False # whether to re-download missing input files on local server if not present job_size = 'small' # either 'small' or 'large', defines whether task should be run as part of the large or small job services + env = None # the environment name within which to run the task (NB: the env is not activated automatically!) def __init__(self, session_path, parents=None, taskid=None, one=None, machine=None, clobber=True, location='server', **kwargs): @@ -172,6 +174,13 @@ def run(self, **kwargs): -1: Errored -2: Didn't run as a lock was encountered -3: Incomplete + + Notes + ----- + - The `run_alyx_task` will update the Alyx Task status depending on both status and outputs + (i.e. the output of subclassed `_run` method): + Assuming a return value of 0... if Task.outputs is None, the status will be Empty; + if Task.outputs is a list (empty or otherwise), the status will be Complete. """ # if task id of one properties are not available, local run only without alyx use_alyx = self.one is not None and self.taskid is not None @@ -200,12 +209,13 @@ def run(self, **kwargs): start_time = time.time() try: setup = self.setUp(**kwargs) + self.outputs = self._input_files_to_register() _logger.info(f'Setup value is: {setup}') self.status = 0 if not setup: # case where outputs are present but don't have input files locally to rerun task # label task as complete - _, self.outputs = self.assert_expected_outputs() + _, outputs = self.assert_expected_outputs() else: # run task if self.gpu >= 1: @@ -217,8 +227,15 @@ def run(self, **kwargs): _logger.removeHandler(ch) ch.close() return self.status - self.outputs = self._run(**kwargs) + outputs = self._run(**kwargs) _logger.info(f'Job {self.__class__} complete') + if outputs is None: + # If run method returns None and no raw input files were registered, self.outputs + # should be None, meaning task will have an 'Empty' status. If run method returns + # a list, the status will be 'Complete' regardless of whether there are output files. + self.outputs = outputs if not self.outputs else self.outputs # ensure None if no inputs registered + else: + self.outputs.extend(ensure_list(outputs)) # Add output files to list of inputs to register except Exception: _logger.error(traceback.format_exc()) _logger.info(f'Job {self.__class__} errored') @@ -262,6 +279,41 @@ def register_datasets(self, **kwargs): _ = self.register_images() return self.data_handler.uploadData(self.outputs, self.version, **kwargs) + def _input_files_to_register(self, assert_all_exist=False): + """ + Return input datasets to be registered to Alyx. + + These datasets are typically raw data files and are registered even if the task fails to complete. + + Parameters + ---------- + assert_all_exist + Raise AssertionError if not all required input datasets exist on disk. + + Returns + ------- + list of pathlib.Path + A list of input files to register. + """ + try: + input_files = self.input_files + except AttributeError: + raise RuntimeError('Task.setUp must be run before calling this method.') + to_register, missing = [], [] + for filename, collection, required, _ in filter(lambda f: len(f) > 3 and f[3], input_files): + filepath = self.session_path.joinpath(collection, filename) + if filepath.exists(): + to_register.append(filepath) + elif required: + missing.append(filepath) + if any(missing): + missing_str = ', '.join(map(lambda x: x.relative_to(self.session_path).as_posix(), missing)) + if assert_all_exist: + raise AssertionError(f'Missing required input files: {missing_str}') + else: + _logger.error(f'Missing required input files: {missing_str}') + return list(set(to_register) - set(missing)) + def register_images(self, **kwargs): """ Registers images to alyx database @@ -540,7 +592,8 @@ def make_graph(self, out_dir=None, show=True): def create_alyx_tasks(self, rerun__status__in=None, tasks_list=None): """ - Instantiate the pipeline and create the tasks in Alyx, then create the jobs for the session + Instantiate the pipeline and create the tasks in Alyx, then create the jobs for the session. + If the jobs already exist, they are left untouched. The re-run parameter will re-init the job by emptying the log and set the status to Waiting. @@ -682,6 +735,24 @@ def name(self): return self.__class__.__name__ +def str2class(task_executable: str): + """ + Convert task name to class. + + Parameters + ---------- + task_executable : str + A Task class name, e.g. 'ibllib.pipes.behavior_tasks.TrialRegisterRaw'. + + Returns + ------- + class + The imported class. + """ + strmodule, strclass = task_executable.rsplit('.', 1) + return getattr(importlib.import_module(strmodule), strclass) + + def run_alyx_task(tdict=None, session_path=None, one=None, job_deck=None, max_md5_size=None, machine=None, clobber=True, location='server', mode='log'): """ @@ -714,8 +785,8 @@ def run_alyx_task(tdict=None, session_path=None, one=None, job_deck=None, Returns ------- - Task - The instantiated task object that was run. + dict + The updated task dict. list of pathlib.Path A list of registered datasets. """ @@ -736,9 +807,7 @@ def run_alyx_task(tdict=None, session_path=None, one=None, job_deck=None, tdict = one.alyx.rest('tasks', 'partial_update', id=tdict['id'], data={'status': 'Held'}) return tdict, registered_dsets # creates the job from the module name in the database - exec_name = tdict['executable'] - strmodule, strclass = exec_name.rsplit('.', 1) - classe = getattr(importlib.import_module(strmodule), strclass) + classe = str2class(tdict['executable']) tkwargs = tdict.get('arguments') or {} # if the db field is null it returns None task = classe(session_path, one=one, taskid=tdict['id'], machine=machine, clobber=clobber, location=location, **tkwargs) @@ -748,7 +817,7 @@ def run_alyx_task(tdict=None, session_path=None, one=None, job_deck=None, patch_data = {'time_elapsed_secs': task.time_elapsed_secs, 'log': task.log, 'version': task.version} # if there is no data to register, set status to Empty - if task.outputs is None: + if task.outputs is None: # NB: an empty list is still considered Complete. patch_data['status'] = 'Empty' # otherwise register data and set (provisional) status to Complete else: diff --git a/ibllib/tests/extractors/test_mesoscope.py b/ibllib/tests/extractors/test_mesoscope.py new file mode 100644 index 000000000..b9447b630 --- /dev/null +++ b/ibllib/tests/extractors/test_mesoscope.py @@ -0,0 +1,64 @@ +"""Tests for ibllib.io.extractors.mesoscope module.""" +import unittest +from itertools import repeat, chain + +import numpy as np + +from ibllib.io.extractors import mesoscope + + +class TestMesoscopeSyncTimeline(unittest.TestCase): + """Tests for MesoscopeSyncTimeline extractor class.""" + def setUp(self) -> None: + """Simulate for meta data for 9 FOVs at 3 different depths. + + These simulated values match those from SP048/2024-02-05/001. + """ + n_lines_flyback = 75 + self.n_lines = 512 + self.n_FOV = 9 + n_depths = 3 + assert self.n_FOV > n_depths and self.n_FOV % n_depths == 0 + reps = int(self.n_FOV / n_depths) + start_depth = 60 + delta_depth = 40 + self.line_period = 4.158e-05 + + self.meta = { + 'scanImageParams': {'hRoiManager': {'linePeriod': self.line_period, 'scanFrameRate': 13.6803}}, + 'FOV': [] + } + nXnYnZ = [self.n_lines, self.n_lines, 1] + for i, slice_id in enumerate(chain.from_iterable(map(lambda x: list(repeat(x, reps)), range(n_depths)))): + offset = (i % n_depths) * (self.n_lines + n_lines_flyback) - ((i % n_depths) - 1) + offset = offset or 1 # start at 1 for MATLAB indexing + fov = {'slice_id': slice_id, 'Zs': start_depth + (delta_depth * slice_id), + 'nXnYnZ': nXnYnZ, 'lineIdx': list(range(offset, self.n_lines + offset))} + self.meta['FOV'].append(fov) + + def test_get_timeshifts_multidepth(self): + """Test MescopeSyncTimeline.get_timeshifts method. + + This tests output when given multiple FOVs at different depths. The tasks/mesoscope_tasks.py + module in iblscripts more thoroughly tests single-depth imaging with real data. + """ + line_indices, fov_time_shifts, line_time_shifts = mesoscope.MesoscopeSyncTimeline.get_timeshifts(self.meta) + expected = [np.array(x['lineIdx']) for x in self.meta['FOV']] + self.assertTrue(np.all(x == y) for x, y in zip(expected, line_indices)) + self.assertEqual(self.n_FOV, len(fov_time_shifts)) + self.assertEqual(self.n_FOV, len(line_time_shifts)) + self.assertTrue(all(len(x) == self.n_lines for x in line_time_shifts)) + + expected = self.line_period * np.arange(self.n_lines) + for i, line_shifts in enumerate(line_time_shifts): + with self.subTest(f'FOV == {i}'): + self.assertEqual(self.n_lines, len(line_shifts)) + np.testing.assert_almost_equal(expected, line_shifts) + + # NB: The following values are fixed for the setup parameters + expected = [0., 0.02436588, 0.04873176, 0.07309781, 0.09746369, 0.12182957, 0.14619562, 0.1705615, 0.19492738] + np.testing.assert_almost_equal(expected, fov_time_shifts) + + +if __name__ == '__main__': + unittest.main() diff --git a/ibllib/tests/test_mesoscope.py b/ibllib/tests/test_mesoscope.py index 7ae3e5cd4..95afe8a4a 100644 --- a/ibllib/tests/test_mesoscope.py +++ b/ibllib/tests/test_mesoscope.py @@ -1,4 +1,4 @@ -"""Tests for ibllib.pipes.mesoscope_tasks""" +"""Tests for ibllib.pipes.mesoscope_tasks.""" import sys import unittest from unittest import mock @@ -68,6 +68,7 @@ def test_meta(self): } meta = { + 'nFrames': 2000, 'scanImageParams': {'hStackManager': {'zs': 320}, 'hRoiManager': {'scanVolumeRate': 6.8}}, 'FOV': [{'topLeftDeg': [-1, 1.3], 'topRightDeg': [3, 1.3], 'bottomLeftDeg': [-1, 5.2], @@ -256,10 +257,10 @@ class TestImagingMeta(unittest.TestCase): """Test raw imaging metadata versioning.""" def test_patch_imaging_meta(self): """Test for ibllib.io.extractors.mesoscope.patch_imaging_meta function.""" - meta = {'version': '0.1.0', 'FOV': [{'roiUuid': None}, {'roiUUID': None}]} + meta = {'version': '0.1.0', 'nFrames': 2000, 'FOV': [{'roiUuid': None}, {'roiUUID': None}]} new_meta = mesoscope.patch_imaging_meta(meta) self.assertEqual(set(chain(*map(dict.keys, new_meta['FOV']))), {'roiUUID'}) - meta = {'FOV': [ + meta = {'nFrames': 2000, 'FOV': [ dict.fromkeys(['topLeftDeg', 'topRightDeg', 'bottomLeftDeg', 'bottomRightDeg']), dict.fromkeys(['topLeftMM', 'topRightMM', 'bottomLeftMM', 'bottomRightMM']) ]} diff --git a/ibllib/tests/test_pipes.py b/ibllib/tests/test_pipes.py index 14d0643af..9b1df5c1f 100644 --- a/ibllib/tests/test_pipes.py +++ b/ibllib/tests/test_pipes.py @@ -13,15 +13,15 @@ from uuid import uuid4 from datetime import datetime +from one.webclient import AlyxClient from one.api import ONE, OneAlyx import iblutil.io.params as iopar from packaging.version import Version, InvalidVersion import ibllib.io.extractors.base import ibllib.tests.fixtures.utils as fu -from ibllib.pipes import misc +from ibllib.pipes import misc, local_server from ibllib.pipes.misc import sleepless -from ibllib.pipes import local_server from ibllib.tests import TEST_DB import ibllib.pipes.scan_fix_passive_files as fix from ibllib.pipes.base_tasks import RegisterRawDataTask @@ -41,6 +41,34 @@ def setUp(self): raw_behaviour_data.parent.joinpath('raw_session.flag').touch() fu.populate_task_settings(raw_behaviour_data, patch={'PYBPOD_PROTOCOL': 'ephys_optoChoiceWorld6.0.1'}) + @mock.patch('ibllib.pipes.local_server.get_local_data_repository') + def test_task_queue(self, lab_repo_mock): + """Test ibllib.pipes.local_server.task_queue function.""" + lab_repo_mock.return_value = 'foo_repo' + tasks = [ + {'executable': 'ibllib.pipes.mesoscope_tasks.MesoscopePreprocess', 'priority': 80}, + {'executable': 'ibllib.pipes.ephys_preprocessing.SpikeSorting', 'priority': SpikeSorting.priority}, + {'executable': 'ibllib.pipes.base_tasks.RegisterRawDataTask', 'priority': RegisterRawDataTask.priority} + ] + alyx = mock.Mock(spec=AlyxClient) + alyx.rest.return_value = tasks + queue = local_server.task_queue(lab='foolab', alyx=alyx) + alyx.rest.assert_called() + self.assertEqual('Waiting', alyx.rest.call_args.kwargs.get('status')) + self.assertIn('foolab', alyx.rest.call_args.kwargs.get('django', '')) + self.assertIn('foo_repo', alyx.rest.call_args.kwargs.get('django', '')) + # Expect to return tasks in descending priority order, without mesoscope task (different env) + self.assertEqual([tasks[2], tasks[1]], queue) + # Expect only mesoscope task returned when relevant env passed + queue = local_server.task_queue(lab='foolab', alyx=alyx, env=('suite2p',)) + self.assertEqual([tasks[0]], queue) + # Expect no tasks as mesoscope task is a large job + queue = local_server.task_queue(mode='small', lab='foolab', alyx=alyx, env=('suite2p',)) + self.assertEqual([], queue) + # Expect only register task as it's the only small job + queue = local_server.task_queue(mode='small', lab='foolab', alyx=alyx) + self.assertEqual([tasks[2]], queue) + @mock.patch('ibllib.pipes.local_server.IBLRegistrationClient') @mock.patch('ibllib.pipes.local_server.make_pipeline') def test_job_creator(self, pipeline_mock, _): diff --git a/ibllib/tests/test_tasks.py b/ibllib/tests/test_tasks.py index 602579922..482697b39 100644 --- a/ibllib/tests/test_tasks.py +++ b/ibllib/tests/test_tasks.py @@ -1,3 +1,4 @@ +"""Test ibllib.pipes.tasks module and Task class.""" import shutil import tempfile import unittest @@ -112,6 +113,7 @@ class TaskGpuLock(ibllib.pipes.tasks.Task): def setUp(self): self.make_lock_file() self.data_handler = self.get_data_handler() + self.input_files = [] return True def _run(self, overwrite=False): @@ -356,5 +358,51 @@ def test_get_device_collection(self): self.assertEqual('raw_ephys_data/probe00', collection) +class TestTask(unittest.TestCase): + def setUp(self): + tmpdir = tempfile.TemporaryDirectory() + self.addCleanup(tmpdir.cleanup) + self.tmpdir = Path(tmpdir.name) + self.session_path = self.tmpdir.joinpath('subject', '1900-01-01', '001') + self.session_path.mkdir(parents=True) + + def test_input_files_to_register(self): + """Test for Task._input_files_to_register method.""" + task = Task00(self.session_path) + self.assertRaises(RuntimeError, task._input_files_to_register) + task.input_files = [('register.optional.ext', 'alf', False, True), + ('register.optional_foo.ext', 'alf', False, True), + ('register.required.ext', 'alf', True, True), + ('ignore.required.ext', 'alf', True), + ('ignore.optional.ext', 'alf', False)] + self.session_path.joinpath('alf').mkdir() + for f in task.input_files: + self.session_path.joinpath(f[1], f[0]).touch() + files = task._input_files_to_register(assert_all_exist=True) + expected = [self.session_path.joinpath('alf', 'register.required.ext'), + self.session_path.joinpath('alf', 'register.optional.ext'), + self.session_path.joinpath('alf', 'register.optional_foo.ext')] + self.assertCountEqual(files, expected) + expected[2].unlink() + with self.assertNoLogs(ibllib.pipes.tasks.__name__, level='ERROR'): + files = task._input_files_to_register(assert_all_exist=True) + self.assertCountEqual(files, expected[:2]) + expected[0].unlink() + with self.assertLogs(ibllib.pipes.tasks.__name__, level='ERROR'): + files = task._input_files_to_register(assert_all_exist=False) + self.assertEqual(files, expected[1:2]) + self.assertRaises(AssertionError, task._input_files_to_register, assert_all_exist=True) + + +class TestMisc(unittest.TestCase): + """Tests for misc functions in ibllib.pipes.tasks module.""" + + def test_str2class(self): + """Test ibllib.pipes.tasks.str2class function.""" + task_str = 'ibllib.pipes.base_tasks.ExperimentDescriptionRegisterRaw' + self.assertIs(ibllib.pipes.tasks.str2class(task_str), ExperimentDescriptionRegisterRaw) + self.assertRaises(AttributeError, ibllib.pipes.tasks.str2class, 'ibllib.pipes.base_tasks.Foo') + + if __name__ == '__main__': unittest.main(exit=False, verbosity=2)