From 7aaaaa4efa9c1190d744281ac6d905c43ea945b1 Mon Sep 17 00:00:00 2001 From: Miles Wells Date: Fri, 15 Dec 2023 16:20:35 +0200 Subject: [PATCH] Add get_trials_tasks function Fix tests flake8 Fix test: double temp dir Logging to investigate failing test on Github workflow build Skip test of Github build --- ibllib/io/extractors/base.py | 161 +++++++++++++++++++------ ibllib/io/extractors/camera.py | 10 +- ibllib/pipes/dynamic_pipeline.py | 70 ++++++++++- ibllib/pipes/misc.py | 1 + ibllib/pipes/training_preprocessing.py | 2 +- ibllib/tests/fixtures/utils.py | 22 +++- ibllib/tests/test_dynamic_pipeline.py | 104 +++++++++++++++- 7 files changed, 319 insertions(+), 51 deletions(-) diff --git a/ibllib/io/extractors/base.py b/ibllib/io/extractors/base.py index e49d47980..13f594217 100644 --- a/ibllib/io/extractors/base.py +++ b/ibllib/io/extractors/base.py @@ -1,4 +1,5 @@ """Base Extractor classes. + A module for the base Extractor classes. The Extractor, given a session path, will extract the processed data from raw hardware files and optionally save them. """ @@ -10,7 +11,6 @@ import numpy as np import pandas as pd -from one.alf.files import get_session_path from ibllib.io import raw_data_loaders as raw from ibllib.io.raw_data_loaders import load_settings, _logger @@ -162,7 +162,8 @@ def extract(self, bpod_trials=None, settings=None, **kwargs): def run_extractor_classes(classes, session_path=None, **kwargs): """ - Run a set of extractors with the same inputs + Run a set of extractors with the same inputs. + :param classes: list of Extractor class :param save: True/False :param path_out: (defaults to alf path) @@ -195,12 +196,30 @@ def run_extractor_classes(classes, session_path=None, **kwargs): def _get_task_types_json_config(): + """ + Return the extractor types map. + + This function is only used for legacy sessions, i.e. those without an experiment description + file and will be removed in favor of :func:`_get_task_extractor_map`, which directly returns + the Bpod extractor class name. The experiment description file cuts out the need for pipeline + name identifiers. + + Returns + ------- + Dict[str, str] + A map of task protocol to task extractor identifier, e.g. 'ephys', 'habituation', etc. + + See Also + -------- + _get_task_extractor_map - returns a map of task protocol to Bpod trials extractor class name. + """ with open(Path(__file__).parent.joinpath('extractor_types.json')) as fp: task_types = json.load(fp) try: # look if there are custom extractor types in the personal projects repo import projects.base custom_extractors = Path(projects.base.__file__).parent.joinpath('extractor_types.json') + _logger.debug('Loading extractor types from %s', custom_extractors) with open(custom_extractors) as fp: custom_task_types = json.load(fp) task_types.update(custom_task_types) @@ -210,8 +229,28 @@ def _get_task_types_json_config(): def get_task_protocol(session_path, task_collection='raw_behavior_data'): + """ + Return the task protocol name from task settings. + + If the session path and/or task collection do not exist, the settings file is missing or + otherwise can not be parsed, or if the 'PYBPOD_PROTOCOL' key is absent, None is returned. + A warning is logged if the session path or settings file doesn't exist. An error is logged if + the settings file can not be parsed. + + Parameters + ---------- + session_path : str, pathlib.Path + The absolute session path. + task_collection : str + The session path directory containing the task settings file. + + Returns + ------- + str or None + The Pybpod task protocol name or None if not found. + """ try: - settings = load_settings(get_session_path(session_path), task_collection=task_collection) + settings = load_settings(session_path, task_collection=task_collection) except json.decoder.JSONDecodeError: _logger.error(f'Can\'t read settings for {session_path}') return @@ -223,11 +262,26 @@ def get_task_protocol(session_path, task_collection='raw_behavior_data'): def get_task_extractor_type(task_name): """ - Returns the task type string from the full pybpod task name: - _iblrig_tasks_biasedChoiceWorld3.7.0 returns "biased" - _iblrig_tasks_trainingChoiceWorld3.6.0 returns "training' - :param task_name: - :return: one of ['biased', 'habituation', 'training', 'ephys', 'mock_ephys', 'sync_ephys'] + Returns the task type string from the full pybpod task name. + + Parameters + ---------- + task_name : str + The complete task protocol name from the PYBPOD_PROTOCOL field of the task settings. + + Returns + ------- + str + The extractor type identifier. Examples include 'biased', 'habituation', 'training', + 'ephys', 'mock_ephys' and 'sync_ephys'. + + Examples + -------- + >>> get_task_extractor_type('_iblrig_tasks_biasedChoiceWorld3.7.0') + 'biased' + + >>> get_task_extractor_type('_iblrig_tasks_trainingChoiceWorld3.6.0') + 'training' """ if isinstance(task_name, Path): task_name = get_task_protocol(task_name) @@ -245,16 +299,30 @@ def get_task_extractor_type(task_name): def get_session_extractor_type(session_path, task_collection='raw_behavior_data'): """ - From a session path, loads the settings file, finds the task and checks if extractors exist - task names examples: - :param session_path: - :return: bool + Infer trials extractor type from task settings. + + From a session path, loads the settings file, finds the task and checks if extractors exist. + Examples include 'biased', 'habituation', 'training', 'ephys', 'mock_ephys', and 'sync_ephys'. + Note this should only be used for legacy sessions, i.e. those without an experiment description + file. + + Parameters + ---------- + session_path : str, pathlib.Path + The session path for which to determine the pipeline. + task_collection : str + The session path directory containing the raw task data. + + Returns + ------- + str or False + The task extractor type, e.g. 'biased', 'habituation', 'ephys', or False if unknown. """ - settings = load_settings(session_path, task_collection=task_collection) - if settings is None: - _logger.error(f'ABORT: No data found in "{task_collection}" folder {session_path}') + task_protocol = get_task_protocol(session_path, task_collection=task_collection) + if task_protocol is None: + _logger.error(f'ABORT: No task protocol found in "{task_collection}" folder {session_path}') return False - extractor_type = get_task_extractor_type(settings['PYBPOD_PROTOCOL']) + extractor_type = get_task_extractor_type(task_protocol) if extractor_type: return extractor_type else: @@ -263,9 +331,22 @@ def get_session_extractor_type(session_path, task_collection='raw_behavior_data' def get_pipeline(session_path, task_collection='raw_behavior_data'): """ - Get the pre-processing pipeline name from a session path - :param session_path: - :return: + Get the pre-processing pipeline name from a session path. + + Note this is only suitable for legacy sessions, i.e. those without an experiment description + file. This function will be removed in the future. + + Parameters + ---------- + session_path : str, pathlib.Path + The session path for which to determine the pipeline. + task_collection : str + The session path directory containing the raw task data. + + Returns + ------- + str + The pipeline name inferred from the extractor type, e.g. 'ephys', 'training', 'widefield'. """ stype = get_session_extractor_type(session_path, task_collection=task_collection) return _get_pipeline_from_task_type(stype) @@ -273,18 +354,29 @@ def get_pipeline(session_path, task_collection='raw_behavior_data'): def _get_pipeline_from_task_type(stype): """ - Returns the pipeline from the task type. Some tasks types directly define the pipeline - :param stype: session_type or task extractor type - :return: + Return the pipeline from the task type. + + Some task types directly define the pipeline. Note this is only suitable for legacy sessions, + i.e. those without an experiment description file. This function will be removed in the future. + + Parameters + ---------- + stype : str + The session type or task extractor type, e.g. 'habituation', 'ephys', etc. + + Returns + ------- + str + A task pipeline identifier. """ if stype in ['ephys_biased_opto', 'ephys', 'ephys_training', 'mock_ephys', 'sync_ephys']: return 'ephys' elif stype in ['habituation', 'training', 'biased', 'biased_opto']: return 'training' - elif 'widefield' in stype: + elif isinstance(stype, str) and 'widefield' in stype: return 'widefield' else: - return stype + return stype or '' def _get_task_extractor_map(): @@ -293,7 +385,7 @@ def _get_task_extractor_map(): Returns ------- - dict(str, str) + Dict[str, str] A map of task protocol to Bpod trials extractor class. """ FILENAME = 'task_extractor_map.json' @@ -315,26 +407,26 @@ def get_bpod_extractor_class(session_path, task_collection='raw_behavior_data'): """ Get the Bpod trials extractor class associated with a given Bpod session. + Note that unlike :func:`get_session_extractor_type`, this function maps directly to the Bpod + trials extractor class name. This is hardware invariant and is purly to determine the Bpod only + trials extractor. + Parameters ---------- session_path : str, pathlib.Path The session path containing Bpod behaviour data. task_collection : str - The session_path subfolder containing the Bpod settings file. + The session_path sub-folder containing the Bpod settings file. Returns ------- str The extractor class name. """ - # Attempt to load settings files - settings = load_settings(session_path, task_collection=task_collection) - if settings is None: - raise ValueError(f'No data found in "{task_collection}" folder {session_path}') - # Attempt to get task protocol - protocol = settings.get('PYBPOD_PROTOCOL') + # Attempt to get protocol name from settings file + protocol = get_task_protocol(session_path, task_collection=task_collection) if not protocol: - raise ValueError(f'No task protocol found in {session_path/task_collection}') + raise ValueError(f'No task protocol found in {Path(session_path) / task_collection}') return protocol2extractor(protocol) @@ -342,7 +434,8 @@ def protocol2extractor(protocol): """ Get the Bpod trials extractor class associated with a given Bpod task protocol. - The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of _iblrig_taskSettings.raw.json. + The Bpod task protocol can be found in the 'PYBPOD_PROTOCOL' field of the + _iblrig_taskSettings.raw.json file. Parameters ---------- diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 93554c86a..a44010821 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -1,4 +1,5 @@ """ Camera extractor functions. + This module handles extraction of camera timestamps for both Bpod and DAQ. """ import logging @@ -29,7 +30,7 @@ def extract_camera_sync(sync, chmap=None): """ - Extract camera timestamps from the sync matrix + Extract camera timestamps from the sync matrix. :param sync: dictionary 'times', 'polarities' of fronts detected on sync trace :param chmap: dictionary containing channel indices. Default to constant. @@ -45,7 +46,8 @@ def extract_camera_sync(sync, chmap=None): def get_video_length(video_path): """ - Returns video length + Returns video length. + :param video_path: A path to the video :return: """ @@ -58,9 +60,7 @@ def get_video_length(video_path): class CameraTimestampsFPGA(BaseExtractor): - """ - Extractor for videos using DAQ sync and channel map. - """ + """Extractor for videos using DAQ sync and channel map.""" def __init__(self, label, session_path=None): super().__init__(session_path) diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index d95497380..ec4228256 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -13,7 +13,7 @@ import spikeglx import ibllib.io.session_params as sess_params -import ibllib.io.extractors.base +from ibllib.io.extractors.base import get_pipeline, get_session_extractor_type import ibllib.pipes.tasks as mtasks import ibllib.pipes.base_tasks as bstasks import ibllib.pipes.widefield_tasks as wtasks @@ -45,7 +45,7 @@ def acquisition_description_legacy_session(session_path, save=False): dict The legacy acquisition description. """ - extractor_type = ibllib.io.extractors.base.get_session_extractor_type(session_path=session_path) + extractor_type = get_session_extractor_type(session_path) etype2protocol = dict(biased='choice_world_biased', habituation='choice_world_habituation', training='choice_world_training', ephys='choice_world_recording') dict_ad = get_acquisition_description(etype2protocol[extractor_type]) @@ -130,7 +130,7 @@ def make_pipeline(session_path, **pkwargs): ---------- session_path : str, Path The absolute session path, i.e. '/path/to/subject/yyyy-mm-dd/nnn'. - **pkwargs + pkwargs Optional arguments passed to the ibllib.pipes.tasks.Pipeline constructor. Returns @@ -147,7 +147,7 @@ def make_pipeline(session_path, **pkwargs): if not acquisition_description: raise ValueError('Experiment description file not found or is empty') devices = acquisition_description.get('devices', {}) - kwargs = {'session_path': session_path} + kwargs = {'session_path': session_path, 'one': pkwargs.get('one')} # Registers the experiment description file tasks['ExperimentDescriptionRegisterRaw'] = type('ExperimentDescriptionRegisterRaw', @@ -430,3 +430,65 @@ def load_pipeline_dict(path): task_list = yaml.full_load(file) return task_list + + +def get_trials_tasks(session_path, one=None): + """ + Return a list of pipeline trials extractor task objects for a given session. + + This function supports both legacy and dynamic pipeline sessions. + + Parameters + ---------- + session_path : str, pathlib.Path + An absolute path to a session. + one : one.api.One + An ONE instance. + + Returns + ------- + list of pipes.tasks.Task + A list of task objects for the provided session. + + """ + # Check for an experiment.description file; ensure downloaded if possible + if one and one.to_eid(session_path): # to_eid returns None if session not registered + one.load_datasets(session_path, ['_ibl_experiment.description'], download_only=True, assert_present=False) + experiment_description = sess_params.read_params(session_path) + + # If experiment description file then use this to make the pipeline + if experiment_description is not None: + tasks = [] + pipeline = make_pipeline(session_path, one=one) + trials_tasks = [t for t in pipeline.tasks if 'Trials' in t] + for task in trials_tasks: + t = pipeline.tasks.get(task) + t.__init__(session_path, **t.kwargs) + tasks.append(t) + else: + # Otherwise default to old way of doing things + pipeline = get_pipeline(session_path) + if pipeline == 'training': + from ibllib.pipes.training_preprocessing import TrainingTrials + tasks = [TrainingTrials(session_path, one=one)] + elif pipeline == 'ephys': + from ibllib.pipes.ephys_preprocessing import EphysTrials + tasks = [EphysTrials(session_path, one=one)] + else: + try: + # try to find a custom extractor in the personal projects extraction class + import projects.base + task_type = get_session_extractor_type(session_path) + assert (PipelineClass := projects.base.get_pipeline(task_type)) + pipeline = PipelineClass(session_path, one=one) + trials_task_name = next((task for task in pipeline.tasks if 'Trials' in task), None) + assert trials_task_name, (f'No "Trials" tasks for custom pipeline ' + f'"{pipeline.name}" with extractor type "{task_type}"') + task = pipeline.tasks.get(trials_task_name) + task(session_path) + tasks = [task] + except (ModuleNotFoundError, AssertionError) as ex: + _logger.warning('Failed to get trials tasks: %s', ex) + tasks = [] + + return tasks diff --git a/ibllib/pipes/misc.py b/ibllib/pipes/misc.py index 39871ad00..27e1df5b3 100644 --- a/ibllib/pipes/misc.py +++ b/ibllib/pipes/misc.py @@ -38,6 +38,7 @@ def subjects_data_folder(folder: Path, rglob: bool = False) -> Path: """Given a root_data_folder will try to find a 'Subjects' data folder. + If Subjects folder is passed will return it directly.""" if not isinstance(folder, Path): folder = Path(folder) diff --git a/ibllib/pipes/training_preprocessing.py b/ibllib/pipes/training_preprocessing.py index db41f8992..ad2172809 100644 --- a/ibllib/pipes/training_preprocessing.py +++ b/ibllib/pipes/training_preprocessing.py @@ -19,7 +19,7 @@ from ibllib.qc.task_extractors import TaskQCExtractor _logger = logging.getLogger(__name__) -warnings.warn('`pipes.training_preprocessing` to be removed in favour of dynamic pipeline') +warnings.warn('`pipes.training_preprocessing` to be removed in favour of dynamic pipeline', FutureWarning) # level 0 diff --git a/ibllib/tests/fixtures/utils.py b/ibllib/tests/fixtures/utils.py index ac7ac5f71..f536875d0 100644 --- a/ibllib/tests/fixtures/utils.py +++ b/ibllib/tests/fixtures/utils.py @@ -216,7 +216,7 @@ def create_fake_raw_behavior_data_folder( ): """Create the folder structure for a raw behaviour session. - Creates a raw_behavior_data folder and optionally, touches some files and writes a experiment + Creates a raw_behavior_data folder and optionally, touches some files and writes an experiment description stub to a `_devices` folder. Parameters @@ -304,8 +304,26 @@ def create_fake_raw_behavior_data_folder( def populate_task_settings(fpath: Path, patch: dict): - with fpath.open("w") as f: + """ + Populate a task settings JSON file. + + Parameters + ---------- + fpath : pathlib.Path + A path to a raw task settings folder or the full settings file path. + patch : dict + The settings dict to write to file. + + Returns + ------- + pathlib.Path + The full settings file path. + """ + if fpath.is_dir(): + fpath /= '_iblrig_taskSettings.raw.json' + with fpath.open('w') as f: json.dump(patch, f, indent=1) + return fpath def create_fake_complete_ephys_session( diff --git a/ibllib/tests/test_dynamic_pipeline.py b/ibllib/tests/test_dynamic_pipeline.py index 8b32b5ff6..41420c674 100644 --- a/ibllib/tests/test_dynamic_pipeline.py +++ b/ibllib/tests/test_dynamic_pipeline.py @@ -1,15 +1,22 @@ import tempfile from pathlib import Path import unittest +from unittest import mock from itertools import chain +import yaml + import ibllib.tests -from ibllib.pipes import dynamic_pipeline +import ibllib.pipes.dynamic_pipeline as dyn +from ibllib.pipes.tasks import Pipeline, Task +from ibllib.pipes import ephys_preprocessing +from ibllib.pipes import training_preprocessing from ibllib.io import session_params +from ibllib.tests.fixtures.utils import populate_task_settings def test_read_write_params_yaml(): - ad = dynamic_pipeline.get_acquisition_description('choice_world_recording') + ad = dyn.get_acquisition_description('choice_world_recording') with tempfile.TemporaryDirectory() as td: session_path = Path(td) session_params.write_params(session_path, ad) @@ -21,14 +28,14 @@ class TestCreateLegacyAcqusitionDescriptions(unittest.TestCase): def test_legacy_biased(self): session_path = Path(ibllib.tests.__file__).parent.joinpath('extractors', 'data', 'session_biased_ge5') - ad = dynamic_pipeline.acquisition_description_legacy_session(session_path) + ad = dyn.acquisition_description_legacy_session(session_path) protocols = list(chain(*map(dict.keys, ad.get('tasks', [])))) self.assertCountEqual(['biasedChoiceWorld'], protocols) self.assertEqual(1, len(ad['devices']['cameras'])) def test_legacy_ephys(self): session_path = Path(ibllib.tests.__file__).parent.joinpath('extractors', 'data', 'session_ephys') - ad_ephys = dynamic_pipeline.acquisition_description_legacy_session(session_path) + ad_ephys = dyn.acquisition_description_legacy_session(session_path) self.assertEqual(2, len(ad_ephys['devices']['neuropixel'])) self.assertEqual(3, len(ad_ephys['devices']['cameras'])) protocols = list(chain(*map(dict.keys, ad_ephys.get('tasks', [])))) @@ -36,7 +43,94 @@ def test_legacy_ephys(self): def test_legacy_training(self): session_path = Path(ibllib.tests.__file__).parent.joinpath('extractors', 'data', 'session_training_ge5') - ad = dynamic_pipeline.acquisition_description_legacy_session(session_path) + ad = dyn.acquisition_description_legacy_session(session_path) protocols = list(chain(*map(dict.keys, ad.get('tasks', [])))) self.assertCountEqual(['trainingChoiceWorld'], protocols) self.assertEqual(1, len(ad['devices']['cameras'])) + + +class TestGetTrialsTasks(unittest.TestCase): + """Test pipes.dynamic_pipeline.get_trials_tasks function.""" + + def setUp(self): + tmpdir = tempfile.TemporaryDirectory() + self.addCleanup(tmpdir.cleanup) + # The github CI root dir contains an alias/symlink so we must resolve it + self.tempdir = Path(tmpdir.name).resolve() + self.session_path_dynamic = self.tempdir / 'subject' / '2023-01-01' / '001' + self.session_path_dynamic.mkdir(parents=True) + description = {'version': '1.0.0', + 'sync': {'nidq': {'collection': 'raw_ephys_data', 'extension': 'bin', 'acquisition_software': 'spikeglx'}}, + 'tasks': [ + {'ephysChoiceWorld': {'task_collection': 'raw_task_data_00'}}, + {'passiveChoiceWorld': {'task_collection': 'raw_task_data_01'}}, + ]} + with open(self.session_path_dynamic / '_ibl_experiment.description.yaml', 'w') as fp: + yaml.safe_dump(description, fp) + + self.session_path_legacy = self.session_path_dynamic.with_name('002') + (collection := self.session_path_legacy.joinpath('raw_behavior_data')).mkdir(parents=True) + self.settings = {'IBLRIG_VERSION': '7.2.2', 'PYBPOD_PROTOCOL': '_iblrig_tasks_ephysChoiceWorld'} + self.settings_path = populate_task_settings(collection, self.settings) + + def test_get_trials_tasks(self): + """Test pipes.dynamic_pipeline.get_trials_tasks function.""" + # A dynamic pipeline session + tasks = dyn.get_trials_tasks(self.session_path_dynamic) + self.assertEqual(2, len(tasks)) + self.assertEqual('raw_task_data_00', tasks[0].collection) + + # Check behaviour with ONE + one = mock.MagicMock() + one.offline = False + one.alyx = mock.MagicMock() + one.alyx.cache_mode = None # sneaky hack as this is checked by the pipeline somewhere + tasks = dyn.get_trials_tasks(self.session_path_dynamic, one) + self.assertEqual(2, len(tasks)) + one.load_datasets.assert_called() # check that description file is checked on disk + + # An ephys session + tasks = dyn.get_trials_tasks(self.session_path_legacy) + self.assertEqual(1, len(tasks)) + self.assertIsInstance(tasks[0], ephys_preprocessing.EphysTrials) + + # A training session + self.settings['PYBPOD_PROTOCOL'] = '_iblrig_tasks_trainingChoiceWorld' + populate_task_settings(self.settings_path, self.settings) + + tasks = dyn.get_trials_tasks(self.session_path_legacy, one=one) + self.assertEqual(1, len(tasks)) + self.assertIsInstance(tasks[0], training_preprocessing.TrainingTrials) + self.assertIs(tasks[0].one, one, 'failed to assign ONE instance to task') + + # A personal project + self.settings['PYBPOD_PROTOCOL'] = '_misc_foobarChoiceWorld' + populate_task_settings(self.settings_path, self.settings) + + m = mock.MagicMock() # Mock the project_extractors repo + m.base.__file__ = str(self.tempdir / 'base.py') + # Create the personal project extractor types map + task_type_map = {'_misc_foobarChoiceWorld': 'foobar'} + extractor_types_path = Path(m.base.__file__).parent.joinpath('extractor_types.json') + populate_task_settings(extractor_types_path, task_type_map) + # Simulate the instantiation of the personal project module's pipeline class + pipeline = mock.Mock(spec=Pipeline) + pipeline.name = 'custom' + task_mock = mock.Mock(spec=Task) + pipeline.tasks = {'RegisterRaw': mock.MagicMock(), 'FooBarTrials': task_mock} + m.base.get_pipeline().return_value = pipeline + with mock.patch.dict('sys.modules', projects=m): + """For unknown reasons this method of mocking the personal projects repo (which is + imported within various functions) fails on the Github test builds. This we check + here and skip the rest of the test if patch didn't work.""" + try: + import projects.base + assert isinstance(projects.base, mock.Mock) + except (AssertionError, ModuleNotFoundError): + self.skipTest('Failed to mock projects module import') + tasks = dyn.get_trials_tasks(self.session_path_legacy) + self.assertEqual(1, len(tasks)) + task_mock.assert_called_once_with(self.session_path_legacy) + # Should handle absent trials tasks + pipeline.tasks.pop('FooBarTrials') + self.assertEqual([], dyn.get_trials_tasks(self.session_path_legacy))