Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
335 changes: 334 additions & 1 deletion brainbox/io/one.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Functions for loading IBL ephys and trial data using the Open Neurophysiology Environment."""
from dataclasses import dataclass
from dataclasses import dataclass, field
import logging
import os
from pathlib import Path
Expand All @@ -24,6 +24,8 @@
from brainbox.core import TimeSeries
from brainbox.processing import sync
from brainbox.metrics.single_units import quick_unit_metrics
from brainbox.behavior.wheel import interpolate_position, velocity_smoothed
from brainbox.behavior.dlc import likelihood_threshold, get_pupil_diameter, get_smooth_pupil_diameter

_logger = logging.getLogger('ibllib')

Expand Down Expand Up @@ -1056,3 +1058,334 @@ def samples2times(self, values, direction='forward'):
'reverse': interp1d(timestamps[:, 1], timestamps[:, 0], fill_value='extrapolate'),
}
return self._sync[direction](values)


@dataclass
class SessionLoader:
"""
Object to load session data for a give session in the recommended way.

Parameters
----------
one: one.api.ONE instance
Can be in remote or local mode (required)
session_path: string or pathlib.Path
The absolute path to the session (one of session_path or eid is required)
eid: string
database UUID of the session (one of session_path or eid is required)

If both are provided, session_path takes precedence over eid.

Examples
--------
1) Load all available session data for one session:
>>> from one.api import ONE
>>> from brainbox.io.one import SessionLoader
>>> one = ONE()
>>> sess_loader = SessionLoader(one=one, session_path='/mnt/s0/Data/Subjects/cortexlab/KS022/2019-12-10/001/')
# Object is initiated, but no data is loaded as you can see in the data_info attribute
>>> sess_loader.data_info
name is_loaded
0 trials False
1 wheel False
2 pose False
3 motion_energy False
4 pupil False

# Loading all available session data, the data_info attribute now shows which data has been loaded
>>> sess_loader.load_session_data()
>>> sess_loader.data_info
name is_loaded
0 trials True
1 wheel True
2 pose True
3 motion_energy True
4 pupil False

# The data is loaded in pandas dataframes that you can access via the respective attributes, e.g.
>>> type(sess_loader.trials)
pandas.core.frame.DataFrame
>>> sess_loader.trials.shape
(626, 18)
# Each data comes with its own timestamps in a column called 'times'
>>> sess_loader.wheel['times']
0 0.134286
1 0.135286
2 0.136286
3 0.137286
4 0.138286
...
# For camera data (pose, motionEnergy) the respective functions load the data into one dataframe per camera.
# The dataframes of all cameras are collected in a dictionary
>>> type(sess_loader.pose)
dict
>>> sess_loader.pose.keys()
dict_keys(['leftCamera', 'rightCamera', 'bodyCamera'])
>>> sess_loader.pose['bodyCamera'].columns
Index(['times', 'tail_start_x', 'tail_start_y', 'tail_start_likelihood'], dtype='object')
# In order to control the loading of specific data by e.g. specifying parameters, use the individual loading
functions:
>>> sess_loader.load_wheel(sampling_rate=100)
"""
one: One = None
session_path: Path = ''
eid: str = ''
data_info: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
trials: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
wheel: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)
pose: dict = field(default_factory=dict, repr=False)
motion_energy: dict = field(default_factory=dict, repr=False)
pupil: pd.DataFrame = field(default_factory=pd.DataFrame, repr=False)

def __post_init__(self):
"""
Function that runs automatically after initiation of the dataclass attributes.
Checks for required inputs, sets session_path and eid, creates data_info table.
"""
if self.one is None:
raise ValueError("An input to one is required. If not connection to a database is desired, it can be "
"a fully local instance of One.")
# If session path is given, takes precedence over eid
if self.session_path is not None and self.session_path != '':
self.eid = self.one.to_eid(self.session_path)
self.session_path = Path(self.session_path)
# Providing no session path, try to infer from eid
else:
if self.eid is not None and self.eid != '':
self.session_path = self.one.eid2path(self.eid)
else:
raise ValueError("If no session path is given, eid is required.")

data_names = [
'trials',
'wheel',
'pose',
'motion_energy',
'pupil'
]
self.data_info = pd.DataFrame(columns=['name', 'is_loaded'], data=zip(data_names, [False] * len(data_names)))

def load_session_data(self, trials=True, wheel=True, pose=True, motion_energy=True, pupil=True, reload=False):
"""
Function to load available session data into the SessionLoader object. Input parameters allow to control which
data is loaded. Data is loaded into an attribute of the SessionLoader object with the same name as the input
parameter (e.g. SessionLoader.trials, SessionLoader.pose). Information about which data is loaded is stored
in SessionLoader.data_info

Parameters
----------
trials: boolean
Whether to load all trials data into SessionLoader.trials, default is True
wheel: boolean
Whether to load wheel data (position, velocity, acceleration) into SessionLoader.wheel, default is True
pose: boolean
Whether to load pose tracking results (DLC) for each available camera into SessionLoader.pose,
default is True
motion_energy: boolean
Whether to load motion energy data (whisker pad for left/right camera, body for body camera)
into SessionLoader.motion_energy, default is True
pupil: boolean
Whether to load pupil diameter (raw and smooth) for the left/right camera into SessionLoader.pupil,
default is True
reload: boolean
Whether to reload data that has already been loaded into this SessionLoader object, default is False
"""
load_df = self.data_info.copy()
load_df['to_load'] = [
trials,
wheel,
pose,
motion_energy,
pupil
]
load_df['load_func'] = [
self.load_trials,
self.load_wheel,
self.load_pose,
self.load_motion_energy,
self.load_pupil
]

for idx, row in load_df.iterrows():
if row['to_load'] is False:
_logger.debug(f"Not loading {row['name']} data, set to False.")
elif row['is_loaded'] is True and reload is False:
_logger.debug(f"Not loading {row['name']} data, is already loaded and reload=False.")
else:
try:
_logger.info(f"Loading {row['name']} data")
row['load_func']()
self.data_info.loc[idx, 'is_loaded'] = True
except BaseException as e:
_logger.warning(f"Could not load {row['name']} data.")
_logger.debug(e)

def load_trials(self):
"""
Function to load trials data into SessionLoader.trials
"""
self.trials = self.one.load_object(self.eid, 'trials', collection='alf').to_df()
self.data_info.loc[self.data_info['name'] == 'trials', 'is_loaded'] = True

def load_wheel(self, sampling_rate=1000, smooth_size=0.03):
"""
Function to load wheel data (position, velocity, acceleration) into SessionLoader.wheel. The wheel position
is first interpolated to a uniform sampling rate. Then velocity and acceleration are computed, during which
Gaussian smoothing is applied.

Parameters
----------
sampling_rate: float
Rate at which to sample the wheel position, default is 1000 Hz
smooth_size: float
Size of Gaussian smoothing window in seconds, default is 0.03
"""
wheel_raw = self.one.load_object(self.eid, 'wheel')
# TODO: Fix this instead of raising error?
if wheel_raw['position'].shape[0] != wheel_raw['timestamps'].shape[0]:
raise ValueError("Length mismatch between 'wheel.position' and 'wheel.timestamps")
# resample the wheel position and compute velocity, acceleration
self.wheel = pd.DataFrame(columns=['times', 'position', 'velocity', 'acceleration'])
self.wheel['position'], self.wheel['times'] = interpolate_position(
wheel_raw['timestamps'], wheel_raw['position'], freq=sampling_rate)
self.wheel['velocity'], self.wheel['acceleration'] = velocity_smoothed(
self.wheel['position'], freq=sampling_rate, smooth_size=smooth_size)
self.wheel = self.wheel.apply(np.float32)
self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True

def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
"""
Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
Dataframes with the timestamps and pose data, one row for each body part tracked for that camera.

Parameters
----------
likelihood_thr: float
The position of each tracked body part come with a likelihood of that estimate for each time point.
Estimates for time points with likelihood < likelihood_thr are set to NaN. To skip thresholding set
likelihood_thr=1. Default is 0.9
views: list
List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
"""
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
self.pose = {}
for view in views:
try:
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'])
# Double check if video timestamps are correct length or can be fixed
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc'])
self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
self.pose[f'{view}Camera'].insert(0, 'times', times_fixed)
self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True
except BaseException as e:
_logger.warning(f'Could not load pose data for {view}Camera. Skipping camera.')
_logger.debug(e)

def load_motion_energy(self, views=['left', 'right', 'body']):
"""
Function to load the motion energy data into SessionLoader.motion_energy. SessionLoader.motion_energy is a
dictionary where keys are the names of the cameras for which motion energy data is loaded, and values are
pandas Dataframes with the timestamps and motion energy data.
The motion energy for the left and right camera is calculated for a square roughly covering the whisker pad
(whiskerMotionEnergy). The motion energy for the body camera is calculated for a square covering much of the
body (bodyMotionEnergy).

Parameters
----------
views: list
List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
"""
names = {'left': 'whiskerMotionEnergy',
'right': 'whiskerMotionEnergy',
'body': 'bodyMotionEnergy'}
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
self.motion_energy = {}
for view in views:
try:
me_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['ROIMotionEnergy', 'times'])
# Double check if video timestamps are correct length or can be fixed
times_fixed, motion_energy = self._check_video_timestamps(
view, me_raw['times'], me_raw['ROIMotionEnergy'])
self.motion_energy[f'{view}Camera'] = pd.DataFrame(columns=[names[view]], data=motion_energy)
self.motion_energy[f'{view}Camera'].insert(0, 'times', times_fixed)
self.data_info.loc[self.data_info['name'] == 'motion_energy', 'is_loaded'] = True
except BaseException as e:
_logger.warning(f'Could not load motion energy data for {view}Camera. Skipping camera.')
_logger.debug(e)

def load_licks(self):
"""
Not yet implemented
"""
pass

def load_pupil(self, snr_thresh=5.):
"""
Function to load raw and smoothed pupil diameter data from the left camera into SessionLoader.pupil.

Parameters
----------
snr_thresh: float
An SNR is calculated from the raw and smoothed pupil diameter. If this snr < snr_thresh the data
will be considered unusable and will be discarded.
"""
# Try to load from features
feat_raw = self.one.load_object(self.eid, 'leftCamera', attribute=['times', 'features'])
if 'features' in feat_raw.keys():
times_fixed, feats = self._check_video_timestamps(feat_raw['times'], feat_raw['features'])
self.pupil = feats.copy()
self.pupil.insert(0, 'times', times_fixed)

# If unavailable compute on the fly
else:
_logger.info('Pupil diameter not available, trying to compute on the fly.')
if (self.data_info[self.data_info['name'] == 'pose']['is_loaded'].values[0]
and 'leftCamera' in self.pose.keys()):
# If pose data is already loaded, we don't know if it was threshold at 0.9, so we need a little stunt
copy_pose = self.pose['leftCamera'].copy() # Save the previously loaded pose data
self.load_pose(views=['left'], likelihood_thr=0.9) # Load new with threshold 0.9
dlc_thr = self.pose['leftCamera'].copy() # Save the threshold pose data in new variable
self.pose['leftCamera'] = copy_pose.copy() # Get previously loaded pose data back in place
else:
self.load_pose(views=['left'], likelihood_thr=0.9)
dlc_thr = self.pose['leftCamera'].copy()

self.pupil['pupilDiameter_raw'] = get_pupil_diameter(dlc_thr)
try:
self.pupil['pupilDiameter_smooth'] = get_smooth_pupil_diameter(self.pupil['pupilDiameter_raw'], 'left')
except BaseException as e:
_logger.error("Computing smooth pupil diameter failed, saving all NaNs.")
_logger.debug(e)
self.pupil['pupilDiameter_smooth'] = np.nan

if not np.all(np.isnan(self.pupil['pupilDiameter_smooth'])):
good_idxs = np.where(
~np.isnan(self.pupil['pupilDiameter_smooth']) & ~np.isnan(self.pupil['pupilDiameter_raw']))[0]
snr = (np.var(self.pupil['pupilDiameter_smooth'][good_idxs]) /
(np.var(self.pupil['pupilDiameter_smooth'][good_idxs] - self.pupil['pupilDiameter_raw'][good_idxs])))
if snr < snr_thresh:
self.pupil = pd.DataFrame()
_logger.error(f'Pupil diameter SNR ({snr:.2f}) below threshold SNR ({snr_thresh}), removing data.')

def _check_video_timestamps(self, view, video_timestamps, video_data):
"""
Helper function to check for the length of the video frames vs video timestamps and fix in case
timestamps are longer than video frames.
"""
# If camera times are shorter than video data, or empty, no current fix
if video_timestamps.shape[0] < video_data.shape[0]:
if video_timestamps.shape[0] == 0:
msg = f'Camera times empty for {view}Camera.'
else:
msg = f'Camera times are shorter than video data for {view}Camera.'
_logger.warning(msg)
raise ValueError(msg)
# For pre-GPIO sessions, it is possible that the camera times are longer than the actual video.
# This is because the first few frames are sometimes not recorded. We can remove the first few
# timestamps in this case
elif video_timestamps.shape[0] > video_data.shape[0]:
video_timestamps_fixed = video_timestamps[-video_data.shape[0]:]
return video_timestamps_fixed, video_data
else:
return video_timestamps, video_data