diff --git a/ibllib/io/extractors/base.py b/ibllib/io/extractors/base.py index c1b46b22e..cfc9557f4 100644 --- a/ibllib/io/extractors/base.py +++ b/ibllib/io/extractors/base.py @@ -29,9 +29,16 @@ class BaseExtractor(abc.ABC): """ session_path = None + """pathlib.Path: Absolute path of session folder.""" + save_names = None + """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" + var_names = None + """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" + default_path = Path('alf') # relative to session + """pathlib.Path: The default output folder relative to `session_path`.""" def __init__(self, session_path=None): # If session_path is None Path(session_path) will fail @@ -127,6 +134,8 @@ class BaseBpodTrialsExtractor(BaseExtractor): bpod_trials = None settings = None task_collection = None + frame2ttl = None + audio = None def extract(self, bpod_trials=None, settings=None, **kwargs): """ diff --git a/ibllib/io/extractors/biased_trials.py b/ibllib/io/extractors/biased_trials.py index 16d8f8111..e2912d11e 100644 --- a/ibllib/io/extractors/biased_trials.py +++ b/ibllib/io/extractors/biased_trials.py @@ -183,6 +183,8 @@ class EphysTrials(BaseBpodTrialsExtractor): def _extract(self, extractor_classes=None, **kwargs) -> dict: base = [GoCueTriggerTimes, StimOnTriggerTimes, ItiInTimes, StimOffTriggerTimes, StimFreezeTriggerTimes, ErrorCueTriggerTimes, TrialsTableEphys, IncludedTrials, PhasePosQuiescence] + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) # Exclude from trials table out, _ = run_extractor_classes(base, session_path=self.session_path, bpod_trials=self.bpod_trials, settings=self.settings, save=False, task_collection=self.task_collection) diff --git a/ibllib/io/extractors/camera.py b/ibllib/io/extractors/camera.py index 7612c3e9e..93554c86a 100644 --- a/ibllib/io/extractors/camera.py +++ b/ibllib/io/extractors/camera.py @@ -513,12 +513,16 @@ def attribute_times(arr, events, tol=.1, injective=True, take='first'): Returns ------- numpy.array - An array the same length as `events`. + An array the same length as `events` containing indices of `arr` corresponding to each + event. """ if (take := take.lower()) not in ('first', 'nearest', 'after'): raise ValueError('Parameter `take` must be either "first", "nearest", or "after"') stack = np.ma.masked_invalid(arr, copy=False) stack.fill_value = np.inf + # If there are no invalid values, the mask is False so let's ensure it's a bool array + if stack.mask is np.bool_(0): + stack.mask = np.zeros(arr.shape, dtype=bool) assigned = np.full(events.shape, -1, dtype=int) # Initialize output array min_tol = 0 if take == 'after' else -tol for i, x in enumerate(events): diff --git a/ibllib/io/extractors/ephys_fpga.py b/ibllib/io/extractors/ephys_fpga.py index 74ac1e551..187d216f6 100644 --- a/ibllib/io/extractors/ephys_fpga.py +++ b/ibllib/io/extractors/ephys_fpga.py @@ -1,14 +1,47 @@ -"""Data extraction from raw FPGA output -Complete FPGA data extraction depends on Bpod extraction +"""Data extraction from raw FPGA output. + +The behaviour extraction happens in the following stages: + + 1. The NI DAQ events are extracted into a map of event times and TTL polarities. + 2. The Bpod trial events are extracted from the raw Bpod data, depending on the task protocol. + 3. As protocols may be chained together within a given recording, the period of a given task + protocol is determined using the 'spacer' DAQ signal (see `get_protocol_period`). + 4. Physical behaviour events such as stim on and reward time are separated out by TTL length or + sequence within the trial. + 5. The Bpod clock is sync'd with the FPGA using one of the extracted trial events. + 6. The Bpod software events are then converted to FPGA time. + +Examples +-------- +For simple extraction, use the FPGATrials class: + +>>> extractor = FpgaTrials(session_path) +>>> trials, _ = extractor.extract(update=False, save=False) + +Notes +----- +Sync extraction in this module only supports FPGA data acquired with an NI DAQ as part of a +Neuropixels recording system, however a sync and channel map extracted from a different DAQ format +can be passed to the FpgaTrials class. + +See Also +-------- +For dynamic pipeline sessions it is best to call the extractor via the BehaviorTask class. + +TODO notes on subclassing various methods of FpgaTrials for custom hardware. """ -from collections import OrderedDict import logging +from itertools import cycle from pathlib import Path import uuid import re +import warnings +from functools import partial import matplotlib.pyplot as plt +from matplotlib.colors import TABLEAU_COLORS import numpy as np +from packaging import version import spikeglx import neurodsp.utils @@ -21,17 +54,22 @@ from ibllib.io.extractors.bpod_trials import extract_all as bpod_extract_all import ibllib.io.extractors.base as extractors_base from ibllib.io.extractors.training_wheel import extract_wheel_moves -import ibllib.plots as plots +from ibllib import plots from ibllib.io.extractors.default_channel_maps import DEFAULT_MAPS _logger = logging.getLogger(__name__) -SYNC_BATCH_SIZE_SECS = 100 # number of samples to read at once in bin file for sync +SYNC_BATCH_SIZE_SECS = 100 +"""int: Number of samples to read at once in bin file for sync.""" + WHEEL_RADIUS_CM = 1 # stay in radians +"""float: The radius of the wheel used in the task. A value of 1 ensures units remain in radians.""" + WHEEL_TICKS = 1024 +"""int: The number of encoder pulses per channel for one complete rotation.""" -BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 # throws an error if bpod to fpga clock drift is higher -F2TTL_THRESH = 0.01 # consecutive pulses with less than this threshold ignored +BPOD_FPGA_DRIFT_THRESHOLD_PPM = 150 +"""int: Throws an error if Bpod to FPGA clock drift is higher than this value.""" CHMAPS = {'3A': {'ap': @@ -62,10 +100,11 @@ {'imec_sync': 6} }, } +"""dict: The default channel indices corresponding to various devices for different recording systems.""" def data_for_keys(keys, data): - """Check keys exist in 'data' dict and contain values other than None""" + """Check keys exist in 'data' dict and contain values other than None.""" return data is not None and all(k in data and data.get(k, None) is not None for k in keys) @@ -157,6 +196,8 @@ def _assign_events_bpod(bpod_t, bpod_polarities, ignore_first_valve=True): :param bpod_fronts: numpy vector containing polarity of fronts (1 rise, -1 fall) :param ignore_first_valve (True): removes detected valve events at indices le 2 :return: numpy arrays of times t_trial_start, t_valve_open and t_iti_in + + TODO Remove function (now using FpgaTrials._assign_events) """ TRIAL_START_TTL_LEN = 2.33e-4 # the TTL length is 0.1ms but this has proven to drift on # some bpods and this is the highest possible value that discriminates trial start from valve @@ -258,6 +299,8 @@ def _assign_events_audio(audio_t, audio_polarities, return_indices=False, displa :param display (False): for debug mode, displays the raw fronts overlaid with detections :return: numpy arrays t_ready_tone_in, t_error_tone_in :return: numpy arrays ind_ready_tone_in, ind_error_tone_in if return_indices=True + + TODO Remove function (now using FpgaTrials._assign_events) """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) @@ -285,13 +328,29 @@ def _assign_events_to_trial(t_trial_start, t_event, take='last'): """ Assign events to a trial given trial start times and event times. - Trials without an event - result in nan value in output time vector. + Trials without an event result in nan value in output time vector. The output has a consistent size with t_trial_start and ready to output to alf. - :param t_trial_start: numpy vector of trial start times - :param t_event: numpy vector of event times to assign to trials - :param take: 'last' or 'first' (optional, default 'last'): index to take in case of duplicates - :return: numpy array of event times with the same shape of trial start. + + Parameters + ---------- + t_trial_start : numpy.array + An array of start times, used to bin edges for assigning values from `t_event`. + t_event : numpy.array + An array of event times to assign to trials. + take : str {'first', 'last'}, int + 'first' takes first event > t_trial_start; 'last' takes last event < the next + t_trial_start; an int defines the index to take for events within trial bounds. The index + may be negative. + + Returns + ------- + numpy.array + An array the length of `t_trial_start` containing values from `t_event`. Unassigned values + are replaced with np.nan. + + See Also + -------- + FpgaTrials._assign_events - Assign trial events based on TTL length. """ # make sure the events are sorted try: @@ -316,7 +375,7 @@ def _assign_events_to_trial(t_trial_start, t_event, take='last'): else: # if the index is arbitrary, needs to be numeric (could be negative if from the end) iall = np.unique(ind) minsize = take + 1 if take >= 0 else - take - # for each trial, take the takenth element if there are enough values in trial + # for each trial, take the take nth element if there are enough values in trial for iu in iall: match = t_event[iu == ind] if len(match) >= minsize: @@ -382,25 +441,39 @@ def _clean_audio(audio, display=False): return audio -def _clean_frame2ttl(frame2ttl, display=False): +def _clean_frame2ttl(frame2ttl, threshold=0.01, display=False): """ + Clean the frame2ttl events. + Frame 2ttl calibration can be unstable and the fronts may be flickering at an unrealistic pace. This removes the consecutive frame2ttl pulses happening too fast, below a threshold - of F2TTL_THRESH + of F2TTL_THRESH. + + Parameters + ---------- + frame2ttl : dict + A dictionary of frame2TTL events, with keys {'times', 'polarities'}. + threshold : float + Consecutive pulses occurring with this many seconds ignored. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + """ dt = np.diff(frame2ttl['times']) - iko = np.where(np.logical_and(dt < F2TTL_THRESH, frame2ttl['polarities'][:-1] == -1))[0] + iko = np.where(np.logical_and(dt < threshold, frame2ttl['polarities'][:-1] == -1))[0] iko = np.unique(np.r_[iko, iko + 1]) frame2ttl_ = {'times': np.delete(frame2ttl['times'], iko), 'polarities': np.delete(frame2ttl['polarities'], iko)} if iko.size > (0.1 * frame2ttl['times'].size): _logger.warning(f'{iko.size} ({iko.size / frame2ttl["times"].size:.2%}) ' - f'frame to TTL polarity switches below {F2TTL_THRESH} secs') + f'frame to TTL polarity switches below {threshold} secs') if display: # pragma: no cover - from ibllib.plots import squares - plt.figure() - squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9]) - squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9]) + fig, (ax0, ax1) = plt.subplots(2, sharex=True) + plots.squares(frame2ttl['times'] * 1000, frame2ttl['polarities'], yrange=[0.1, 0.9], ax=ax0) + plots.squares(frame2ttl_['times'] * 1000, frame2ttl_['polarities'], yrange=[1.1, 1.9], ax=ax1) import seaborn as sns sns.displot(dt[dt < 0.05], binwidth=0.0005) @@ -425,9 +498,9 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): Returns ------- - np.array + numpy.array Wheel timestamps in seconds. - np.array + numpy.array Wheel positions in radians. """ # Assume two separate edge count channels @@ -440,7 +513,7 @@ def extract_wheel_sync(sync, chmap=None, tmin=None, tmax=None): return re_ts, re_pos -def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tmin=None, tmax=None): +def extract_behaviour_sync(sync, chmap, display=False, bpod_trials=None, tmin=None, tmax=None): """ Extract task related event times from the sync. @@ -463,6 +536,8 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm ------- dict A map of trial event timestamps. + + TODO Remove this function (now using FpgaTrials.extract_behaviour_sync) """ bpod = get_sync_fronts(sync, chmap['bpod'], tmin=tmin, tmax=tmax) if bpod.times.size == 0: @@ -476,6 +551,7 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm t_trial_start, t_valve_open, t_iti_in = _assign_events_bpod(bpod['times'], bpod['polarities']) if not bpod_trials: raise ValueError('No Bpod trials to align') + intervals_bpod = bpod_trials['intervals'] # If there are no detected trial start times or more than double the trial end pulses, # the trial start pulses may be too small to be detected, in which case, sync using the ini_in if t_trial_start.size == 0 or (t_trial_start.size / t_iti_in.size) < .5: @@ -486,12 +562,12 @@ def extract_behaviour_sync(sync, chmap=None, display=False, bpod_trials=None, tm # if it's drifting too much if drift > 200 and bpod_end.size != t_iti_in.size: raise err.SyncBpodFpgaException('sync cluster f*ck') - t_trial_start = fcn(bpod_trials['intervals_bpod'][:, 0]) + t_trial_start = fcn(intervals_bpod[:, 0]) else: # one issue is that sometimes bpod pulses may not have been detected, in this case # perform the sync bpod/FPGA, and add the start that have not been detected _logger.info('Attempting to align on trial start') - bpod_start = bpod_trials['intervals_bpod'][:, 0] + bpod_start = intervals_bpod[:, 0] fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps( bpod_start, t_trial_start, return_indices=True) # if it's drifting too much @@ -703,34 +779,39 @@ def get_protocol_period(session_path, protocol_number, bpod_sync): class FpgaTrials(extractors_base.BaseExtractor): - save_names = ('_ibl_trials.intervals_bpod.npy', - '_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None, + save_names = ('_ibl_trials.goCueTrigger_times.npy', None, None, None, None, None, None, None, '_ibl_trials.stimOff_times.npy', None, None, None, '_ibl_trials.quiescencePeriod.npy', '_ibl_trials.table.pqt', '_ibl_wheel.timestamps.npy', '_ibl_wheel.position.npy', '_ibl_wheelMoves.intervals.npy', '_ibl_wheelMoves.peakAmplitude.npy') - var_names = ('intervals_bpod', - 'goCueTrigger_times', 'stimOnTrigger_times', + var_names = ('goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times', 'errorCue_times', 'itiIn_times', 'stimFreeze_times', 'stimOff_times', 'valveOpen_times', 'phase', 'position', 'quiescence', 'table', 'wheel_timestamps', 'wheel_position', 'wheelMoves_intervals', 'wheelMoves_peakAmplitude') - # Fields from bpod extractor that we want to re-sync to FPGA bpod_rsync_fields = ('intervals', 'response_times', 'goCueTrigger_times', 'stimOnTrigger_times', 'stimOffTrigger_times', 'stimFreezeTrigger_times', 'errorCueTrigger_times') + """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" - # Fields from bpod extractor that we want to save bpod_fields = ('feedbackType', 'choice', 'rewardVolume', 'contrastLeft', 'contrastRight', - 'probabilityLeft', 'intervals_bpod', 'phase', 'position', 'quiescence') + 'probabilityLeft', 'phase', 'position', 'quiescence') + """tuple of str: Fields from bpod extractor that we want to save.""" + + sync_field = 'intervals_0' # trial start events + """str: The trial event to synchronize (must be present in extracted trials).""" - """str: The Bpod events to synchronize (must be present in sync channel map).""" - sync_field = 'intervals' + bpod = None + """dict of numpy.array: The Bpod out TTLs recorded on the DAQ. Used in the QC viewer plot.""" def __init__(self, *args, bpod_trials=None, bpod_extractor=None, **kwargs): - """An extractor for all ephys trial data, in FPGA time""" + """An extractor for ephysChoiceWorld trials data, in FPGA time. + + This class may be subclassed to handle moderate variations in hardware and task protocol, + however there is flexible + """ super().__init__(*args, **kwargs) self.bpod2fpga = None self.bpod_trials = bpod_trials @@ -781,7 +862,7 @@ def _update_var_names(self, bpod_fields=None, bpod_rsync_fields=None): if not self.bpod_trials: self.bpod_trials = self.bpod_extractor.extract(save=False) table_keys = alfio.AlfBunch.from_df(self.bpod_trials['table']).keys() - self.bpod_fields += (*[x for x in table_keys if x not in excluded], self.sync_field + '_bpod') + self.bpod_fields += tuple([x for x in table_keys if x not in excluded]) @staticmethod def _time_fields(trials_attr) -> set: @@ -802,72 +883,260 @@ def _time_fields(trials_attr) -> set: pattern = re.compile(fr'^[_\w]*({"|".join(FIELDS)})[_\w]*$') return set(filter(pattern.match, trials_attr)) + def load_sync(self, sync_collection='raw_ephys_data', **kwargs): + """Load the DAQ sync and channel map data. + + This method may be subclassed for novel DAQ systems. The sync must contain the following + keys: 'times' - an array timestamps in seconds; 'polarities' - an array of {-1, 1} + corresponding to TTL LOW and TTL HIGH, respectively; 'channels' - an array of ints + corresponding to channel number. + + Parameters + ---------- + sync_collection : str + The session subdirectory where the sync data are located. + kwargs + Optional arguments used by subclass methods. + + Returns + ------- + one.alf.io.AlfBunch + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + dict + A map of channel names and their corresponding indices. + """ + return get_sync_and_chn_map(self.session_path, sync_collection) + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', task_collection='raw_behavior_data', **kwargs) -> dict: - """Extracts ephys trials by combining Bpod and FPGA sync pulses""" - # extract the behaviour data from bpod + """Extracts ephys trials by combining Bpod and FPGA sync pulses. + + It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` + attributes are all correct for the bpod protocol used. + + Below are the steps involved: + 0. Load sync and bpod trials, if required. + 1. Determine protocol period and discard sync events outside the task. + 2. Classify multiplexed TTL events based on length (see :meth:`FpgaTrials.build_trials`). + 3. Sync the Bpod clock to the DAQ clock using one of the assigned trial events. + 4. Assign classified TTL events to trial events based on order within the trial. + 4. Convert Bpod software event times to DAQ clock. + 5. Extract the wheel from the DAQ rotary encoder signal, if required. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. If None, the sync is loaded using the + `load_sync` method. + chmap : dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the :meth:`FpgaTrials.load_sync` method. + sync_collection : str + The session subdirectory where the sync data are located. This is only used if the + sync or channel maps are not provided. + task_collection : str + The session subdirectory where the raw Bpod data are located. This is used for loading + the task settings and extracting the bpod trials, if not already done. + protocol_number : int + The protocol number if multiple protocols were run during the session. If provided, a + spacer signal must be present in order to determine the correct period. + kwargs + Optional arguments for subclass methods to use. + + Returns + ------- + dict + A dictionary of numpy arrays with `FpgaTrials.var_names` as keys. + """ if sync is None or chmap is None: - _sync, _chmap = get_sync_and_chn_map(self.session_path, sync_collection) + _sync, _chmap = self.load_sync(sync_collection) sync = sync or _sync chmap = chmap or _chmap - if not self.bpod_trials: + if not self.bpod_trials: # extract the behaviour data from bpod self.bpod_trials, *_ = bpod_extract_all( session_path=self.session_path, task_collection=task_collection, save=False, extractor_type=kwargs.get('extractor_type')) + # Explode trials table df - trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) - table_columns = trials_table.keys() - self.bpod_trials.update(trials_table) - self.bpod_trials['intervals_bpod'] = np.copy(self.bpod_trials['intervals']) + if 'table' in self.var_names: + trials_table = alfio.AlfBunch.from_df(self.bpod_trials.pop('table')) + table_columns = trials_table.keys() + self.bpod_trials.update(trials_table) + else: + if 'table' in self.bpod_trials: + _logger.error( + '"table" found in Bpod trials but missing from `var_names` attribute and will' + 'therefore not be extracted. This is likely in error.') + table_columns = None # Get the spacer times for this protocol - if (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer + if any(arg in kwargs for arg in ('tmin', 'tmax')): + tmin, tmax = kwargs.get('tmin'), kwargs.get('tmax') + elif (protocol_number := kwargs.get('protocol_number')) is not None: # look for spacer # The spacers are TTLs generated by Bpod at the start of each protocol bpod = get_sync_fronts(sync, chmap['bpod']) tmin, tmax = get_protocol_period(self.session_path, protocol_number, bpod) else: tmin = tmax = None - # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC - fpga_trials, self.frame2ttl, self.audio, self.bpod = extract_behaviour_sync( - sync=sync, chmap=chmap, bpod_trials=self.bpod_trials, tmin=tmin, tmax=tmax) - assert self.sync_field in self.bpod_trials and self.sync_field in fpga_trials - self.bpod_trials[f'{self.sync_field}_bpod'] = np.copy(self.bpod_trials[self.sync_field]) - - # checks consistency and compute dt with bpod - self.bpod2fpga, drift_ppm, ibpod, ifpga = neurodsp.utils.sync_timestamps( - self.bpod_trials[f'{self.sync_field}_bpod'][:, 0], fpga_trials.pop(self.sync_field)[:, 0], - return_indices=True) - nbpod = self.bpod_trials[f'{self.sync_field}_bpod'].shape[0] - npfga = fpga_trials['feedback_times'].shape[0] - nsync = len(ibpod) - _logger.info(f'N trials: {nbpod} bpod, {npfga} FPGA, {nsync} merged, sync {drift_ppm} ppm') - if drift_ppm > BPOD_FPGA_DRIFT_THRESHOLD_PPM: - _logger.warning('BPOD/FPGA synchronization shows values greater than %i ppm', - BPOD_FPGA_DRIFT_THRESHOLD_PPM) - out = OrderedDict() - out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) - out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) - out.update({k: fpga_trials[k][ifpga] for k in sorted(fpga_trials.keys())}) + # Remove unnecessary data from sync + selection = np.logical_and( + sync['times'] <= (tmax if tmax is not None else sync['times'][-1]), + sync['times'] >= (tmin if tmin is not None else sync['times'][0]), + ) + sync = alfio.AlfBunch({k: v[selection] for k, v in sync.items()}) + _logger.debug('Protocol period from %.2fs to %.2fs (~%.0f min duration)', + *sync['times'][[0, -1]], np.diff(sync['times'][[0, -1]]) / 60) + + # Get the trial events from the DAQ sync TTLs, sync clocks and build final trials datasets + out = self.build_trials(sync=sync, chmap=chmap, **kwargs) + # extract the wheel data - wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) - from ibllib.io.extractors.training_wheel import extract_first_movement_times - if not self.settings: - self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) - min_qt = self.settings.get('QUIESCENT_PERIOD', None) - first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) - out.update({'firstMovement_times': first_move_onsets}) + if any(x.startswith('wheel') for x in self.var_names): + wheel, moves = self.get_wheel_positions(sync=sync, chmap=chmap, tmin=tmin, tmax=tmax) + from ibllib.io.extractors.training_wheel import extract_first_movement_times + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + min_qt = self.settings.get('QUIESCENT_PERIOD', None) + first_move_onsets, *_ = extract_first_movement_times(moves, out, min_qt=min_qt) + out.update({'firstMovement_times': first_move_onsets}) + out.update({f'wheel_{k}': v for k, v in wheel.items()}) + out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) + # Re-create trials table - trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) - out['table'] = trials_table.to_df() + if table_columns: + trials_table = alfio.AlfBunch({x: out.pop(x) for x in table_columns}) + out['table'] = trials_table.to_df() - out.update({f'wheel_{k}': v for k, v in wheel.items()}) - out.update({f'wheelMoves_{k}': v for k, v in moves.items()}) - out = {k: out[k] for k in self.var_names if k in out} # Reorder output + out = alfio.AlfBunch({k: out[k] for k in self.var_names if k in out}) # Reorder output assert self.var_names == tuple(out.keys()) return out + def build_trials(self, sync, chmap, display=False, **kwargs): + """ + Extract task related event times from the sync. + + The trial start times are the shortest Bpod TTLs and occur at the start of the trial. The + first trial start TTL of the session is longer and must be handled differently. The trial + start TTL is used to assign the other trial events to each trial. + + The trial end is the end of the so-called 'ITI' Bpod event TTL (classified as the longest + of the three Bpod event TTLs). Go cue audio TTLs are the shorter of the two expected audio + tones. The first of these after each trial start is taken to be the go cue time. Error + tones are longer audio TTLs and assigned as the last of such occurrence after each trial + start. The valve open Bpod TTLs are medium-length, the last of which is used for each trial. + The feedback times are times of either valve open or error tone as there should be only one + such event per trial. + + The stimulus times are taken from the frame2ttl events (with improbably high frequency TTLs + removed): the first TTL after each trial start is assumed to be the stim onset time; the + second to last and last are taken as the stimulus freeze and offset times, respectively. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: + raise ValueError( + 'Expected at least "ready_tone" and "error_tone" audio events.' + '`audio_event_ttls` kwarg may be incorrect.') + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'trial_start', 'valve_open', 'trial_end'}: + raise ValueError( + 'Expected at least "trial_start", "trial_end", and "valve_open" audio events. ' + '`bpod_event_ttls` kwarg may be incorrect.') + + t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T + fpga_events = alfio.AlfBunch({ + 'goCue_times': audio_event_intervals['ready_tone'][:, 0], + 'errorCue_times': audio_event_intervals['error_tone'][:, 0], + 'valveOpen_times': bpod_event_intervals['valve_open'][:, 0], + 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], + 'itiIn_times': t_iti_in, + 'intervals_0': bpod_event_intervals['trial_start'][:, 0], + 'intervals_1': t_trial_end + }) + + # Sync the Bpod clock to the DAQ. + # NB: The Bpod extractor typically drops the final, incomplete, trial. Hence there is + # usually at least one extra FPGA event. This shouldn't affect the sync. The final trial is + # dropped after assigning the FPGA events, using the `ifpga` index. Doing this after + # assigning the FPGA trial events ensures the last trial has the correct timestamps. + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + + if np.any(np.diff(ibpod) != 1) and self.sync_field == 'intervals_0': + # One issue is that sometimes pulses may not have been detected, in this case + # add the events that have not been detected and re-extract the behaviour sync. + # This is only really relevant for the Bpod interval events as the other TTLs are + # from devices where a missing TTL likely means the Bpod event was truly absent. + _logger.warning('Missing Bpod TTLs; reassigning events using aligned Bpod start times') + bpod_start = self.bpod_trials['intervals'][:, 0] + missing_bpod = self.bpod2fpga(bpod_start[np.setxor1d(ibpod, np.arange(len(bpod_start)))]) + t_trial_start = np.sort(np.r_[fpga_events['intervals_0'][:, 0], missing_bpod]) + else: + t_trial_start = fpga_events['intervals_0'] + + # Assign the FPGA events to individual trials + fpga_trials = { + 'goCue_times': _assign_events_to_trial(t_trial_start, fpga_events['goCue_times'], take='first'), + 'errorCue_times': _assign_events_to_trial(t_trial_start, fpga_events['errorCue_times']), + 'valveOpen_times': _assign_events_to_trial(t_trial_start, fpga_events['valveOpen_times']), + 'itiIn_times': _assign_events_to_trial(t_trial_start, fpga_events['itiIn_times']), + 'stimFreeze_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times'], take=-2), + 'stimOn_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times'], take='first'), + 'stimOff_times': _assign_events_to_trial(t_trial_start, self.frame2ttl['times']) + } + + # Feedback times are valve open on correct trials and error tone in on incorrect trials + fpga_trials['feedback_times'] = np.copy(fpga_trials['valveOpen_times']) + ind_err = np.isnan(fpga_trials['valveOpen_times']) + fpga_trials['feedback_times'][ind_err] = fpga_trials['errorCue_times'][ind_err] + + out = alfio.AlfBunch() + # Add the Bpod trial events, converting the timestamp fields to FPGA time. + # NB: The trial intervals are by default a Bpod rsync field. + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + out.update({k: fpga_trials[k][ifpga] for k in fpga_trials.keys()}) + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(fpga_events.items(), cycle(color_map)): + plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 5]) + + return out + def get_wheel_positions(self, *args, **kwargs): """Extract wheel and wheelMoves objects. @@ -875,6 +1144,493 @@ def get_wheel_positions(self, *args, **kwargs): """ return get_wheel_positions(*args, **kwargs) + def get_stimulus_update_times(self, sync, chmap, display=False, **_): + """ + Extract stimulus update times from sync. + + Gets the stimulus times from the frame2ttl channel and cleans the signal. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + chmap : dict + A map of channel names and their corresponding indices. Must contain a 'frame2ttl' key. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing stimulus TTL fronts. + """ + frame2ttl = get_sync_fronts(sync, chmap['frame2ttl']) + frame2ttl = _clean_frame2ttl(frame2ttl, display=display) + return frame2ttl + + def get_audio_event_times(self, sync, chmap, audio_event_ttls=None, display=False, **_): + """ + Extract audio times from sync. + + Gets the TTL times from the 'audio' channel, cleans the signal, and classifies each TTL + event by length. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + chmap : dict + A map of channel names and their corresponding indices. Must contain an 'audio' key. + audio_event_ttls : dict + A map of event names to (min, max) TTL length. + display : bool + If true, plots the input TTLs and the cleaned output. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing audio TTL fronts. + dict + A dictionary of events (from `audio_event_ttls`) and their intervals as an Nx2 array. + """ + audio = get_sync_fronts(sync, chmap['audio']) + audio = _clean_audio(audio) + + if audio['times'].size == 0: + _logger.error('No audio sync fronts found.') + + if audio_event_ttls is None: + # For training/biased/ephys protocols, the ready tone should be below 110 ms. The error + # tone should be between 400ms and 1200ms + audio_event_ttls = {'ready_tone': (0, 0.11), 'error_tone': (0.4, 1.2)} + audio_event_intervals = self._assign_events(audio['times'], audio['polarities'], audio_event_ttls, display=display) + + return audio, audio_event_intervals + + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Gets the Bpod TTL times from the sync 'bpod' channel and classifies each TTL event by + length. NB: The first trial has an abnormal trial_start TTL that is usually mis-assigned. + This method accounts for this. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + bpod = get_sync_fronts(sync, chmap['bpod']) + if bpod.times.size == 0: + raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' + 'Check channel maps.') + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # For training/biased/ephys protocols, the trial start TTL length is 0.1ms but this has + # proven to drift on some Bpods and this is the highest possible value that + # discriminates trial start from valve. Valve open events are between 50ms to 300 ms. + # ITI events are above 400 ms. + bpod_event_ttls = { + 'trial_start': (0, 2.33e-4), 'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} + bpod_event_intervals = self._assign_events( + bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + + if 'trial_start' not in bpod_event_intervals or bpod_event_intervals['trial_start'].size == 0: + return bpod, bpod_event_intervals + + # The first trial pulse is longer and often assigned to another event. + # Here we move the earliest non-trial_start event to the trial_start array. + t0 = bpod_event_intervals['trial_start'][0, 0] # expect 1st event to be trial_start + pretrial = [(k, v[0, 0]) for k, v in bpod_event_intervals.items() if v.size and v[0, 0] < t0] + if pretrial: + (pretrial, _) = sorted(pretrial, key=lambda x: x[1])[0] # take the earliest event + dt = np.diff(bpod_event_intervals[pretrial][0, :]) * 1e3 # record TTL length to log + _logger.debug('Reassigning first %s to trial_start. TTL length = %.3g ms', pretrial, dt) + bpod_event_intervals['trial_start'] = np.r_[ + bpod_event_intervals[pretrial][0:1, :], bpod_event_intervals['trial_start'] + ] + bpod_event_intervals[pretrial] = bpod_event_intervals[pretrial][1:, :] + + return bpod, bpod_event_intervals + + @staticmethod + def _assign_events(ts, polarities, event_lengths, precedence='shortest', display=False): + """ + Classify TTL events by length. + + Outputs the synchronisation events such as trial intervals, valve opening, and audio. + + Parameters + ---------- + ts : numpy.array + Numpy vector containing times of TTL fronts. + polarities : numpy.array + Numpy vector containing polarity of TTL fronts (1 rise, -1 fall). + event_lengths : dict of tuple + A map of TTL events and the range of permissible lengths, where l0 < ttl <= l1. + precedence : str {'shortest', 'longest', 'dict order'} + In the case of overlapping event TTL lengths, assign shortest/longest first or go by + the `event_lengths` dict order. + display : bool + If true, plots the TTLs with coloured lines delineating the assigned events. + + Returns + ------- + Dict[str, numpy.array] + A dictionary of events and their intervals as an Nx2 array. + + See Also + -------- + _assign_events_to_trial - classify TTLs by event order within a given trial period. + """ + event_intervals = dict.fromkeys(event_lengths) + assert 'unassigned' not in event_lengths.keys() + + if len(ts) == 0: + return {k: np.array([[], []]).T for k in (*event_lengths.keys(), 'unassigned')} + + # make sure that there are no 2 consecutive fall or consecutive rise events + assert np.all(np.abs(np.diff(polarities)) == 2) + if polarities[0] == -1: + ts = np.delete(ts, 0) + if polarities[-1] == 1: # if the final TTL is left HIGH, insert a NaN + ts = np.r_[ts, np.nan] + # take only even time differences: i.e. from rising to falling fronts + dt = np.diff(ts)[::2] + + # Assign events from shortest TTL to largest + assigned = np.zeros(ts.shape, dtype=bool) + if precedence.lower() == 'shortest': + event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1])) + elif precedence.lower() == 'longest': + event_items = sorted(event_lengths.items(), key=lambda x: np.diff(x[1]), reverse=True) + elif precedence.lower() == 'dict order': + event_items = event_lengths.items() + else: + raise ValueError(f'Precedence must be one of "shortest", "longest", "dict order", got "{precedence}".') + for event, (min_len, max_len) in event_items: + _logger.debug('%s: %.4G < ttl <= %.4G', event, min_len, max_len) + i_event = np.where(np.logical_and(dt > min_len, dt <= max_len))[0] * 2 + i_event = i_event[np.where(~assigned[i_event])[0]] # remove those already assigned + event_intervals[event] = np.c_[ts[i_event], ts[i_event + 1]] + assigned[np.r_[i_event, i_event + 1]] = True + + # Include the unassigned events for convenience and debugging + event_intervals['unassigned'] = ts[~assigned].reshape(-1, 2) + + # Assert that event TTLs mutually exclusive + all_assigned = np.concatenate(list(event_intervals.values())).flatten() + assert all_assigned.size == np.unique(all_assigned).size, 'TTLs assigned to multiple events' + + # some debug plots when needed + if display: # pragma: no cover + plt.figure() + plots.squares(ts, polarities, label='raw fronts') + for event, intervals in event_intervals.items(): + plots.vertical_lines(intervals[:, 0], ymin=-0.2, ymax=1.1, linewidth=0.5, label=event) + plt.legend() + + # Return map of event intervals in the same order as `event_lengths` dict + return {k: event_intervals[k] for k in (*event_lengths, 'unassigned')} + + @staticmethod + def sync_bpod_clock(bpod_trials, fpga_trials, sync_field): + """ + Sync the Bpod clock to FPGA one using the provided trial event. + + It assumes that `sync_field` is in both `fpga_trials` and `bpod_trials`. Syncing on both + intervals is not supported so to sync on trial start times, `sync_field` should be + 'intervals_0'. + + Parameters + ---------- + bpod_trials : dict + A dictionary of extracted Bpod trial events. + fpga_trials : dict + A dictionary of TTL events extracted from FPGA sync (see `extract_behaviour_sync` + method). + sync_field : str + The trials key to use for syncing clocks. For intervals (i.e. Nx2 arrays) append the + column index, e.g. 'intervals_0'. + + Returns + ------- + function + Interpolation function such that f(timestamps_bpod) = timestamps_fpga. + float + The clock drift in parts per million. + numpy.array of int + The indices of the Bpod trial events in the FPGA trial events array. + numpy.array of int + The indices of the FPGA trial events in the Bpod trial events array. + + Raises + ------ + ValueError + The key `sync_field` was not found in either the `bpod_trials` or `fpga_trials` dicts. + """ + _logger.info(f'Attempting to align Bpod clock to DAQ using trial event "{sync_field}"') + bpod_fpga_timestamps = [None, None] + for i, trials in enumerate((bpod_trials, fpga_trials)): + if sync_field not in trials: + # handle syncing on intervals + if not (m := re.match(r'(.*)_(\d)', sync_field)): + # If missing from bpod trials, either the sync field is incorrect, + # or the Bpod extractor is incorrect. If missing from the fpga events, check + # the sync field and the `extract_behaviour_sync` method. + raise ValueError( + f'Sync field "{sync_field}" not in extracted {"fpga" if i else "bpod"} events') + _sync_field, n = m.groups() + bpod_fpga_timestamps[i] = trials[_sync_field][:, int(n)] + else: + bpod_fpga_timestamps[i] = trials[sync_field] + + # Sync the two timestamps + fcn, drift, ibpod, ifpga = neurodsp.utils.sync_timestamps(*bpod_fpga_timestamps, return_indices=True) + + # If it's drifting too much throw warning or error + _logger.info('N trials: %i bpod, %i FPGA, %i merged, sync %.5f ppm', + *map(len, bpod_fpga_timestamps), len(ibpod), drift) + if drift > 200 and bpod_fpga_timestamps[0].size != bpod_fpga_timestamps[1].size: + raise err.SyncBpodFpgaException('sync cluster f*ck') + elif drift > BPOD_FPGA_DRIFT_THRESHOLD_PPM: + _logger.warning('BPOD/FPGA synchronization shows values greater than %.2f ppm', + BPOD_FPGA_DRIFT_THRESHOLD_PPM) + + return fcn, drift, ibpod, ifpga + + +class FpgaTrialsHabituation(FpgaTrials): + """Extract habituationChoiceWorld trial events from an NI DAQ.""" + + save_names = ('_ibl_trials.stimCenter_times.npy', '_ibl_trials.feedbackType.npy', '_ibl_trials.rewardVolume.npy', + '_ibl_trials.stimOff_times.npy', '_ibl_trials.contrastLeft.npy', '_ibl_trials.contrastRight.npy', + '_ibl_trials.feedback_times.npy', '_ibl_trials.stimOn_times.npy', '_ibl_trials.stimOnTrigger_times.npy', + '_ibl_trials.intervals.npy', '_ibl_trials.goCue_times.npy', '_ibl_trials.goCueTrigger_times.npy', + None, None, None, None, None) + """tuple of str: The filenames of each extracted dataset, or None if array should not be saved.""" + + var_names = ('stimCenter_times', 'feedbackType', 'rewardVolume', 'stimOff_times', 'contrastLeft', + 'contrastRight', 'feedback_times', 'stimOn_times', 'stimOnTrigger_times', 'intervals', + 'goCue_times', 'goCueTrigger_times', 'itiIn_times', 'stimOffTrigger_times', + 'stimCenterTrigger_times', 'position', 'phase') + """tuple of str: A list of names for the extracted variables. These become the returned output keys.""" + + bpod_rsync_fields = ('intervals', 'stimOn_times', 'feedback_times', 'stimCenterTrigger_times', + 'goCue_times', 'itiIn_times', 'stimOffTrigger_times', 'stimOff_times', + 'stimCenter_times', 'stimOnTrigger_times', 'goCueTrigger_times') + """tuple of str: Fields from Bpod extractor that we want to re-sync to FPGA.""" + + bpod_fields = ('feedbackType', 'rewardVolume', 'contrastLeft', 'contrastRight', 'position', 'phase') + """tuple of str: Fields from Bpod extractor that we want to save.""" + + sync_field = 'feedback_times' # valve open events + """str: The trial event to synchronize (must be present in extracted trials).""" + + def _extract(self, sync=None, chmap=None, sync_collection='raw_ephys_data', + task_collection='raw_behavior_data', **kwargs) -> dict: + """ + Extract habituationChoiceWorld trial events from an NI DAQ. + + It is essential that the `var_names`, `bpod_rsync_fields`, `bpod_fields`, and `sync_field` + attributes are all correct for the bpod protocol used. + + Unlike FpgaTrials, this class assumes different Bpod TTL events and syncs the Bpod clock + using the valve open times, instead of the trial start times. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. If None, the sync is loaded using the + `load_sync` method. + dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the `load_sync` method. + sync_collection : str + The session subdirectory where the sync data are located. This is only used if the + sync or channel maps are not provided. + task_collection : str + The session subdirectory where the raw Bpod data are located. This is used for loading + the task settings and extracting the bpod trials, if not already done. + protocol_number : int + The protocol number if multiple protocols were run during the session. If provided, a + spacer signal must be present in order to determine the correct period. + kwargs + Optional arguments for class methods, e.g. 'display', 'bpod_event_ttls'. + + Returns + ------- + dict + A dictionary of numpy arrays with `FpgaTrialsHabituation.var_names` as keys. + """ + # Version check: the ITI in TTL was added in a later version + if not self.settings: + self.settings = raw.load_settings(session_path=self.session_path, task_collection=task_collection) + iblrig_version = version.parse(self.settings.get('IBL_VERSION', '0.0.0')) + if version.parse('8.9.3') <= iblrig_version < version.parse('8.12.6'): + """A second 1s TTL was added in this version during the 'iti' state, however this is + unrelated to the trial ITI and is unfortunately the same length as the trial start TTL.""" + raise NotImplementedError('Ambiguous TTLs in 8.9.3 >= version < 8.12.6') + + trials = super()._extract(sync=sync, chmap=chmap, sync_collection=sync_collection, + task_collection=task_collection, **kwargs) + + return trials + + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse. + Also the first trial pulse is incorrectly assigned due to its abnormal length. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + bpod = get_sync_fronts(sync, chmap['bpod']) + if bpod.times.size == 0: + raise err.SyncBpodFpgaException('No Bpod event found in FPGA. No behaviour extraction. ' + 'Check channel maps.') + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # Currently (at least v8.12 and below) there is no trial start or end TTL, only an ITI pulse + bpod_event_ttls = {'trial_iti': (1, 1.1), 'valve_open': (0, 0.4)} + bpod_event_intervals = self._assign_events( + bpod['times'], bpod['polarities'], bpod_event_ttls, display=display) + + # The first trial pulse is shorter and assigned to valve_open. Here we remove the first + # valve event, prepend a 0 to the trial_start events, and drop the last trial if it was + # incomplete in Bpod. + bpod_event_intervals['trial_iti'] = np.r_[bpod_event_intervals['valve_open'][0:1, :], + bpod_event_intervals['trial_iti']] + bpod_event_intervals['valve_open'] = bpod_event_intervals['valve_open'][1:, :] + + return bpod, bpod_event_intervals + + def build_trials(self, sync, chmap, display=False, **kwargs): + """ + Extract task related event times from the sync. + + This is called by the superclass `_extract` method. The key difference here is that the + `trial_start` LOW->HIGH is the trial end, and HIGH->LOW is trial start. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + display : bool, matplotlib.pyplot.Axes + Show the full session sync pulses display. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_iti'}: + raise ValueError( + 'Expected at least "trial_iti" and "valve_open" Bpod events. `bpod_event_ttls` kwarg may be incorrect.') + + fpga_events = alfio.AlfBunch({ + 'feedback_times': bpod_event_intervals['valve_open'][:, 0], + 'valveClose_times': bpod_event_intervals['valve_open'][:, 1], + 'intervals_0': bpod_event_intervals['trial_iti'][:, 1], + 'intervals_1': bpod_event_intervals['trial_iti'][:, 0], + 'goCue_times': audio_event_intervals['ready_tone'][:, 0] + }) + n_trials = self.bpod_trials['intervals'].shape[0] + + # Sync the Bpod clock to the DAQ. + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + + out = alfio.AlfBunch() + # Add the Bpod trial events, converting the timestamp fields to FPGA time. + # NB: The trial intervals are by default a Bpod rsync field. + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + + # Assigning each event to a trial ensures exactly one event per trial (missing events are NaN) + assign_to_trial = partial(_assign_events_to_trial, fpga_events['intervals_0']) + trials = alfio.AlfBunch({ + 'goCue_times': assign_to_trial(fpga_events['goCue_times'], take='first')[:n_trials], + 'feedback_times': assign_to_trial(fpga_events['feedback_times'])[:n_trials], + 'stimCenter_times': assign_to_trial(self.frame2ttl['times'], take=-2)[:n_trials], + 'stimOn_times': assign_to_trial(self.frame2ttl['times'], take='first')[:n_trials], + 'stimOff_times': assign_to_trial(self.frame2ttl['times'])[:n_trials], + }) + + # If stim on occurs before trial end, use stim on time. Likewise for trial end and stim off + to_correct = ~np.isnan(trials['stimOn_times']) & (trials['stimOn_times'] < out['intervals'][:, 0]) + if np.any(to_correct): + _logger.warning('%i/%i stim on events occurring outside trial intervals', sum(to_correct), len(to_correct)) + out['intervals'][to_correct, 0] = trials['stimOn_times'][to_correct] + to_correct = ~np.isnan(trials['stimOff_times']) & (trials['stimOff_times'] > out['intervals'][:, 1]) + if np.any(to_correct): + _logger.debug( + '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', + sum(to_correct), len(to_correct)) + out['intervals'][to_correct, 1] = trials['stimOff_times'][to_correct] + + out.update({k: trials[k][ifpga] for k in trials.keys()}) + + if display: # pragma: no cover + width = 0.5 + ymax = 5 + if isinstance(display, bool): + plt.figure('Bpod FPGA Sync') + ax = plt.gca() + else: + ax = display + plots.squares(self.bpod['times'], self.bpod['polarities'] * 0.4 + 1, ax=ax, color='k') + plots.squares(self.frame2ttl['times'], self.frame2ttl['polarities'] * 0.4 + 2, ax=ax, color='k') + plots.squares(self.audio['times'], self.audio['polarities'] * 0.4 + 3, ax=ax, color='k') + color_map = TABLEAU_COLORS.keys() + for (event_name, event_times), c in zip(trials.to_df().items(), cycle(color_map)): + plots.vertical_lines(event_times, ymin=0, ymax=ymax, ax=ax, color=c, label=event_name, linewidth=width) + ax.legend() + ax.set_yticks([0, 1, 2, 3]) + ax.set_yticklabels(['', 'bpod', 'f2ttl', 'audio']) + ax.set_ylim([0, 4]) + + return out + def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_path=None, task_collection='raw_behavior_data', protocol_number=None, **kwargs): @@ -883,7 +1639,11 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ - sync - wheel - behaviour - - video time stamps + + These `extract_all` functions should be deprecated as they make assumptions about hardware + parameters. Additionally the FpgaTrials class now automatically loads DAQ sync files, extracts + the Bpod trials, and returns a dict instead of a tuple. Therefore this function is entirely + redundant. See the examples for the correct way to extract NI DAQ behaviour sessions. Parameters ---------- @@ -909,23 +1669,31 @@ def extract_all(session_path, sync_collection='raw_ephys_data', save=True, save_ list of pathlib.Path, None If save is True, a list of file paths to the extracted data. """ + warnings.warn( + 'ibllib.io.extractors.ephys_fpga.extract_all will be removed in future versions; ' + 'use FpgaTrials instead. For reliable extraction, use the dynamic pipeline behaviour tasks.', + FutureWarning) + return_extractor = kwargs.pop('return_extractor', False) # Extract Bpod trials bpod_raw = raw.load_data(session_path, task_collection=task_collection) assert bpod_raw is not None, 'No task trials data in raw_behavior_data - Exit' - bpod_trials, *_ = bpod_extract_all( + bpod_trials, bpod_wheel, *_ = bpod_extract_all( session_path=session_path, bpod_trials=bpod_raw, task_collection=task_collection, save=False, extractor_type=kwargs.get('extractor_type')) # Sync Bpod trials to FPGA sync, chmap = get_sync_and_chn_map(session_path, sync_collection) # sync, chmap = get_main_probe_sync(session_path, bin_exists=bin_exists) - trials = FpgaTrials(session_path, bpod_trials=bpod_trials) + trials = FpgaTrials(session_path, bpod_trials=bpod_trials | bpod_wheel) outputs, files = trials.extract( save=save, sync=sync, chmap=chmap, path_out=save_path, task_collection=task_collection, protocol_number=protocol_number, **kwargs) if not isinstance(outputs, dict): outputs = {k: v for k, v in zip(trials.var_names, outputs)} - return outputs, files + if return_extractor: + return outputs, files, trials + else: + return outputs, files def get_sync_and_chn_map(session_path, sync_collection): diff --git a/ibllib/io/extractors/habituation_trials.py b/ibllib/io/extractors/habituation_trials.py index 9dedbd3d5..655ea2de1 100644 --- a/ibllib/io/extractors/habituation_trials.py +++ b/ibllib/io/extractors/habituation_trials.py @@ -1,12 +1,11 @@ +"""Habituation ChoiceWorld Bpod trials extraction.""" import logging import numpy as np import ibllib.io.raw_data_loaders as raw from ibllib.io.extractors.base import BaseBpodTrialsExtractor, run_extractor_classes from ibllib.io.extractors.biased_trials import ContrastLR -from ibllib.io.extractors.training_trials import ( - FeedbackTimes, StimOnTriggerTimes, Intervals, GoCueTimes -) +from ibllib.io.extractors.training_trials import FeedbackTimes, StimOnTriggerTimes, GoCueTimes _logger = logging.getLogger(__name__) @@ -24,9 +23,24 @@ def __init__(self, *args, **kwargs): self.save_names = tuple(f'_ibl_trials.{x}.npy' if x not in exclude else None for x in self.var_names) def _extract(self) -> dict: + """ + Extract the Bpod trial events. + + The Bpod state machine for this task has extremely misleading names! The 'iti' state is + actually the delay between valve open and trial end (the stimulus is still present during + this period), and the 'trial_start' state is actually the ITI during which there is a 1s + Bpod TTL and gray screen period. + + Returns + ------- + dict + A dictionary of Bpod trial events. The keys are defined in the `var_names` attribute. + """ # Extract all trials... - # Get all stim_sync events detected + # Get all detected TTLs. These are stored for QC purposes + self.frame2ttl, self.audio = raw.load_bpod_fronts(self.session_path, data=self.bpod_trials) + # These are the frame2TTL pulses as a list of lists, one per trial ttls = [raw.get_port_events(tr, 'BNC1') for tr in self.bpod_trials] # Report missing events @@ -38,10 +52,49 @@ def _extract(self) -> dict: _logger.warning(f'{self.session_path}: Missing BNC1 TTLs on {n_missing} trial(s)') # Extract datasets common to trainingChoiceWorld - training = [ContrastLR, FeedbackTimes, Intervals, GoCueTimes, StimOnTriggerTimes] + training = [ContrastLR, FeedbackTimes, GoCueTimes, StimOnTriggerTimes] out, _ = run_extractor_classes(training, session_path=self.session_path, save=False, bpod_trials=self.bpod_trials, settings=self.settings, task_collection=self.task_collection) + """ + The 'trial_start' state is in fact the 1s grey screen period, therefore the first timestamp + is really the end of the previous trial and also the stimOff trigger time. The second + timestamp is the true trial start time. + """ + (_, *ends), starts = zip(*[ + t['behavior_data']['States timestamps']['trial_start'][-1] for t in self.bpod_trials] + ) + + # StimOffTrigger times + out['stimOffTrigger_times'] = np.array(ends) + + # StimOff times + """ + There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. + If 1 or more pulses are missing, we can not be confident of assigning the correct one. + """ + out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan for sync in ttls[1:]]) + + # Trial intervals + """ + In terms of TTLs, the intervals are defined by the 'trial_start' state, however the stim + off time often happens after the trial end TTL front, i.e. after the 'trial_start' start + begins. For these trials, we set the trial end time as the stim off time. + """ + # NB: We lose the last trial because the stim off event occurs at trial_num + 1 + n_trials = out['stimOff_times'].size + out['intervals'] = np.c_[starts, np.r_[ends, np.nan]][:n_trials, :] + + to_correct = ~np.isnan(out['stimOff_times']) & (out['stimOff_times'] > out['intervals'][:, 1]) + if np.any(to_correct): + _logger.debug( + '%i/%i stim off events occurring outside trial intervals; using stim off times as trial end', + sum(to_correct), len(to_correct)) + out['intervals'][to_correct, 1] = out['stimOff_times'][to_correct] + + # itiIn times + out['itiIn_times'] = np.r_[ends, np.nan] + # GoCueTriggerTimes is the same event as StimOnTriggerTimes out['goCueTrigger_times'] = out['stimOnTrigger_times'].copy() @@ -75,38 +128,19 @@ def _extract(self) -> dict: trial_volume = [x['reward_amount'] for x in self.bpod_trials] out['rewardVolume'] = np.array(trial_volume).astype(np.float64) - # StimOffTrigger times - # StimOff occurs at trial start (ignore the first trial's state update) - out['stimOffTrigger_times'] = np.array( - [tr["behavior_data"]["States timestamps"] - ["trial_start"][0][0] for tr in self.bpod_trials[1:]] - ) - - # StimOff times - """ - There should be exactly three TTLs per trial. stimOff_times should be the first TTL pulse. - If 1 or more pulses are missing, we can not be confident of assigning the correct one. - """ - trigg = out['stimOffTrigger_times'] - out['stimOff_times'] = np.array([sync[0] if len(sync) == 3 else np.nan - for sync, off in zip(ttls[1:], trigg)]) - # FeedbackType is always positive out['feedbackType'] = np.ones(len(out['feedback_times']), dtype=np.int8) - # ItiIn times - out['itiIn_times'] = np.array( - [tr["behavior_data"]["States timestamps"] - ["iti"][0][0] for tr in self.bpod_trials] - ) - # Phase and position out['position'] = np.array([t['position'] for t in self.bpod_trials]) out['phase'] = np.array([t['stim_phase'] for t in self.bpod_trials]) - # NB: We lose the last trial because the stim off event occurs at trial_num + 1 - n_trials = out['stimOff_times'].size - # return [out[k][:n_trials] for k in self.var_names] + # Double-check that the early and late trial events occur within the trial intervals + idx = ~np.isnan(out['stimOn_times'][:n_trials]) + assert not np.any(out['stimOn_times'][:n_trials][idx] < out['intervals'][idx, 0]), \ + 'Stim on events occurring outside trial intervals' + + # Truncate arrays and return in correct order return {k: out[k][:n_trials] for k in self.var_names} diff --git a/ibllib/io/extractors/mesoscope.py b/ibllib/io/extractors/mesoscope.py index 4def5ed3a..e4ca6766b 100644 --- a/ibllib/io/extractors/mesoscope.py +++ b/ibllib/io/extractors/mesoscope.py @@ -2,21 +2,21 @@ import logging import numpy as np +from scipy.signal import find_peaks import one.alf.io as alfio from one.util import ensure_list from one.alf.files import session_path_parts import matplotlib.pyplot as plt -from neurodsp.utils import falls from pkg_resources import parse_version from ibllib.plots.misc import squares, vertical_lines from ibllib.io.raw_daq_loaders import (extract_sync_timeline, timeline_get_channel, correct_counter_discontinuities, load_timeline_sync_and_chmap) import ibllib.io.extractors.base as extractors_base -from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, get_sync_fronts, get_protocol_period +from ibllib.io.extractors.ephys_fpga import FpgaTrials, WHEEL_TICKS, WHEEL_RADIUS_CM, _assign_events_to_trial from ibllib.io.extractors.training_wheel import extract_wheel_moves from ibllib.io.extractors.camera import attribute_times -from ibllib.io.extractors.ephys_fpga import _assign_events_bpod +from brainbox.behavior.wheel import velocity_filtered _logger = logging.getLogger(__name__) @@ -102,103 +102,246 @@ def plot_timeline(timeline, channels=None, raw=True): class TimelineTrials(FpgaTrials): """Similar extraction to the FPGA, however counter and position channels are treated differently.""" - """one.alf.io.AlfBunch: The timeline data object""" timeline = None + """one.alf.io.AlfBunch: The timeline data object.""" + + sync_field = 'itiIn_times' # trial start events + """str: The trial event to synchronize (must be present in extracted trials).""" def __init__(self, *args, sync_collection='raw_sync_data', **kwargs): """An extractor for all ephys trial data, in Timeline time""" super().__init__(*args, **kwargs) self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') - def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: - if not (sync or chmap): - sync, chmap = load_timeline_sync_and_chmap( - self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) + def load_sync(self, sync_collection='raw_sync_data', chmap=None, **_): + """Load the DAQ sync and channel map data. + + Parameters + ---------- + sync_collection : str + The session subdirectory where the sync data are located. + chmap : dict + A map of channel names and their corresponding indices. If None, the channel map is + loaded using the :func:`ibllib.io.raw_daq_loaders.timeline_meta2chmap` method. + + Returns + ------- + one.alf.io.AlfBunch + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. + dict + A map of channel names and their corresponding indices. + """ + if not self.timeline: + self.timeline = alfio.load_object(self.session_path / sync_collection, 'DAQdata', namespace='timeline') + sync, chmap = load_timeline_sync_and_chmap( + self.session_path / sync_collection, timeline=self.timeline, chmap=chmap) + return sync, chmap + def _extract(self, sync=None, chmap=None, sync_collection='raw_sync_data', **kwargs) -> dict: + trials = super()._extract(sync, chmap, sync_collection='raw_sync_data', **kwargs) if kwargs.get('display', False): plot_timeline(self.timeline, channels=chmap.keys(), raw=True) - trials = super()._extract(sync, chmap, sync_collection, extractor_type='ephys', **kwargs) - - # If no protocol number is defined, trim timestamps based on Bpod trials intervals - trials_table = trials['table'] - bpod = get_sync_fronts(sync, chmap['bpod']) - if kwargs.get('protocol_number') is None: - tmin = trials_table.intervals_0.iloc[0] - 1 - tmax = trials_table.intervals_1.iloc[-1] - # Ensure wheel is cut off based on trials - mask = np.logical_and(tmin <= trials['wheel_timestamps'], trials['wheel_timestamps'] <= tmax) - trials['wheel_timestamps'] = trials['wheel_timestamps'][mask] - trials['wheel_position'] = trials['wheel_position'][mask] - mask = np.logical_and(trials['wheelMoves_intervals'][:, 0] >= tmin, trials['wheelMoves_intervals'][:, 0] <= tmax) - trials['wheelMoves_intervals'] = trials['wheelMoves_intervals'][mask, :] + return trials + + def get_bpod_event_times(self, sync, chmap, bpod_event_ttls=None, display=False, **kwargs): + """ + Extract Bpod times from sync. + + Unlike the superclass method. This one doesn't reassign the first trial pulse. + + Parameters + ---------- + sync : dict + A dictionary with keys ('times', 'polarities', 'channels'), containing the sync pulses + and the corresponding channel numbers. Must contain a 'bpod' key. + chmap : dict + A map of channel names and their corresponding indices. + bpod_event_ttls : dict of tuple + A map of event names to (min, max) TTL length. + + Returns + ------- + dict + A dictionary with keys {'times', 'polarities'} containing Bpod TTL fronts. + dict + A dictionary of events (from `bpod_event_ttls`) and their intervals as an Nx2 array. + """ + # Assign the Bpod BNC2 events based on TTL length. The defaults are below, however these + # lengths are defined by the state machine of the task protocol and therefore vary. + if bpod_event_ttls is None: + # The trial start TTLs are often too short for the low sampling rate of the DAQ and are + # therefore not used in extraction + bpod_event_ttls = {'valve_open': (2.33e-4, 0.4), 'trial_end': (0.4, np.inf)} + bpod, bpod_event_intervals = super().get_bpod_event_times( + sync=sync, chmap=chmap, bpod_event_ttls=bpod_event_ttls, display=display, **kwargs) + + # TODO Here we can make use of the 'bpod_rising_edge' channel, if available + return bpod, bpod_event_intervals + + def build_trials(self, sync=None, chmap=None, **kwargs): + """ + Extract task related event times from the sync. + + The two major differences are that the sampling rate is lower for imaging so the short Bpod + trial start TTLs are often absent. For this reason, the sync happens using the ITI_in TTL. + + Second, the valve used at the mesoscope has a way to record the raw voltage across the + solenoid, giving a more accurate readout of the valve's activity. If the reward_valve + channel is present on the DAQ, this is used to extract the valve open times. + + Parameters + ---------- + sync : dict + 'polarities' of fronts detected on sync trace for all 16 chans and their 'times' + chmap : dict + Map of channel names and their corresponding index. Default to constant. + + Returns + ------- + dict + A map of trial event timestamps. + """ + # Get the events from the sync. + # Store the cleaned frame2ttl, audio, and bpod pulses as this will be used for QC + self.frame2ttl = self.get_stimulus_update_times(sync, chmap, **kwargs) + self.audio, audio_event_intervals = self.get_audio_event_times(sync, chmap, **kwargs) + if not set(audio_event_intervals.keys()) >= {'ready_tone', 'error_tone'}: + raise ValueError( + 'Expected at least "ready_tone" and "error_tone" audio events.' + '`audio_event_ttls` kwarg may be incorrect.') + + self.bpod, bpod_event_intervals = self.get_bpod_event_times(sync, chmap, **kwargs) + if not set(bpod_event_intervals.keys()) >= {'valve_open', 'trial_end'}: + raise ValueError( + 'Expected at least "trial_end" and "valve_open" audio events. ' + '`bpod_event_ttls` kwarg may be incorrect.') + + t_iti_in, t_trial_end = bpod_event_intervals['trial_end'].T + fpga_events = alfio.AlfBunch({ + 'itiIn_times': t_iti_in, + 'intervals_1': t_trial_end, + 'goCue_times': audio_event_intervals['ready_tone'][:, 0], + 'errorTone_times': audio_event_intervals['error_tone'][:, 0] + }) + + # Sync the Bpod clock to the DAQ + self.bpod2fpga, drift_ppm, ibpod, ifpga = self.sync_bpod_clock(self.bpod_trials, fpga_events, self.sync_field) + + out = dict() + out.update({k: self.bpod_trials[k][ibpod] for k in self.bpod_fields}) + out.update({k: self.bpod2fpga(self.bpod_trials[k][ibpod]) for k in self.bpod_rsync_fields}) + + start_times = out['intervals'][:, 0] + last_trial_end = out['intervals'][-1, 1] + + def assign_to_trial(events, take='last'): + """Assign DAQ events to trials. + + Because we may not have trial start TTLs on the DAQ (because of the low sampling rate), + there may be an extra last trial that's not in the Bpod intervals as the extractor + ignores the last trial. This function trims the input array before assigning so that + the last trial's events are correctly assigned. + """ + return _assign_events_to_trial(start_times, events[events <= last_trial_end], take) + out['itiIn_times'] = assign_to_trial(fpga_events['itiIn_times'][ifpga]) + + # Extract valve open times from the DAQ + valve_driver_ttls = bpod_event_intervals['valve_open'] + correct = self.bpod_trials['feedbackType'] == 1 + # If there is a reward_valve channel, the valve has + if any(ch['name'] == 'reward_valve' for ch in self.timeline['meta']['inputs']): + # TODO Let's look at the expected open length based on calibration and reward volume + # import scipy.interpolate + # # FIXME support v7 settings? + # fcn_vol2time = scipy.interpolate.pchip( + # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_WEIGHT_PERDROP'], + # self.bpod_extractor.settings['device_valve']['WATER_CALIBRATION_OPEN_TIMES'] + # ) + # reward_time = fcn_vol2time(self.bpod_extractor.settings.get('REWARD_AMOUNT_UL')) / 1e3 + + # Use the driver TTLs to find the valve open times that correspond to the valve opening + valve_intervals, valve_open_times = self.get_valve_open_times(driver_ttls=valve_driver_ttls) + if valve_open_times.size != np.sum(correct): + _logger.warning( + 'Number of valve open times does not equal number of correct trials (%i != %i)', + valve_open_times.size, np.sum(correct)) + + out['valveOpen_times'] = assign_to_trial(valve_open_times) else: - tmin, tmax = get_protocol_period(self.session_path, kwargs['protocol_number'], bpod) - bpod = get_sync_fronts(sync, chmap['bpod'], tmin, tmax) - - self.frame2ttl = get_sync_fronts(sync, chmap['frame2ttl'], tmin, tmax) # save for later access by QC - - # Replace valve open times with those extracted from the DAQ - # TODO Let's look at the expected open length based on calibration and reward volume - assert len(bpod['times']) > 0, 'No Bpod TTLs detected on DAQ' - _, driver_out, _, = _assign_events_bpod(bpod['times'], bpod['polarities'], False) - # Use the driver TTLs to find the valve open times that correspond to the valve opening - valve_open_times = self.get_valve_open_times(driver_ttls=driver_out) - assert len(valve_open_times) == sum(trials_table.feedbackType == 1) # TODO Relax assertion - correct = trials_table.feedbackType == 1 - trials['valveOpen_times'][correct] = valve_open_times - trials_table.feedback_times[correct] = valve_open_times - - # Replace audio events - self.audio = get_sync_fronts(sync, chmap['audio'], tmin, tmax) - # Attempt to assign the go cue and error tone onsets based on TTL length - go_cue, error_cue = self._assign_events_audio(self.audio['times'], self.audio['polarities']) - - assert error_cue.size == np.sum(~correct), 'N detected error tones does not match number of incorrect trials' - assert go_cue.size <= len(trials_table), 'More go cue tones detected than trials!' - - if go_cue.size < len(trials_table): - _logger.warning('%i go cue tones missed', len(trials_table) - go_cue.size) + # Use the valve controller TTLs recorded on the Bpod channel as the reward time + out['valveOpen_times'] = assign_to_trial(valve_driver_ttls[:, 0]) + + # Stimulus times extracted the same as usual + out['stimFreeze_times'] = assign_to_trial(self.frame2ttl['times'], take=-2) + out['stimOn_times'] = assign_to_trial(self.frame2ttl['times'], take='first') + out['stimOff_times'] = assign_to_trial(self.frame2ttl['times']) + + # Audio times + error_cue = fpga_events['errorTone_times'] + if error_cue.size != np.sum(~correct): + _logger.warning( + 'N detected error tones does not match number of incorrect trials (%i != %i)', + error_cue.size, np.sum(~correct)) + go_cue = fpga_events['goCue_times'] + out['goCue_times'] = assign_to_trial(go_cue, take='first') + out['errorCue_times'] = assign_to_trial(error_cue) + + if go_cue.size > start_times.size: + _logger.warning( + 'More go cue tones detected than trials! (%i vs %i)', go_cue.size, start_times.size) + elif go_cue.size < start_times.size: """ If the error cues are all assigned and some go cues are missed it may be that some - responses were so fast that the go cue and error tone merged. + responses were so fast that the go cue and error tone merged, or the go cue TTL was too + long. """ + _logger.warning('%i go cue tones missed', start_times.size - go_cue.size) err_trig = self.bpod2fpga(self.bpod_trials['errorCueTrigger_times']) go_trig = self.bpod2fpga(self.bpod_trials['goCueTrigger_times']) assert not np.any(np.isnan(go_trig)) - assert err_trig.size == go_trig.size - - def first_true(arr): - """Return the index of the first True value in an array.""" - indices = np.where(arr)[0] - return None if len(indices) == 0 else indices[0] + assert err_trig.size == go_trig.size # should be length of n trials with NaNs # Find which trials are missing a go cue - _go_cue = np.full(len(trials_table), np.nan) - for i, intervals in enumerate(trials_table[['intervals_0', 'intervals_1']].values): - idx = first_true(np.logical_and(go_cue > intervals[0], go_cue < intervals[1])) - if idx is not None: - _go_cue[i] = go_cue[idx] + _go_cue = assign_to_trial(go_cue, take='first') + error_cue = assign_to_trial(error_cue) + missing = np.isnan(_go_cue) # Get all the DAQ timestamps where audio channel was HIGH raw = timeline_get_channel(self.timeline, 'audio') raw = (raw - raw.min()) / (raw.max() - raw.min()) # min-max normalize ups = self.timeline.timestamps[raw > .5] # timestamps where input HIGH - for i in np.where(np.isnan(_go_cue))[0]: - # Get the timestamp of the first HIGH after the trigger times - _go_cue[i] = ups[first_true(ups > go_trig[i])] - idx = first_true(np.logical_and( - error_cue > trials_table['intervals_0'][i], - error_cue < trials_table['intervals_1'][i])) - if np.isnan(err_trig[i]): - if idx is not None: - error_cue = np.delete(error_cue, idx) # Remove mis-assigned error tone time - else: - error_cue[idx] = ups[first_true(ups > err_trig[i])] - go_cue = _go_cue - - trials_table.feedback_times[~correct] = error_cue - trials_table.goCue_times = go_cue - return {k: trials[k] for k in self.var_names} + + # Get the timestamps of the first HIGH after the trigger times (allow up to 200ms after). + # Indices of ups directly following a go trigger, or -1 if none found (or trigger NaN) + idx = attribute_times(ups, go_trig, tol=0.2, take='after') + # Trial indices that didn't have detected goCue and now has been assigned an `ups` index + assigned = np.where(idx != -1 & missing)[0] # ignore unassigned + _go_cue[assigned] = ups[idx[assigned]] + + # Remove mis-assigned error tone times (i.e. those that have now been assigned to goCue) + error_cue_without_trig, = np.where(~np.isnan(error_cue) & np.isnan(err_trig)) + i_to_remove = np.intersect1d(assigned, error_cue_without_trig, assume_unique=True) + error_cue[i_to_remove] = np.nan + + # For those trials where go cue was merged with the error cue and therefore mis-assigned, + # we must re-assign the error cue times as the first HIGH after the error trigger. + idx = attribute_times(ups, err_trig, tol=0.2, take='after') + assigned = np.where(idx != -1 & missing)[0] # ignore unassigned + error_cue[assigned] = ups[idx[assigned]] + out['goCue_times'] = _go_cue + out['errorCue_times'] = error_cue + + # Because we're not + assert np.intersect1d(out['goCue_times'], out['errorCue_times']).size == 0, \ + 'audio tones not assigned correctly; tones likely missed' + + # Feedback times + out['feedback_times'] = np.copy(out['valveOpen_times']) + ind_err = np.isnan(out['valveOpen_times']) + out['feedback_times'][ind_err] = out['errorCue_times'][ind_err] + + return out def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding='x4', tmin=None, tmax=None): """ @@ -234,7 +377,7 @@ def extract_wheel_sync(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding=' # Timeline evenly samples counter so we extract only change points d = np.diff(raw) - ind, = np.where(d.astype(int)) + ind, = np.where(~np.isclose(d, 0)) pos = raw[ind + 1] pos -= pos[0] # Start from zero pos = pos / ticks * np.pi * 2 * radius / int(coding[1]) # Convert to radians @@ -290,7 +433,7 @@ def get_wheel_positions(self, ticks=WHEEL_TICKS, radius=WHEEL_RADIUS_CM, coding= ax1.set_ylabel('DAQ wheel position / rad'), ax1.set_xlabel('Time / s') return wheel, moves - def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=10, driver_ttls=None): + def get_valve_open_times(self, display=False, threshold=100, driver_ttls=None): """ Get the valve open times from the raw timeline voltage trace. @@ -299,44 +442,82 @@ def get_valve_open_times(self, display=False, threshold=-2.5, floor_percentile=1 display : bool Plot detected times on the raw voltage trace. threshold : float - The threshold for applying to analogue channels. - floor_percentile : float - 10% removes the percentile value of the analog trace before thresholding. This is to - avoid DC offset drift. + The threshold of voltage change to apply. The default was set by eye; units should be + Volts per sample but doesn't appear to be. driver_ttls : numpy.array An optional array of driver TTLs to use for assigning with the valve times. Returns ------- numpy.array - The detected valve open times. - - TODO extract close times too + The detected valve open intervals. + numpy.array + If driver_ttls is not None, returns an array of open times that occurred directly after + the driver TTLs. """ + WARN_THRESH = 10e-3 # open time threshold below which to log warning tl = self.timeline info = next(x for x in tl['meta']['inputs'] if x['name'] == 'reward_valve') values = tl['raw'][:, info['arrayColumn'] - 1] # Timeline indices start from 1 - offset = np.percentile(values, floor_percentile, axis=0) - idx = falls(values - offset, step=threshold) # Voltage falls when valve opens - open_times = tl['timestamps'][idx] + + # The voltage changes over ~1ms and can therefore occur over two DAQ samples at 2kHz + # making simple thresholding an issue. For this reason we convolve the signal with a + # window and detect the peaks and troughs. + if (Fs := tl['meta']['daqSampleRate']) != 2000: # e.g. 2kHz + _logger.warning('Reward valve detection not tested with a DAQ sample rate of %i', Fs) + dt = 1e-3 # change in voltage takes ~1ms when changing valve open state + N = dt / (1 / Fs) # this means voltage change occurs over N samples + vel, _ = velocity_filtered(values, int(Fs / N)) # filtered voltage change over time + ups, _ = find_peaks(vel, height=threshold) # valve closes (-5V -> 0V) + downs, _ = find_peaks(-1 * vel, height=threshold) # valve opens (0V -> -5V) + + # Convert these times into intervals + ixs = np.argsort(np.r_[downs, ups]) # sort indices + times = tl['timestamps'][np.r_[downs, ups]][ixs] # ordered valve event times + polarities = np.r_[np.zeros_like(downs) - 1, np.ones_like(ups)][ixs] # polarity sorted + missing = np.where(np.diff(polarities) == 0)[0] # if some changes were missed insert NaN + times = np.insert(times, missing + int(polarities[0] == -1), np.nan) + if polarities[-1] == -1: # ensure ends with a valve close + times = np.r_[times, np.nan] + if polarities[0] == 1: # ensure starts with a valve open + # It seems it can start out at -5V (open), then when the reward happens it closes and + # immediately opens. In this case we insert discard the first open time. + times = np.r_[np.nan, times] + intervals = times.reshape(-1, 2) + + # Log warning of improbably short intervals + short = np.sum(np.diff(intervals) < WARN_THRESH) + if short > 0: + _logger.warning('%i valve open intervals shorter than %i ms', short, WARN_THRESH) + # The closing of the valve is noisy. Keep only the falls that occur immediately after a Bpod TTL if driver_ttls is not None: # Returns an array of open_times indices, one for each driver TTL - ind = attribute_times(open_times, driver_ttls, tol=.1, take='after') - open_times = open_times[ind[ind >= 0]] + ind = attribute_times(intervals[:, 0], driver_ttls[:, 0], tol=.1, take='after') + open_times = intervals[ind[ind >= 0], 0] # TODO Log any > 40ms? Difficult to report missing valve times because of calibration if display: fig, (ax0, ax1) = plt.subplots(nrows=2, sharex=True) - ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), 'k-o') + ax0.plot(tl['timestamps'], timeline_get_channel(tl, 'bpod'), color='grey', linestyle='-') if driver_ttls is not None: - vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') - ax1.plot(tl['timestamps'], values - offset, 'k-o') + x = np.empty_like(driver_ttls.flatten()) + x[0::2] = driver_ttls[:, 0] + x[1::2] = driver_ttls[:, 1] + y = np.ones_like(x) + y[1::2] -= 2 + squares(x, y, ax=ax0, yrange=[0, 5]) + # vertical_lines(driver_ttls, ymax=5, ax=ax0, linestyle='--', color='b') + ax0.plot(open_times, np.ones_like(open_times) * 4.5, 'g*') + ax1.plot(tl['timestamps'], values, 'k-o') ax1.set_ylabel('Voltage / V'), ax1.set_xlabel('Time / s') - ax1.plot(tl['timestamps'][idx], np.zeros_like(idx), 'r*') - if driver_ttls is not None: - ax1.plot(open_times, np.zeros_like(open_times), 'g*') - return open_times + + ax2 = ax1.twinx() + ax2.set_ylabel('dV', color='grey') + ax2.plot(tl['timestamps'], vel, linestyle='-', color='grey') + ax2.plot(intervals[:, 1], np.ones(len(intervals)) * threshold, 'r*', label='close') + ax2.plot(intervals[:, 0], np.ones(len(intervals)) * threshold, 'g*', label='open') + return intervals if driver_ttls is None else (intervals, open_times) def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ @@ -360,7 +541,7 @@ def _assign_events_audio(self, audio_times, audio_polarities, display=False): """ # make sure that there are no 2 consecutive fall or consecutive rise events assert np.all(np.abs(np.diff(audio_polarities)) == 2) - # take only even time differences: ie. from rising to falling fronts + # take only even time differences: i.e. from rising to falling fronts dt = np.diff(audio_times) onsets = audio_polarities[:-1] == 1 diff --git a/ibllib/io/raw_daq_loaders.py b/ibllib/io/raw_daq_loaders.py index add980130..8ac58c3e7 100644 --- a/ibllib/io/raw_daq_loaders.py +++ b/ibllib/io/raw_daq_loaders.py @@ -292,7 +292,7 @@ def extract_sync_timeline(timeline, chmap=None, floor_percentile=10, threshold=N # Bidirectional; extract indices where delta != 0 raw = correct_counter_discontinuities(raw) d = np.diff(raw) - ind, = np.where(d.astype(int)) + ind, = np.where(~np.isclose(d, 0)) sync.polarities = np.concatenate((sync.polarities, np.sign(d[ind]).astype('i1'))) ind += 1 else: diff --git a/ibllib/io/session_params.py b/ibllib/io/session_params.py index 5bcaf2873..fd9854455 100644 --- a/ibllib/io/session_params.py +++ b/ibllib/io/session_params.py @@ -413,7 +413,6 @@ def iter_dict(d): for d in filter(lambda x: isinstance(x, dict), v): iter_dict(d) elif isinstance(v, dict) and 'collection' in v: - print(k) # if the key already exists, append the collection name to the list if k in collection_map: clist = collection_map[k] if isinstance(collection_map[k], list) else [collection_map[k]] diff --git a/ibllib/pipes/behavior_tasks.py b/ibllib/pipes/behavior_tasks.py index 6f1c8d506..85e21c7ac 100644 --- a/ibllib/pipes/behavior_tasks.py +++ b/ibllib/pipes/behavior_tasks.py @@ -14,7 +14,7 @@ from ibllib.qc.task_metrics import HabituationQC, TaskQC from ibllib.io.extractors.ephys_passive import PassiveChoiceWorld from ibllib.io.extractors.bpod_trials import get_bpod_extractor -from ibllib.io.extractors.ephys_fpga import FpgaTrials, get_sync_and_chn_map +from ibllib.io.extractors.ephys_fpga import FpgaTrials, FpgaTrialsHabituation, get_sync_and_chn_map from ibllib.io.extractors.mesoscope import TimelineTrials from ibllib.pipes import training_status from ibllib.plots.figures import BehaviourPlots @@ -102,14 +102,61 @@ def _run_qc(self, trials_data=None, update=True): qc.extractor = TaskQCExtractor(self.session_path, lazy=True, sync_collection=self.sync_collection, one=self.one, sync_type=self.sync, task_collection=self.collection) - # Currently only the data field is accessed + # Update extractor fields qc.extractor.data = qc.extractor.rename_data(trials_data.copy()) + qc.extractor.frame_ttls = self.extractor.frame2ttl # used in iblapps QC viewer + qc.extractor.audio_ttls = self.extractor.audio # used in iblapps QC viewer + qc.extractor.settings = self.extractor.settings namespace = 'task' if self.protocol_number is None else f'task_{self.protocol_number:02}' qc.run(update=update, namespace=namespace) return qc +class HabituationTrialsNidq(HabituationTrialsBpod): + priority = 90 + job_size = 'small' + + @property + def signature(self): + signature = super().signature + signature['input_files'] = [ + ('_iblrig_taskData.raw.*', self.collection, True), + ('_iblrig_taskSettings.raw.*', self.collection, True), + (f'_{self.sync_namespace}_sync.channels.npy', self.sync_collection, True), + (f'_{self.sync_namespace}_sync.polarities.npy', self.sync_collection, True), + (f'_{self.sync_namespace}_sync.times.npy', self.sync_collection, True), + ('*wiring.json', self.sync_collection, False), + ('*.meta', self.sync_collection, True)] + return signature + + def _extract_behaviour(self, save=True, **kwargs): + """Extract the habituationChoiceWorld trial data using NI DAQ clock.""" + # Extract Bpod trials + bpod_trials, _ = super()._extract_behaviour(save=False, **kwargs) + + # Sync Bpod trials to FPGA + sync, chmap = get_sync_and_chn_map(self.session_path, self.sync_collection) + self.extractor = FpgaTrialsHabituation( + self.session_path, bpod_trials=bpod_trials, bpod_extractor=self.extractor) + + # NB: The stimOff times are called stimCenter times for habituation choice world + outputs, files = self.extractor.extract( + save=save, sync=sync, chmap=chmap, path_out=self.session_path.joinpath(self.output_collection), + task_collection=self.collection, protocol_number=self.protocol_number, **kwargs) + return outputs, files + + def _run_qc(self, trials_data=None, update=True, **_): + """Run and update QC. + + This adds the bpod TTLs to the QC object *after* the QC is run in the super call method. + The raw Bpod TTLs are not used by the QC however they are used in the iblapps QC plot. + """ + qc = super()._run_qc(trials_data=trials_data, update=update) + qc.extractor.bpod_ttls = self.extractor.bpod + return qc + + class TrialRegisterRaw(base_tasks.RegisterRawDataTask, base_tasks.BehaviourTask): priority = 100 job_size = 'small' @@ -286,9 +333,9 @@ def _run_qc(self, trials_data=None, update=True): else: qc = TaskQC(self.session_path, one=self.one, log=_logger) qc_extractor.wheel_encoding = 'X1' - qc_extractor.settings = self.extractor.settings - qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( - self.session_path, task_collection=self.collection) + qc_extractor.settings = self.extractor.settings + qc_extractor.frame_ttls, qc_extractor.audio_ttls = load_bpod_fronts( + self.session_path, task_collection=self.collection) qc.extractor = qc_extractor # Aggregate and update Alyx QC fields @@ -370,14 +417,15 @@ def _run_qc(self, trials_data=None, update=False, plot_qc=False): qc = HabituationQC(self.session_path, one=self.one, log=_logger) else: qc = TaskQC(self.session_path, one=self.one, log=_logger) - qc_extractor.settings = self.extractor.settings # Add Bpod wheel data wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) qc_extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod qc_extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] qc_extractor.wheel_encoding = 'X4' - qc_extractor.frame_ttls = self.extractor.frame2ttl - qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.frame_ttls = self.extractor.frame2ttl + qc_extractor.audio_ttls = self.extractor.audio + qc_extractor.bpod_ttls = self.extractor.bpod + qc_extractor.settings = self.extractor.settings qc.extractor = qc_extractor # Aggregate and update Alyx QC fields diff --git a/ibllib/pipes/dynamic_pipeline.py b/ibllib/pipes/dynamic_pipeline.py index bc2caaf1b..3c72853fb 100644 --- a/ibllib/pipes/dynamic_pipeline.py +++ b/ibllib/pipes/dynamic_pipeline.py @@ -230,22 +230,28 @@ def make_pipeline(session_path, **pkwargs): # - choice_world_biased # - choice_world_training # - choice_world_habituation - if 'habituation' in protocol: - registration_class = btasks.HabituationRegisterRaw - behaviour_class = btasks.HabituationTrialsBpod - compute_status = False - elif 'passiveChoiceWorld' in protocol: + if 'passiveChoiceWorld' in protocol: registration_class = btasks.PassiveRegisterRaw behaviour_class = btasks.PassiveTask compute_status = False elif sync_kwargs['sync'] == 'bpod': - registration_class = btasks.TrialRegisterRaw - behaviour_class = btasks.ChoiceWorldTrialsBpod - compute_status = True + if 'habituation' in protocol: + registration_class = btasks.HabituationRegisterRaw + behaviour_class = btasks.HabituationTrialsBpod + compute_status = False + else: + registration_class = btasks.TrialRegisterRaw + behaviour_class = btasks.ChoiceWorldTrialsBpod + compute_status = True elif sync_kwargs['sync'] == 'nidq': - registration_class = btasks.TrialRegisterRaw - behaviour_class = btasks.ChoiceWorldTrialsNidq - compute_status = True + if 'habituation' in protocol: + registration_class = btasks.HabituationRegisterRaw + behaviour_class = btasks.HabituationTrialsNidq + compute_status = False + else: + registration_class = btasks.TrialRegisterRaw + behaviour_class = btasks.ChoiceWorldTrialsNidq + compute_status = True else: raise NotImplementedError tasks[f'RegisterRaw_{protocol}_{i:02}'] = type(f'RegisterRaw_{protocol}_{i:02}', (registration_class,), {})( diff --git a/ibllib/pipes/ephys_preprocessing.py b/ibllib/pipes/ephys_preprocessing.py index 26cef7050..7ea845d18 100644 --- a/ibllib/pipes/ephys_preprocessing.py +++ b/ibllib/pipes/ephys_preprocessing.py @@ -694,7 +694,8 @@ def _behaviour_criterion(self): ) def _extract_behaviour(self): - dsets, out_files = ephys_fpga.extract_all(self.session_path, save=True) + dsets, out_files, self.extractor = ephys_fpga.extract_all( + self.session_path, save=True, return_extractor=True) return dsets, out_files @@ -709,8 +710,16 @@ def _run(self, plot_qc=True): qc = TaskQC(self.session_path, one=self.one, log=_logger) qc.extractor = TaskQCExtractor(self.session_path, lazy=True, one=qc.one) # Extract extra datasets required for QC - qc.extractor.data = dsets - qc.extractor.extract_data() + qc.extractor.data = qc.extractor.rename_data(dsets) + wheel_ts_bpod = self.extractor.bpod2fpga(self.extractor.bpod_trials['wheel_timestamps']) + qc.extractor.data['wheel_timestamps_bpod'] = wheel_ts_bpod + qc.extractor.data['wheel_position_bpod'] = self.extractor.bpod_trials['wheel_position'] + qc.extractor.wheel_encoding = 'X4' + qc.extractor.settings = self.extractor.settings + qc.extractor.frame_ttls = self.extractor.frame2ttl + qc.extractor.audio_ttls = self.extractor.audio + qc.extractor.bpod_ttls = self.extractor.bpod + # Aggregate and update Alyx QC fields qc.run(update=True) diff --git a/ibllib/qc/task_metrics.py b/ibllib/qc/task_metrics.py index 42361645d..d746626d5 100644 --- a/ibllib/qc/task_metrics.py +++ b/ibllib/qc/task_metrics.py @@ -1,4 +1,5 @@ -"""Behaviour QC +"""Behaviour QC. + This module runs a list of quality control metrics on the behaviour data. Examples @@ -179,20 +180,22 @@ def run(self, update=False, namespace='task', **kwargs): return outcome, results @staticmethod - def compute_session_status_from_dict(results): + def compute_session_status_from_dict(results, criteria=None): """ Given a dictionary of results, computes the overall session QC for each key and aggregates in a single value - :param results: a dictionary of qc keys containing (usually scalar) values + :param results: a dictionary of qc keys containing (usually scalar) values. + :param criteria: a dictionary of qc keys containing map of PASS, WARNING, FAIL thresholds. :return: Overall session QC outcome as a string :return: A dict of QC tests and their outcomes """ indices = np.zeros(len(results), dtype=int) + criteria = criteria or TaskQC.criteria for i, k in enumerate(results): - if k in TaskQC.criteria.keys(): - indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria[k]) + if k in criteria.keys(): + indices[i] = TaskQC._thresholding(results[k], thresholds=criteria[k]) else: - indices[i] = TaskQC._thresholding(results[k], thresholds=TaskQC.criteria['default']) + indices[i] = TaskQC._thresholding(results[k], thresholds=criteria['default']) def key_map(x): return 'NOT_SET' if x < 0 else list(TaskQC.criteria['default'].keys())[x] @@ -213,14 +216,19 @@ def compute_session_status(self): # Get mean passed of each check, or None if passed is None or all NaN results = {k: None if v is None or np.isnan(v).all() else np.nanmean(v) for k, v in self.passed.items()} - session_outcome, outcomes = self.compute_session_status_from_dict(results) + session_outcome, outcomes = self.compute_session_status_from_dict(results, self.criteria) return session_outcome, results, outcomes class HabituationQC(TaskQC): - def compute(self, download_data=None): - """Compute and store the QC metrics + criteria = dict() + criteria['default'] = {'PASS': 0.99, 'WARNING': 0.90, 'FAIL': 0} # Note: WARNING was 0.95 prior to Aug 2022 + criteria['_task_phase_distribution'] = {'PASS': 0.99, 'NOT_SET': 0} # This rarely passes due to low trial num + + def compute(self, download_data=None, **kwargs): + """Compute and store the QC metrics. + Runs the QC on the session and stores a map of the metrics for each datapoint for each test, and a map of which datapoints passed for each test :return: @@ -228,7 +236,7 @@ def compute(self, download_data=None): if self.extractor is None: # If download_data is None, decide based on whether eid or session path was provided ensure_data = self.download_data if download_data is None else download_data - self.load_data(download_data=ensure_data) + self.load_data(download_data=ensure_data, **kwargs) self.log.info(f'Session {self.session_path}: Running QC on habituation data...') # Initialize checks @@ -302,6 +310,7 @@ def compute(self, download_data=None): passed[check] = (metric <= 2 * np.pi) & (metric >= 0) metrics[check] = metric + # This is not very useful as a check because there are so few trials check = prefix + 'phase_distribution' metric, _ = np.histogram(data['phase']) _, p = chisquare(metric) diff --git a/ibllib/tests/extractors/test_ephys_fpga.py b/ibllib/tests/extractors/test_ephys_fpga.py index ca211e426..fdfe27218 100644 --- a/ibllib/tests/extractors/test_ephys_fpga.py +++ b/ibllib/tests/extractors/test_ephys_fpga.py @@ -1,15 +1,11 @@ +"""Tests for ephys FPGA sync and FPGA wheel extraction.""" import unittest import tempfile from pathlib import Path -import pickle -import logging import numpy as np -from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units from ibllib.io.extractors import ephys_fpga -from ibllib.io.extractors.training_wheel import extract_wheel_moves -import brainbox.behavior.wheel as wh import spikeglx @@ -88,189 +84,12 @@ def test_ibl_sync_maps(self): self.assertEqual(s, ephys_fpga.CHMAPS['3B']['ap']) -class TestWheelExtraction(unittest.TestCase): - - def setUp(self) -> None: - self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - - def test_x1_decoding(self): - p_ = np.array([1, 2, 1, 0]) - t_ = np.array([2, 6, 11, 15]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - def test_x4_decoding(self): - p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 - t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(np.isclose(p, p_))) - - def test_x2_decoding(self): - p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 - t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') - self.assertTrue(np.all(t == t_)) - self.assertTrue(np.all(p == p_)) - - -class TestExtractedWheelUnits(unittest.TestCase): - """Tests the infer_wheel_units function""" - - wheel_radius_cm = 3.1 - - def setUp(self) -> None: - """ - Create the wheel position data for testing: the positions attribute holds a dictionary of - units, each holding a dictionary of encoding types to test, e.g. - - positions = { - 'rad': { - 'X1': ..., - 'X2': ..., - 'X4': ... - }, - 'cm': { - 'X1': ..., - 'X2': ..., - 'X4': ... - } - } - :return: - """ - def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): - radius = 1 if unit == 'rad' else wheel_radius - return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc - - # A pseudo-random sequence of integrated fronts - seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) - encs = (1, 2, 4) # Encoding types to test - units = ('rad', 'cm') # Units to test - self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} - - def test_extract_wheel_moves(self): - for unit in self.positions.keys(): - for encoding, pos in self.positions[unit].items(): - result = infer_wheel_units(pos) - self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') - expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) - self.assertEqual(expected, result[1], - f'failed to determine number of ticks for {encoding} in {unit}') - self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') - - -class TestWheelMovesExtraction(unittest.TestCase): - - def setUp(self) -> None: - """ - Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a - numpy array of timestamps and one of positions; outputs is a tuple of outputs from - the functions. For details, see help on TestWheel.setUp method in module - brainbox.tests.test_behavior - """ - pickle_file = Path(__file__).parents[3].joinpath( - 'brainbox', 'tests', 'fixtures', 'wheel_test.p') - if not pickle_file.exists(): - self.test_data = None - else: - with open(pickle_file, 'rb') as f: - self.test_data = pickle.load(f) - - # Some trial times for trial_data[1] - self.trials = { - 'goCue_times': np.array([162.5, 105.6, 55]), - 'feedback_times': np.array([164.3, 108.3, 56]) - } - - def test_extract_wheel_moves(self): - test_data = self.test_data[1] - # Wrangle data into expected form - re_ts = test_data[0][0] - re_pos = test_data[0][1] - - logger = logging.getLogger(logname := 'ibllib.io.extractors.training_wheel') - with self.assertLogs(logger, level='INFO') as cm: - wheel_moves = extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) - - n = 56 # expected number of movements - self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), - 'failed to return the correct number of intervals') - self.assertEqual(wheel_moves['peakAmplitude'].size, n) - self.assertEqual(wheel_moves['peakVelocity_times'].size, n) - - # Check the first 3 intervals - ints = np.array( - [[24.78462599, 25.22562599], - [29.58762599, 31.15062599], - [31.64262599, 31.81662599]]) - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - # Check amplitudes - actual = wheel_moves['peakAmplitude'][-3:] - expected = [0.50255486, -1.70103154, 1.00740789] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') - - # Check peak velocities - actual = wheel_moves['peakVelocity_times'][-3:] - expected = [175.13662599, 176.65762599, 178.57262599] - self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') - - # Test extraction in rad - re_pos = wh.cm_to_rad(re_pos) - with self.assertLogs(logger, level='INFO') as cm: - wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) - self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) - - # Check the first 3 intervals. As position thresholds are adjusted by units and - # encoding, we should expect the intervals to be identical to above - actual = wheel_moves['intervals'][:3, ] - self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') - - def test_movement_log(self): - """ - Integration test for inferring the units and decoding type for wheel data input for - extract_wheel_moves. Only expected to work for the default wheel diameter. - """ - ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) - pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) - pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) - logger = logging.getLogger(logname := 'ibllib.io.extractors.training_wheel') - - for unit in ['cm', 'rad']: - for i in (1, 2, 4): - encoding = 'X' + str(i) - r = 3.1 if unit == 'cm' else 1 - # print(encoding, unit) - t, p = ephys_fpga._rotary_encoder_positions_from_fronts( - ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) - expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' - with self.assertLogs(logger, level='INFO') as cm: - ephys_fpga.extract_wheel_moves(t, p) - self.assertEqual([expected], cm.output) - - def test_extract_first_movement_times(self): - test_data = self.test_data[1] - wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) - first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) - np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) - np.testing.assert_array_equal(is_final, [False, True, False]) - np.testing.assert_array_equal(ind, [46, 18]) - - class TestEphysFPGA_TTLsExtraction(unittest.TestCase): def test_audio_ttl_wiring_camera(self): """ + Test ephys_fpga._clean_audio function. + Test removal of spurious TTLs due to a wrong wiring of the camera onto the soundcard example eid: e349a2e7-50a3-47ca-bc45-20d1899854ec """ @@ -389,14 +208,15 @@ def test_frame2ttl_flickers(self): switches under a given threshold """ DISPLAY = False # for debug purposes - diff = ephys_fpga.F2TTL_THRESH * np.array([0.5, 10]) + F2TTL_THRESH = 0.01 + diff = F2TTL_THRESH * np.array([0.5, 10]) # flicker ends with a polarity switch - downgoing pulse is removed t = np.r_[0, np.cumsum(diff[np.array([1, 1, 0, 0, 1])])] + 1 frame2ttl = {'times': t, 'polarities': np.mod(np.arange(t.size) + 1, 2) * 2 - 1} expected = {'times': np.array([1., 1.1, 1.2, 1.31]), 'polarities': np.array([1, -1, 1, -1])} - frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY) + frame2ttl_ = ephys_fpga._clean_frame2ttl(frame2ttl, display=DISPLAY, threshold=F2TTL_THRESH) assert all([np.all(frame2ttl_[k] == expected[k]) for k in frame2ttl_]) # stand-alone flicker diff --git a/ibllib/tests/extractors/test_ephys_trials.py b/ibllib/tests/extractors/test_ephys_trials.py index d5483792f..7d77079af 100644 --- a/ibllib/tests/extractors/test_ephys_trials.py +++ b/ibllib/tests/extractors/test_ephys_trials.py @@ -1,15 +1,23 @@ import unittest from pathlib import Path +import pickle + import numpy as np from ibllib.io.extractors import ephys_fpga, biased_trials import ibllib.io.raw_data_loaders as raw +from ibllib.io.extractors.training_wheel import extract_first_movement_times, infer_wheel_units +from ibllib.io.extractors.training_wheel import extract_wheel_moves +import brainbox.behavior.wheel as wh class TestEphysSyncExtraction(unittest.TestCase): def test_bpod_trace_extraction(self): + """Test ephys_fpga._assign_events_bpod function. + TODO Remove this test and corresponding function. + """ t_valve_open_ = np.array([117.12136667, 122.3873, 127.82903333, 140.56083333, 143.55326667, 155.29713333, 164.9186, 167.91133333, 171.39736667, 178.0305, 181.70343333]) @@ -48,6 +56,7 @@ def test_bpod_trace_extraction(self): self.assertTrue(np.all(np.isclose(t_valve_open, t_valve_open_))) def test_align_to_trial(self): + """Test ephys_fpga._assign_events_to_trial function.""" # simple test with one missing at the end t_trial_start = np.arange(0, 5) * 10 t_event = np.arange(0, 5) * 10 + 2 @@ -95,6 +104,7 @@ def test_align_to_trial(self): self.assertRaises(ValueError, ephys_fpga._assign_events_to_trial, t_trial_start, np.array([0., 2., 1.])) def test_wheel_trace_from_sync(self): + """Test ephys_fpga._rotary_encoder_positions_from_fronts function.""" pos_ = - np.array([-1, 0, -1, -2, -1, -2]) * (np.pi / ephys_fpga.WHEEL_TICKS) ta = np.array([1, 2, 3, 4, 5, 6]) tb = np.array([0.5, 3.2, 3.3, 3.4, 5.25, 5.5]) @@ -137,5 +147,184 @@ def test_get_probabilityLeft(self): self.assertTrue(all([x in [0.2, 0.5, 0.8] for x in np.unique(pLeft1)])) +class TestWheelExtraction(unittest.TestCase): + + def setUp(self) -> None: + self.ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + self.pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + self.tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + self.pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + + def test_x1_decoding(self): + p_ = np.array([1, 2, 1, 0]) + t_ = np.array([2, 6, 11, 15]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x1') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + def test_x4_decoding(self): + p_ = np.array([1, 2, 3, 4, 5, 6, 7, 8, 7, 6, 5, 4, 3, 2, 1, 0]) / 4 + t_ = np.array([2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 15, 16, 17, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x4') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(np.isclose(p, p_))) + + def test_x2_decoding(self): + p_ = np.array([1, 2, 3, 4, 3, 2, 1, 0]) / 2 + t_ = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + self.ta, self.pa, self.tb, self.pb, ticks=np.pi * 2, coding='x2') + self.assertTrue(np.all(t == t_)) + self.assertTrue(np.all(p == p_)) + + +class TestExtractedWheelUnits(unittest.TestCase): + """Tests the infer_wheel_units function""" + + wheel_radius_cm = 3.1 + + def setUp(self) -> None: + """ + Create the wheel position data for testing: the positions attribute holds a dictionary of + units, each holding a dictionary of encoding types to test, e.g. + + positions = { + 'rad': { + 'X1': ..., + 'X2': ..., + 'X4': ... + }, + 'cm': { + 'X1': ..., + 'X2': ..., + 'X4': ... + } + } + :return: + """ + def x(unit, enc=int(1), wheel_radius=self.wheel_radius_cm): + radius = 1 if unit == 'rad' else wheel_radius + return 1 / ephys_fpga.WHEEL_TICKS * np.pi * 2 * radius / enc + + # A pseudo-random sequence of integrated fronts + seq = np.array([-1, 0, 1, 2, 1, 2, 3, 4, 3, 2, 1, 0, -1, -2, 1, -2]) + encs = (1, 2, 4) # Encoding types to test + units = ('rad', 'cm') # Units to test + self.positions = {unit: {f'X{e}': x(unit, e) * seq for e in encs} for unit in units} + + def test_extract_wheel_moves(self): + for unit in self.positions.keys(): + for encoding, pos in self.positions[unit].items(): + result = infer_wheel_units(pos) + self.assertEqual(unit, result[0], f'failed to determine units for {encoding}') + expected = int(ephys_fpga.WHEEL_TICKS * int(encoding[1])) + self.assertEqual(expected, result[1], + f'failed to determine number of ticks for {encoding} in {unit}') + self.assertEqual(encoding, result[2], f'failed to determine encoding in {unit}') + + +class TestWheelMovesExtraction(unittest.TestCase): + + def setUp(self) -> None: + """ + Test data is in the form ((inputs), (outputs)) where inputs is a tuple containing a + numpy array of timestamps and one of positions; outputs is a tuple of outputs from + the functions. For details, see help on TestWheel.setUp method in module + brainbox.tests.test_behavior + """ + pickle_file = Path(__file__).parents[3].joinpath( + 'brainbox', 'tests', 'fixtures', 'wheel_test.p') + if not pickle_file.exists(): + self.test_data = None + else: + with open(pickle_file, 'rb') as f: + self.test_data = pickle.load(f) + + # Some trial times for trial_data[1] + self.trials = { + 'goCue_times': np.array([162.5, 105.6, 55]), + 'feedback_times': np.array([164.3, 108.3, 56]) + } + + def test_extract_wheel_moves(self): + test_data = self.test_data[1] + # Wrangle data into expected form + re_ts = test_data[0][0] + re_pos = test_data[0][1] + + logname = 'ibllib.io.extractors.training_wheel' + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in cm units using X2 encoding'], cm.output) + + n = 56 # expected number of movements + self.assertTupleEqual(wheel_moves['intervals'].shape, (n, 2), + 'failed to return the correct number of intervals') + self.assertEqual(wheel_moves['peakAmplitude'].size, n) + self.assertEqual(wheel_moves['peakVelocity_times'].size, n) + + # Check the first 3 intervals + ints = np.array( + [[24.78462599, 25.22562599], + [29.58762599, 31.15062599], + [31.64262599, 31.81662599]]) + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + # Check amplitudes + actual = wheel_moves['peakAmplitude'][-3:] + expected = [0.50255486, -1.70103154, 1.00740789] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'unexpected amplitudes') + + # Check peak velocities + actual = wheel_moves['peakVelocity_times'][-3:] + expected = [175.13662599, 176.65762599, 178.57262599] + self.assertIsNone(np.testing.assert_allclose(actual, expected), 'peak times') + + # Test extraction in rad + re_pos = wh.cm_to_rad(re_pos) + with self.assertLogs(logname, level='INFO') as cm: + wheel_moves = ephys_fpga.extract_wheel_moves(re_ts, re_pos) + self.assertEqual([f'INFO:{logname}:Wheel in rad units using X2 encoding'], cm.output) + + # Check the first 3 intervals. As position thresholds are adjusted by units and + # encoding, we should expect the intervals to be identical to above + actual = wheel_moves['intervals'][:3, ] + self.assertIsNone(np.testing.assert_allclose(actual, ints), 'unexpected intervals') + + def test_movement_log(self): + """ + Integration test for inferring the units and decoding type for wheel data input for + extract_wheel_moves. Only expected to work for the default wheel diameter. + """ + ta = np.array([2, 4, 6, 8, 12, 14, 16, 18]) + pa = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + tb = np.array([3, 5, 7, 9, 11, 13, 15, 17]) + pb = np.array([1, -1, 1, -1, 1, -1, 1, -1]) + logname = 'ibllib.io.extractors.training_wheel' + + for unit in ['cm', 'rad']: + for i in (1, 2, 4): + encoding = 'X' + str(i) + r = 3.1 if unit == 'cm' else 1 + # print(encoding, unit) + t, p = ephys_fpga._rotary_encoder_positions_from_fronts( + ta, pa, tb, pb, ticks=1024, coding=encoding.lower(), radius=r) + expected = f'INFO:{logname}:Wheel in {unit} units using {encoding} encoding' + with self.assertLogs(logname, level='INFO') as cm: + ephys_fpga.extract_wheel_moves(t, p) + self.assertEqual([expected], cm.output) + + def test_extract_first_movement_times(self): + test_data = self.test_data[1] + wheel_moves = ephys_fpga.extract_wheel_moves(test_data[0][0], test_data[0][1]) + first, is_final, ind = extract_first_movement_times(wheel_moves, self.trials) + np.testing.assert_allclose(first, [162.48462599, 105.62562599, np.nan]) + np.testing.assert_array_equal(is_final, [False, True, False]) + np.testing.assert_array_equal(ind, [46, 18]) + + if __name__ == '__main__': unittest.main(exit=False, verbosity=2)