Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Merge pull request #37 from dstansby/callbacks
Browse files Browse the repository at this point in the history
Add option of externally specified callbacks for detection
  • Loading branch information
dstansby committed Mar 11, 2022
2 parents 8512788 + eeffcb4 commit 9d31b66
Show file tree
Hide file tree
Showing 7 changed files with 150 additions and 13 deletions.
26 changes: 25 additions & 1 deletion cellfinder_core/classify/classify.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import logging
from typing import Callable, Optional

import numpy as np
from imlib.general.system import get_num_processes
from tensorflow import keras

from cellfinder_core.classify.cube_generator import CubeGeneratorFromFile
from cellfinder_core.classify.tools import get_model
Expand All @@ -23,13 +25,26 @@ def main(
model_weights,
network_depth,
max_workers=3,
*,
callback: Optional[Callable[[int], None]] = None,
):

"""
Parameters
----------
callback : Callable[int], optional
A callback function that is called during classification. Called with
the batch number once that batch has been classified.
"""
if signal_array.ndim != 3:
raise IOError("Signal data must be 3D")
if background_array.ndim != 3:
raise IOError("Background data must be 3D")

if callback is not None:
callbacks = [BatchEndCallback(callback)]
else:
callbacks = None

# Too many workers doesn't increase speed, and uses huge amounts of RAM
workers = get_num_processes(
min_free_cpu_cores=n_free_cpus, n_max_processes=max_workers
Expand Down Expand Up @@ -61,6 +76,7 @@ def main(
use_multiprocessing=True,
workers=workers,
verbose=True,
callbacks=callbacks,
)
predictions = predictions.round()
predictions = predictions.astype("uint16")
Expand All @@ -74,3 +90,11 @@ def main(
points_list.append(cell)

return points_list


class BatchEndCallback(keras.callbacks.Callback):
def __init__(self, callback):
self._callback = callback

def on_predict_batch_end(self, batch, logs=None):
self._callback(batch)
7 changes: 4 additions & 3 deletions cellfinder_core/classify/cube_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ def __check_z_scaling(self):

if self.num_planes_needed_for_cube > self.image_z_size:
raise StackSizeError(
"The number of planes provided is not sufficient "
"for any cubes to be extracted. Please check the "
"input data"
f"The number of planes provided ({self.image_z_size}) "
"is not sufficient for any cubes to be extracted "
f"(need at least {self.num_planes_needed_for_cube}). "
"Please check the input data"
)

def __remove_outlier_points(self):
Expand Down
45 changes: 38 additions & 7 deletions cellfinder_core/detect/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from datetime import datetime
from multiprocessing import Lock
from multiprocessing import Queue as MultiprocessingQueue
from typing import Callable

import numpy as np
from imlib.general.system import get_num_processes
Expand Down Expand Up @@ -56,7 +57,16 @@ def main(
artifact_keep=False,
save_planes=False,
plane_directory=None,
*,
callback: Callable[[int], None] = None,
):
"""
Parameters
----------
callback : Callable[int], optional
A callback function that is called every time a plane has finished
being processed. Called with the plane number that has finished.
"""
n_processes = get_num_processes(min_free_cpu_cores=n_free_cpus)
start_time = datetime.now()

Expand All @@ -76,11 +86,16 @@ def main(
if end_plane == -1:
end_plane = len(signal_array)
signal_array = signal_array[start_plane:end_plane]
callback = callback or (lambda *args, **kwargs: None)

workers_queue = MultiprocessingQueue(maxsize=n_processes)
workers_queue: MultiprocessingQueue = MultiprocessingQueue(
maxsize=n_processes
)
# WARNING: needs to be AT LEAST ball_z_size
mp_3d_filter_queue = MultiprocessingQueue(maxsize=ball_z_size)
for plane_id in range(n_processes):
mp_3d_filter_queue: MultiprocessingQueue = MultiprocessingQueue(
maxsize=ball_z_size
)
for _ in range(n_processes):
# place holder for the queue to have the right size on first run
workers_queue.put(None)

Expand All @@ -95,11 +110,14 @@ def main(
ball_overlap_fraction,
start_plane,
]
output_queue = MultiprocessingQueue()
output_queue: MultiprocessingQueue = MultiprocessingQueue()
planes_done_queue: MultiprocessingQueue = MultiprocessingQueue()

# Create 3D analysis filter
mp_3d_filter = Mp3DFilter(
mp_3d_filter_queue,
output_queue,
planes_done_queue,
soma_diameter,
setup_params=setup_params,
soma_size_spread_factor=soma_spread_factor,
Expand All @@ -116,11 +134,14 @@ def main(
bf_process = multiprocessing.Process(target=mp_3d_filter.process, args=())
bf_process.start() # needs to be started before the loop
clipping_val, threshold_value = setup_tile_filtering(signal_array[0, :, :])

# Create 2D analysis filter
mp_tile_processor = MpTileProcessor(workers_queue, mp_3d_filter_queue)
prev_lock = Lock()
processes = []

# start 2D tile filter (output goes into queue for 3D analysis)
# Creates a list of (running) processes for each 2D plane
processes = []
for plane_id, plane in enumerate(signal_array):
workers_queue.get()
lock = Lock()
Expand All @@ -143,9 +164,19 @@ def main(
processes.append(p)
p.start()

processes[-1].join()
mp_3d_filter_queue.put((None, None, None)) # Signal the end
# Trigger callback when 3D filtering is done on a plane
nplanes_done = 0
while nplanes_done < len(signal_array):
callback(planes_done_queue.get(block=True))
nplanes_done += 1

# Wait for all the 2D filters to process
for p in processes:
p.join()
# Tell 3D filter that there are no more planes left
mp_3d_filter_queue.put((None, None, None))
cells = output_queue.get()
# Wait for 3D filter to finish
bf_process.join()

print(
Expand Down
14 changes: 12 additions & 2 deletions cellfinder_core/detect/filters/plane/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@ def process(
log_sigma_size,
n_sds_above_mean_thresh,
):
"""
Parameters
----------
previous_lock : multiprocessing.Lock
Lock for the previous tile in the processing queue.
self_lock : multiprocessing.Lock
Lock for the current tile.
"""
laplace_gaussian_sigma = log_sigma_size * soma_diameter
plane = plane.T
np.clip(plane, 0, clipping_value, out=plane)
Expand All @@ -44,8 +52,10 @@ def process(
] = threshold_value
tile_mask = walker.good_tiles_mask.astype(np.uint8)

with previous_lock:
pass
# Wait for previous plane to be done
previous_lock.acquire()
previous_lock.release()

self.ball_filter_q.put((plane_id, plane, tile_mask))
self.thread_q.put(plane_id)
self_lock.release()
3 changes: 3 additions & 0 deletions cellfinder_core/detect/filters/volume/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def __init__(
self,
data_queue,
output_queue,
planes_done_queue,
soma_diameter,
soma_size_spread_factor=1.4,
setup_params=None,
Expand All @@ -34,6 +35,7 @@ def __init__(
):
self.data_queue = data_queue
self.output_queue = output_queue
self.planes_done_queue = planes_done_queue
self.soma_diameter = soma_diameter
self.soma_size_spread_factor = soma_size_spread_factor
self.progress_bar = None
Expand Down Expand Up @@ -98,6 +100,7 @@ def process(self):
" (out of bounds)"
)

self.planes_done_queue.put(self.z)
self.z += 1
if self.progress_bar is not None:
self.progress_bar.update()
Expand Down
21 changes: 21 additions & 0 deletions cellfinder_core/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,23 @@ def main(
cube_height=50,
cube_depth=20,
network_depth="50",
*,
detect_callback=None,
classify_callback=None,
detect_finished_callback=None,
):
"""
Parameters
----------
detect_callback : Callable[int], optional
Called every time a plane has finished being processed during the
detection stage. Called with the plane number that has finished.
classify_callback : Callable[int], optional
Called every time tensorflow has finished classifying a point.
Called with the batch number that has just finished.
detect_finished_callback : Callable[list], optional
Called after detection is finished with the list of detected points.
"""
suppress_tf_logging(tf_suppress_log_messages)

from pathlib import Path
Expand All @@ -64,8 +80,12 @@ def main(
n_free_cpus,
log_sigma_size,
n_sds_above_mean_thresh,
callback=detect_callback,
)

if detect_finished_callback is not None:
detect_finished_callback(points)

model_weights = prep.prep_classification(
trained_model, model_weights, install_path, model, n_free_cpus
)
Expand All @@ -85,6 +105,7 @@ def main(
trained_model,
model_weights,
network_depth,
callback=classify_callback,
)
else:
logging.info("No candidates, skipping classification")
Expand Down
47 changes: 47 additions & 0 deletions tests/tests/test_integration/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from math import isclose

import imlib.IO.cells as cell_io
import numpy as np
import pytest

from cellfinder_core.main import main
Expand All @@ -18,6 +19,16 @@
DETECTION_TOLERANCE = 2


@pytest.fixture
def signal_array():
return read_with_dask(signal_data_path)


@pytest.fixture
def background_array():
return read_with_dask(background_data_path)


# FIXME: This isn't a very good example


Expand Down Expand Up @@ -50,3 +61,39 @@ def test_detection_full():
assert isclose(
num_cells_validation, num_cells_test, abs_tol=DETECTION_TOLERANCE
)


def test_callbacks(signal_array, background_array):
# 20 is minimum number of planes needed to find > 0 cells
signal_array = signal_array[0:20]
background_array = background_array[0:20]

planes_done = []
batches_classified = []
points_found = []

def detect_callback(plane):
planes_done.append(plane)

def classify_callback(batch):
batches_classified.append(batch)

def detect_finished_callback(points):
points_found.append(points)

main(
signal_array,
background_array,
voxel_sizes,
detect_callback=detect_callback,
classify_callback=classify_callback,
detect_finished_callback=detect_finished_callback,
)

np.testing.assert_equal(planes_done, np.arange(len(signal_array)))
np.testing.assert_equal(batches_classified, [0])

ncalls = len(points_found)
assert ncalls == 1, f"Expected 1 call to callback, got {ncalls}"
npoints = len(points_found[0])
assert npoints == 120, f"Expected 120 points, found {npoints}"

0 comments on commit 9d31b66

Please sign in to comment.