diff --git a/brainbox/io/one.py b/brainbox/io/one.py index 5f1d6a15d..8c0925685 100644 --- a/brainbox/io/one.py +++ b/brainbox/io/one.py @@ -1,5 +1,5 @@ """Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment.""" -from dataclasses import dataclass +from dataclasses import dataclass, field import logging import os from pathlib import Path @@ -24,6 +24,8 @@ from brainbox.core import TimeSeries from brainbox.processing import sync from brainbox.metrics.single_units import quick_unit_metrics +from brainbox.behavior.wheel import interpolate_position, velocity_smoothed +from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter _logger = logging.getLogger('ibllib') @@ -1056,3 +1058,334 @@ def samples2times(self, values, direction='forward'): 'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'), } return self._sync[direction](values) + + +@dataclass +class SessionLoader: + """ + Object to load session data for a give session in the recommended way. + + Parameters + ---------- + one: one.api.ONE instance + Can be in remote or local mode (required) + session_path: string or pathlib.Path + The absolute path to the session (one of session_path or eid is required) + eid: string + database UUID of the session (one of session_path or eid is required) + + If both are provided, session_path takes precedence over eid. + + Examples + -------- + 1) Load all available session data for one session: + >>> from one.api import ONE + >>> from brainbox.io.one import SessionLoader + >>> one = ONE() + >>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/') + # Object is initiated, but no data is loaded as you can see in the data_info attribute + >>> sess_loader.data_info + name is_loaded + 0 trials False + 1 wheel False + 2 pose False + 3 motion_energy False + 4 pupil False + + # Loading all available session data, the data_info attribute now shows which data has been loaded + >>> sess_loader.load_session_data() + >>> sess_loader.data_info + name is_loaded + 0 trials True + 1 wheel True + 2 pose True + 3 motion_energy True + 4 pupil False + + # The data is loaded in pandas dataframes that you can access via the respective attributes, e.g. + >>> type(sess_loader.trials) + pandas.core.frame.DataFrame + >>> sess_loader.trials.shape + (626, 18) + # Each data comes with its own timestamps in a column called 'times' + >>> sess_loader.wheel['times'] + 0 0.134286 + 1 0.135286 + 2 0.136286 + 3 0.137286 + 4 0.138286 + ... + # For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera. + # The dataframes of all cameras are collected in a dictionary + >>> type(sess_loader.pose) + dict + >>> sess_loader.pose.keys() + dict_keys(['leftCamera', 'rightCamera', 'bodyCamera']) + >>> sess_loader.pose['bodyCamera'].columns + Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object') + # In order to control the loading of specific data by e.g. specifying parameters, use the individual loading + functions: + >>> sess_loader.load_wheel(sampling_rate=100) + """ + one: One = None + session_path: Path = '' + eid: str = '' + data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) + trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) + wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) + pose: dict = field(default_factory=dict, repr=False) + motion_energy: dict = field(default_factory=dict, repr=False) + pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False) + + def __post_init__(self): + """ + Function that runs automatically after initiation of the dataclass attributes. + Checks for required inputs, sets session_path and eid, creates data_info table. + """ + if self.one is None: + raise ValueError("An input to one is required. If not connection to a database is desired, it can be " + "a fully local instance of One.") + # If session path is given, takes precedence over eid + if self.session_path is not None and self.session_path != '': + self.eid = self.one.to_eid(self.session_path) + self.session_path = Path(self.session_path) + # Providing no session path, try to infer from eid + else: + if self.eid is not None and self.eid != '': + self.session_path = self.one.eid2path(self.eid) + else: + raise ValueError("If no session path is given, eid is required.") + + data_names = [ + 'trials', + 'wheel', + 'pose', + 'motion_energy', + 'pupil' + ] + self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names))) + + def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False): + """ + Function to load available session data into the SessionLoader object. Input parameters allow to control which + data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input + parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored + in SessionLoader.data_info + + Parameters + ---------- + trials: boolean + Whether to load all trials data into SessionLoader.trials, default is True + wheel: boolean + Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True + pose: boolean + Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose, + default is True + motion_energy: boolean + Whether to load motion energy data (whisker pad for left/right camera, body for body camera) + into SessionLoader.motion_energy, default is True + pupil: boolean + Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil, + default is True + reload: boolean + Whether to reload data that has already been loaded into this SessionLoader object, default is False + """ + load_df = self.data_info.copy() + load_df['to_load'] = [ + trials, + wheel, + pose, + motion_energy, + pupil + ] + load_df['load_func'] = [ + self.load_trials, + self.load_wheel, + self.load_pose, + self.load_motion_energy, + self.load_pupil + ] + + for idx, row in load_df.iterrows(): + if row['to_load'] is False: + _logger.debug(f"Not loading {row['name']} data, set to False.") + elif row['is_loaded'] is True and reload is False: + _logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.") + else: + try: + _logger.info(f"Loading {row['name']} data") + row['load_func']() + self.data_info.loc[idx, 'is_loaded'] = True + except BaseException as e: + _logger.warning(f"Could not load {row['name']} data.") + _logger.debug(e) + + def load_trials(self): + """ + Function to load trials data into SessionLoader.trials + """ + self.trials = self.one.load_object(self.eid, 'trials', collection='alf').to_df() + self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True + + def load_wheel(self, sampling_rate=1000, smooth_size=0.03): + """ + Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position + is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which + Gaussian smoothing is applied. + + Parameters + ---------- + sampling_rate: float + Rate at which to sample the wheel position, default is 1000 Hz + smooth_size: float + Size of Gaussian smoothing window in seconds, default is 0.03 + """ + wheel_raw = self.one.load_object(self.eid, 'wheel') + # TODO: Fix this instead of raising error? + if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]: + raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps") + # resample the wheel position and compute velocity, acceleration + self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration']) + self.wheel['position'], self.wheel['times'] = interpolate_position( + wheel_raw['timestamps'], wheel_raw['position'], freq=sampling_rate) + self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed( + self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size) + self.wheel = self.wheel.apply(np.float32) + self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True + + def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']): + """ + Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a + dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas + Dataframes with the timestamps and pose data, one row for each body part tracked for that camera. + + Parameters + ---------- + likelihood_thr: float + The position of each tracked body part come with a likelihood of that estimate for each time point. + Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set + likelihood_thr=1. Default is 0.9 + views: list + List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} + """ + # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger + self.pose = {} + for view in views: + try: + pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times']) + # Double check if video timestamps are correct length or can be fixed + times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc']) + self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr) + self.pose[f'{view}Camera'].insert(0, 'times', times_fixed) + self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True + except BaseException as e: + _logger.warning(f'Could not load pose data for {view}Camera. Skipping camera.') + _logger.debug(e) + + def load_motion_energy(self, views=['left', 'right', 'body']): + """ + Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a + dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are + pandas Dataframes with the timestamps and motion energy data. + The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad + (whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the + body (bodyMotionEnergy). + + Parameters + ---------- + views: list + List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'} + """ + names = {'left': 'whiskerMotionEnergy', + 'right': 'whiskerMotionEnergy', + 'body': 'bodyMotionEnergy'} + # empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger + self.motion_energy = {} + for view in views: + try: + me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times']) + # Double check if video timestamps are correct length or can be fixed + times_fixed, motion_energy = self._check_video_timestamps( + view, me_raw['times'], me_raw['ROIMotionEnergy']) + self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy) + self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed) + self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True + except BaseException as e: + _logger.warning(f'Could not load motion energy data for {view}Camera. Skipping camera.') + _logger.debug(e) + + def load_licks(self): + """ + Not yet implemented + """ + pass + + def load_pupil(self, snr_thresh=5.): + """ + Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil. + + Parameters + ---------- + snr_thresh: float + An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data + will be considered unusable and will be discarded. + """ + # Try to load from features + feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features']) + if 'features' in feat_raw.keys(): + times_fixed, feats = self._check_video_timestamps(feat_raw['times'], feat_raw['features']) + self.pupil = feats.copy() + self.pupil.insert(0, 'times', times_fixed) + + # If unavailable compute on the fly + else: + _logger.info('Pupil diameter not available, trying to compute on the fly.') + if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0] + and 'leftCamera' in self.pose.keys()): + # If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt + copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data + self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9 + dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable + self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place + else: + self.load_pose(views=['left'], likelihood_thr=0.9) + dlc_thr = self.pose['leftCamera'].copy() + + self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr) + try: + self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left') + except BaseException as e: + _logger.error("Computing smooth pupil diameter failed, saving all NaNs.") + _logger.debug(e) + self.pupil['pupilDiameter_smooth'] = np.nan + + if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])): + good_idxs = np.where( + ~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0] + snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) / + (np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs]))) + if snr < snr_thresh: + self.pupil = pd.DataFrame() + _logger.error(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.') + + def _check_video_timestamps(self, view, video_timestamps, video_data): + """ + Helper function to check for the length of the video frames vs video timestamps and fix in case + timestamps are longer than video frames. + """ + # If camera times are shorter than video data, or empty, no current fix + if video_timestamps.shape[0] < video_data.shape[0]: + if video_timestamps.shape[0] == 0: + msg = f'Camera times empty for {view}Camera.' + else: + msg = f'Camera times are shorter than video data for {view}Camera.' + _logger.warning(msg) + raise ValueError(msg) + # For pre-GPIO sessions, it is possible that the camera times are longer than the actual video. + # This is because the first few frames are sometimes not recorded. We can remove the first few + # timestamps in this case + elif video_timestamps.shape[0] > video_data.shape[0]: + video_timestamps_fixed = video_timestamps[-video_data.shape[0]:] + return video_timestamps_fixed, video_data + else: + return video_timestamps, video_data