Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 20 additions & 50 deletions brainbox/processing.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
'''
Processes data from one form into another, e.g. taking spike times and binning them into
non-overlapping bins and convolving spike times with a gaussian kernel.
'''
"""Process data from one form into another.

For example, taking spike times and binning them into non-overlapping bins and convolving spike
times with a gaussian kernel.
"""

import numpy as np
import pandas as pd
from scipy import interpolate, sparse
from brainbox import core
from iblutil.numerical import bincount2D as _bincount2D
from iblutil.numerical import bincount2D
from iblutil.util import Bunch
import logging
import warnings
import traceback

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -118,35 +117,6 @@ def sync(dt, times=None, values=None, timeseries=None, offsets=None, interp='zer
return syncd


def bincount2D(x, y, xbin=0, ybin=0, xlim=None, ylim=None, weights=None):
"""
Computes a 2D histogram by aggregating values in a 2D array.

:param x: values to bin along the 2nd dimension (c-contiguous)
:param y: values to bin along the 1st dimension
:param xbin:
scalar: bin size along 2nd dimension
0: aggregate according to unique values
array: aggregate according to exact values (count reduce operation)
:param ybin:
scalar: bin size along 1st dimension
0: aggregate according to unique values
array: aggregate according to exact values (count reduce operation)
:param xlim: (optional) 2 values (array or list) that restrict range along 2nd dimension
:param ylim: (optional) 2 values (array or list) that restrict range along 1st dimension
:param weights: (optional) defaults to None, weights to apply to each value for aggregation
:return: 3 numpy arrays MAP [ny,nx] image, xscale [nx], yscale [ny]
"""
for line in traceback.format_stack():
print(line.strip())
warning_text = """Future warning: bincount2D() is now a part of iblutil.
brainbox.processing.bincount2D will be removed in future versions.
Please replace imports with iblutil.numerical.bincount2D."""
_logger.warning(warning_text)
warnings.warn(warning_text, FutureWarning)
return _bincount2D(x, y, xbin, ybin, xlim, ylim, weights)


def compute_cluster_average(spike_clusters, spike_var):
"""
Quickish way to compute the average of some quantity across spikes in each cluster given
Expand Down Expand Up @@ -197,7 +167,7 @@ def bin_spikes(spikes, binsize, interval_indices=False):


def get_units_bunch(spks_b, *args):
'''
"""
Returns a bunch, where the bunch keys are keys from `spks` with labels of spike information
(e.g. unit IDs, times, features, etc.), and the values for each key are arrays with values for
each unit: these arrays are ordered and can be indexed by unit id.
Expand All @@ -223,18 +193,18 @@ def get_units_bunch(spks_b, *args):
--------
1) Create a units bunch given a spikes bunch, and get the amps for unit #4 from the units
bunch.
>>> import brainbox as bb
>>> import alf.io as aio
>>> from brainbox import processing
>>> import one.alf.io as alfio
>>> import ibllib.ephys.spikes as e_spks
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
>>> units_b = bb.processing.get_units_bunch(spks_b)
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> units_b = processing.get_units_bunch(spks_b)
# Get amplitudes for unit 4.
>>> amps = units_b['amps']['4']

TODO add computation time estimate?
'''
"""

# Initialize `units`
units_b = Bunch()
Expand All @@ -261,7 +231,7 @@ def get_units_bunch(spks_b, *args):


def filter_units(units_b, t, **kwargs):
'''
"""
Filters units according to some parameters. **kwargs are the keyword parameters used to filter
the units.

Expand Down Expand Up @@ -299,24 +269,24 @@ def filter_units(units_b, t, **kwargs):
Examples
--------
1) Filter units according to the default parameters.
>>> import brainbox as bb
>>> import alf.io as aio
>>> from brainbox import processing
>>> import one.alf.io as alfio
>>> import ibllib.ephys.spikes as e_spks
(*Note, if there is no 'alf' directory, make 'alf' directory from 'ks2' output directory):
>>> e_spks.ks2_to_alf(path_to_ks_out, path_to_alf_out)
# Get a spikes bunch, units bunch, and filter the units.
>>> spks_b = aio.load_object(path_to_alf_out, 'spikes')
>>> units_b = bb.processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
>>> spks_b = alfio.load_object(path_to_alf_out, 'spikes')
>>> units_b = processing.get_units_bunch(spks_b, ['times', 'amps', 'clusters'])
>>> T = spks_b['times'][-1] - spks_b['times'][0]
>>> filtered_units = bb.processing.filter_units(units_b, T)
>>> filtered_units = processing.filter_units(units_b, T)

2) Filter units with no minimum amplitude, a minimum firing rate of 1 Hz, and a max false
positive rate of 0.2, given a refractory period of 2 ms.
>>> filtered_units = bb.processing.filter_units(units_b, T, min_amp=0, min_fr=1)
>>> filtered_units = processing.filter_units(units_b, T, min_amp=0, min_fr=1)

TODO: `units_b` input arg could eventually be replaced by `clstrs_b` if the required metrics
are in `clstrs_b['metrics']`
'''
"""

# Set params
params = {'min_amp': 50e-6, 'min_fr': 0.5, 'max_fpr': 0.2, 'rp': 0.002} # defaults
Expand Down
15 changes: 0 additions & 15 deletions brainbox/tests/test_processing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from brainbox import processing, core
import unittest
import numpy as np
import datetime


class TestProcessing(unittest.TestCase):
Expand Down Expand Up @@ -63,15 +62,6 @@ def test_sync(self):
self.assertTrue(times2.min() >= resamp2.times.min())
self.assertTrue(times2.max() <= resamp2.times.max())

def test_bincount2D_deprecation(self):
# Timer to remove bincount2D (now in iblutil)
# Once this test fails:
# - Remove the bincount2D method in processing.py
# - Remove the import from iblutil at the top of that file
# - Delete this test
if datetime.datetime.now() > datetime.datetime(2024, 6, 30):
raise NotImplementedError

def test_compute_cluster_averag(self):
# Create fake data for 3 clusters
clust1 = np.ones(40)
Expand Down Expand Up @@ -103,11 +93,6 @@ def test_compute_cluster_averag(self):
self.assertEqual(avg_val[2], 0.75)
self.assertTrue(np.all(count == (40, 40, 50)))

def test_deprecations(self):
"""Ensure removal of bincount2D function."""
from datetime import datetime
self.assertTrue(datetime.today() < datetime(2024, 8, 1), 'remove brainbox.processing.bincount2D')


if __name__ == '__main__':
np.random.seed(0)
Expand Down
47 changes: 29 additions & 18 deletions ibllib/io/extractors/mesoscope.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def patch_imaging_meta(meta: dict) -> dict:
for fov in meta.get('FOV', []):
if 'roiUuid' in fov:
fov['roiUUID'] = fov.pop('roiUuid')
assert 'nFrames' in meta, '"nFrames" key missing from meta data; rawImagingData.meta.json likely an old version'
return meta


Expand Down Expand Up @@ -753,6 +754,15 @@ def get_timeshifts(raw_imaging_meta):
Calculate the time shifts for each field of view (FOV) and the relative offsets for each
scan line.

For a 2 scan field, 2 depth recording (so 4 FOVs):

Frame 1, lines 1-512 correspond to FOV_00
Frame 1, lines 551-1062 correspond to FOV_01
Frame 2, lines 1-512 correspond to FOV_02
Frame 2, lines 551-1062 correspond to FOV_03
Frame 3, lines 1-512 correspond to FOV_00
...

Parameters
----------
raw_imaging_meta : dict
Expand All @@ -772,26 +782,27 @@ def get_timeshifts(raw_imaging_meta):
FOVs = raw_imaging_meta['FOV']

# Double-check meta extracted properly
raw_meta = raw_imaging_meta['rawScanImageMeta']
artist = raw_meta['Artist']
assert sum(x['enable'] for x in artist['RoiGroups']['imagingRoiGroup']['rois']) == len(FOVs)

# assert meta.FOV.Zs is ascending but use slice_id field. This may not be necessary but is expected.
slice_ids = np.array([fov['slice_id'] for fov in FOVs])
assert np.all(np.diff([x['Zs'] for x in FOVs]) >= 0), 'FOV depths not in ascending order'
assert np.all(np.diff(slice_ids) >= 0), 'slice IDs not ordered'
# Number of scan lines per FOV, i.e. number of Y pixels / image height
n_lines = np.array([x['nXnYnZ'][1] for x in FOVs])
n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks
# Number of lines during flyback
n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1))
# The start and end indices of each FOV in the raw images
fov_start_idx = np.insert(np.cumsum(n_lines[:-1] + n_lines_per_gap), 0, 0)
fov_end_idx = fov_start_idx + n_lines
line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod']

line_indices = []
fov_time_shifts = fov_start_idx * line_period
line_time_shifts = []

for ln, s, e in zip(n_lines, fov_start_idx, fov_end_idx):
line_indices.append(np.arange(s, e))
line_time_shifts.append(np.arange(0, ln) * line_period)
# We get indices from MATLAB extracted metadata so below two lines are no longer needed
# n_valid_lines = np.sum(n_lines) # Number of lines imaged excluding flybacks
# n_lines_per_gap = int((raw_meta['Height'] - n_valid_lines) / (len(FOVs) - 1)) # N lines during flyback
line_period = raw_imaging_meta['scanImageParams']['hRoiManager']['linePeriod']
frame_time_shifts = slice_ids / raw_imaging_meta['scanImageParams']['hRoiManager']['scanFrameRate']

# Line indices are now extracted by the MATLAB function mesoscopeMetadataExtraction.m
# They are indexed from 1 so we subtract 1 to convert to zero-indexed
line_indices = [np.array(fov['lineIdx']) - 1 for fov in FOVs] # Convert to zero-indexed from MATLAB 1-indexed
assert all(lns.size == n for lns, n in zip(line_indices, n_lines)), 'unexpected number of scan lines'
# The start indices of each FOV in the raw images
fov_start_idx = np.array([lns[0] for lns in line_indices])
roi_time_shifts = fov_start_idx * line_period # The time offset for each FOV
fov_time_shifts = roi_time_shifts + frame_time_shifts
line_time_shifts = [(lns - ln0) * line_period for lns, ln0 in zip(line_indices, fov_start_idx)]

return line_indices, fov_time_shifts, line_time_shifts
37 changes: 13 additions & 24 deletions ibllib/pipes/local_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def job_creator(root_path, one=None, dry=False, rerun=False):
return pipes, all_datasets


def task_queue(mode='all', lab=None, alyx=None):
def task_queue(mode='all', lab=None, alyx=None, env=(None,)):
"""
Query waiting jobs from the specified Lab

Expand All @@ -164,12 +164,18 @@ def task_queue(mode='all', lab=None, alyx=None):
Lab name as per Alyx, otherwise try to infer from local Globus install.
alyx : one.webclient.AlyxClient
An Alyx instance.
env : list
One or more environments to filter by. See :prop:`ibllib.pipes.tasks.Task.env`.

Returns
-------
list of dict
A list of Alyx tasks associated with `lab` that have a 'Waiting' status.
"""
def predicate(task):
classe = tasks.str2class(task['executable'])
return (mode == 'all' or classe.job_size == mode) and classe.env in env

alyx = alyx or AlyxClient(cache_rest=None)
if lab is None:
_logger.debug('Trying to infer lab from globus installation')
Expand All @@ -179,28 +185,12 @@ def task_queue(mode='all', lab=None, alyx=None):
return # if the lab is none, this will return empty tasks each time
data_repo = get_local_data_repository(alyx)
# Filter for tasks
tasks_all = alyx.rest('tasks', 'list', status='Waiting',
django=f'session__lab__name__in,{lab},data_repository__name,{data_repo}', no_cache=True)
if mode == 'all':
waiting_tasks = tasks_all
else:
small_jobs = []
large_jobs = []
for t in tasks_all:
strmodule, strclass = t['executable'].rsplit('.', 1)
classe = getattr(importlib.import_module(strmodule), strclass)
job_size = classe.job_size
if job_size == 'small':
small_jobs.append(t)
else:
large_jobs.append(t)
if mode == 'small':
waiting_tasks = small_jobs
elif mode == 'large':
waiting_tasks = large_jobs

waiting_tasks = alyx.rest('tasks', 'list', status='Waiting',
django=f'session__lab__name__in,{lab},data_repository__name,{data_repo}', no_cache=True)
# Filter tasks by size
filtered_tasks = filter(predicate, waiting_tasks)
# Order tasks by priority
sorted_tasks = sorted(waiting_tasks, key=lambda d: d['priority'], reverse=True)
sorted_tasks = sorted(filtered_tasks, key=lambda d: d['priority'], reverse=True)

return sorted_tasks

Expand Down Expand Up @@ -252,8 +242,7 @@ def tasks_runner(subjects_path, tasks_dict, one=None, dry=False, count=5, time_o
if dry:
print(session_path, tdict['name'])
else:
task, dsets = tasks.run_alyx_task(tdict=tdict, session_path=session_path,
one=one, **kwargs)
task, dsets = tasks.run_alyx_task(tdict=tdict, session_path=session_path, one=one, **kwargs)
if dsets:
all_datasets.extend(dsets)
c += 1
Expand Down
32 changes: 21 additions & 11 deletions ibllib/pipes/mesoscope_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,11 +193,12 @@ def _run(self, remove_uncompressed=False, verify_output=True, clobber=False, **k


class MesoscopePreprocess(base_tasks.MesoscopeTask):
"""Run suite2p preprocessing on tif files"""
"""Run suite2p preprocessing on tif files."""

priority = 80
cpu = -1
job_size = 'large'
env = 'suite2p'

@property
def signature(self):
Expand Down Expand Up @@ -426,10 +427,7 @@ def _consolidate_exptQC(exptQC):

# Concatenate frames
frameQC = np.concatenate([e['frameQC_frames'] for e in exptQC], axis=0)

# Transform to bad_frames as expected by suite2p
bad_frames = np.where(frameQC != 0)[0]

return frameQC, frameQC_names, bad_frames

def get_default_tau(self):
Expand Down Expand Up @@ -555,7 +553,7 @@ def _run(self, run_suite2p=True, rename_files=True, use_badframes=True, **kwargs
""" Metadata and parameters """
# Load metadata and make sure all metadata is consistent across FOVs
meta_files = sorted(self.session_path.glob(f'{self.device_collection}/*rawImagingData.meta.*'))
collections = set(f.parts[-2] for f in meta_files)
collections = sorted(set(f.parts[-2] for f in meta_files))
# Check there is exactly 1 meta file per collection
assert len(meta_files) == len(list(self.session_path.glob(self.device_collection))) == len(collections)
rawImagingData = [mesoscope.patch_imaging_meta(alfio.load_file_content(filepath)) for filepath in meta_files]
Expand All @@ -576,19 +574,31 @@ def _run(self, run_suite2p=True, rename_files=True, use_badframes=True, **kwargs
self.kwargs = {**self.kwargs, **db}

""" Bad frames """
# exptQC.mat contains experimenter QC values that may not affect ROI detection (e.g. noises, pauses)
qc_paths = (self.session_path.joinpath(f[1], 'exptQC.mat')
for f in self.input_files if f[0] == 'exptQC.mat')
qc_paths = sorted(map(str, filter(Path.exists, qc_paths)))
exptQC = [loadmat(p, squeeze_me=True, simplify_cells=True) for p in qc_paths]
if len(exptQC) > 0:
frameQC, frameQC_names, bad_frames = self._consolidate_exptQC(exptQC)
frameQC, frameQC_names, _ = self._consolidate_exptQC(exptQC)
else:
_logger.warning('No frame QC (exptQC.mat) files found.')
frameQC, bad_frames = np.array([], dtype='u1'), np.array([], dtype='i8')
frameQC = np.array([], dtype='u1')
frameQC_names = pd.DataFrame(columns=['qc_values', 'qc_labels'])

# If applicable, save as bad_frames.npy in first raw_imaging_folder for suite2p
if len(bad_frames) > 0 and use_badframes is True:
np.save(Path(db['data_path'][0]).joinpath('bad_frames.npy'), bad_frames)
# badframes.mat contains QC values that do affect ROI detection (e.g. no PMT, lens artefacts)
badframes = np.array([], dtype='i8')
total_frames = 0
# Ensure all indices are relative to total cumulative frames
for m, collection in zip(rawImagingData, collections):
badframes_path = self.session_path.joinpath(collection, 'badframes.mat')
if badframes_path.exists():
raw_mat = loadmat(badframes_path, squeeze_me=True, simplify_cells=True)['badframes']
badframes = np.r_[badframes, raw_mat + total_frames]
total_frames += m['nFrames']
if len(badframes) > 0 and use_badframes is True:
np.save(Path(db['data_path'][0]).joinpath('bad_frames.npy'), badframes)

""" Suite2p """
# Create alf it is doesn't exist
Expand Down Expand Up @@ -625,10 +635,10 @@ def signature(self):
(f'_{self.sync_namespace}_DAQdata.timestamps.npy', self.sync_collection, True),
(f'_{self.sync_namespace}_DAQdata.meta.json', self.sync_collection, True),
('_ibl_rawImagingData.meta.json', self.device_collection, True),
('rawImagingData.times_scanImage.npy', self.device_collection, True),
('rawImagingData.times_scanImage.npy', self.device_collection, True, True), # register raw
(f'_{self.sync_namespace}_softwareEvents.log.htsv', self.sync_collection, False), ],
'output_files': [('mpci.times.npy', 'alf/mesoscope/FOV*', True),
('mpciStack.timeshift.npy', 'alf/mesoscope/FOV*', True), ]
('mpciStack.timeshift.npy', 'alf/mesoscope/FOV*', True),]
}
return signature

Expand Down
Loading