Skip to content

Commit

Permalink
docs: 📝 Add dummy example for run_blockwise()
Browse files Browse the repository at this point in the history
Also fixed docstrings for blockwise worker functions.
  • Loading branch information
rhoadesScholar committed Apr 18, 2024
1 parent fc56d0d commit 0f7a007
Show file tree
Hide file tree
Showing 12 changed files with 153 additions and 196 deletions.
28 changes: 1 addition & 27 deletions dacapo/blockwise/argmax_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,6 @@ def cli(log_level):
Args:
log_level (str): The log level to use.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))

Expand Down Expand Up @@ -69,12 +63,6 @@ def start_worker(
input_dataset (str): The input dataset.
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(input_container="input_container", input_dataset="input_dataset", output_container="output_container", output_dataset="output_dataset")
Note:
The method is implemented in the class.
"""
# get arrays
input_array_identifier = LocalArrayIdentifier(Path(input_container), input_dataset)
Expand Down Expand Up @@ -119,13 +107,7 @@ def spawn_worker(
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
Returns:
The worker to run.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(model, raw_array, prediction_array_identifier)
Note:
The method is implemented in the class.
Callable: The function to run the worker.
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
Expand Down Expand Up @@ -156,15 +138,7 @@ def spawn_worker(
def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
compute_context.execute(command)

return run_worker
Expand Down
8 changes: 0 additions & 8 deletions dacapo/blockwise/blockwise_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ class DaCapoBlockwiseTask(Task):
Methods:
__init__:
Initialize the task.
Note:
The method is implemented in the class.
"""

def __init__(
Expand Down Expand Up @@ -58,12 +56,6 @@ def __init__(
upstream_tasks: The upstream tasks.
*args: Additional positional arguments to pass to ``worker_function``.
**kwargs: Additional keyword arguments to pass to ``worker_function``.
Raises:
ValueError: If the worker file is not a valid path.
Examples:
>>> DaCapoBlockwiseTask(worker_file="worker_file", total_roi=Roi, read_roi=Roi, write_roi=Roi, num_workers=16, max_retries=2, timeout=None, upstream_tasks=None)
Note:
The method is implemented in the class.
"""
# Load worker functions
worker_name = Path(worker_file).stem
Expand Down
29 changes: 2 additions & 27 deletions dacapo/blockwise/predict_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ def cli(log_level):
Args:
log_level (str): The log level to use for logging.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))

Expand Down Expand Up @@ -94,13 +88,6 @@ def start_worker(
input_dataset (str): The input dataset.
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(run_name="run", iteration=0, input_container="input", input_dataset="input", output_container="output", output_dataset="output")
Note:
The method is implemented in the class.
"""
compute_context = create_compute_context()
device = compute_context.device
Expand Down Expand Up @@ -233,12 +220,8 @@ def spawn_worker(
iteration (int or None): The training iteration of the model to use for prediction.
input_array_identifier (LocalArrayIdentifier): The raw data to predict on.
output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(run_name="run", iteration=0, input_array_identifier=LocalArrayIdentifier(Path("input"), "input"), output_array_identifier=LocalArrayIdentifier(Path("output"), "output"))
Note:
The method is implemented in the class.
Returns:
Callable: The function to run the worker.
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
Expand Down Expand Up @@ -277,15 +260,7 @@ def spawn_worker(
def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
print("Running worker with command: ", command)
compute_context.execute(command)

Expand Down
44 changes: 0 additions & 44 deletions dacapo/blockwise/relabel_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,6 @@ def cli(log_level):
Args:
log_level (str): The log level to use.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))

Expand Down Expand Up @@ -63,12 +57,6 @@ def start_worker(
output_container (str): The output container
output_dataset (str): The output dataset
tmpdir (str): The temporary directory
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(output_container="output_container", output_dataset="output_dataset", tmpdir="tmpdir")
Note:
The method is implemented in the class.
"""
client = daisy.Client()
array_out = open_ds(output_container, output_dataset, mode="a")
Expand Down Expand Up @@ -108,12 +96,6 @@ def relabel_in_block(array_out, old_values, new_values, block):
old_values (np.ndarray): The old values
new_values (np.ndarray): The new values
block (daisy.Block): The block
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> relabel_in_block(array_out, old_values, new_values, block)
Note:
The method is implemented in the class.
"""
a = array_out.to_ndarray(block.write_roi)
# DGA: had to add in flatten and reshape since remap (in particular indices) didn't seem to work with ndarrays for the input
Expand All @@ -131,12 +113,6 @@ def find_components(nodes, edges):
edges (np.ndarray): The edges
Returns:
List[int]: The components
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> find_components(nodes, edges)
Note:
The method is implemented in the class.
"""
# scipy
disjoint_set = DisjointSet(nodes)
Expand All @@ -153,12 +129,6 @@ def read_cross_block_merges(tmpdir):
tmpdir (str): The temporary directory
Returns:
Tuple[np.ndarray, np.ndarray]: The nodes and edges
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> read_cross_block_merges(tmpdir)
Note:
The method is implemented in the class.
"""
block_files = glob(os.path.join(tmpdir, "block_*.npz"))

Expand Down Expand Up @@ -186,12 +156,6 @@ def spawn_worker(
tmpdir (str): The temporary directory
Returns:
Callable: The function to run the worker
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(output_array_identifier, tmpdir)
Note:
The method is implemented in the class.
"""
compute_context = create_compute_context()

Expand Down Expand Up @@ -220,15 +184,7 @@ def spawn_worker(
def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
compute_context.execute(command)

return run_worker
Expand Down
8 changes: 0 additions & 8 deletions dacapo/blockwise/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,8 @@ def run_blockwise(
Additional keyword arguments to pass to ``worker_function``.
Returns:
``Bool``.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_blockwise(worker_file, total_roi, read_roi, write_roi, num_workers, max_retries, timeout, upstream_tasks)
Note:
The method is implemented in the class.
"""

Expand Down Expand Up @@ -127,12 +123,8 @@ def segment_blockwise(
Additional keyword arguments to pass to ``worker_function``.
Returns:
``Bool``.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> segment_blockwise(segment_function_file, context, total_roi, read_roi, write_roi, num_workers, max_retries, timeout, upstream_tasks)
Note:
The method is implemented in the class.
"""
options = Options.instance()
if not options.runs_base_dir.exists():
Expand Down
28 changes: 2 additions & 26 deletions dacapo/blockwise/segment_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,6 @@ def cli(log_level):
Args:
log_level (str): The log level to use.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))

Expand Down Expand Up @@ -73,12 +67,6 @@ def start_worker(
tmpdir (str): The temporary directory.
function_path (str): The path to the segment function.
return_io_loop (bool): Whether to return the io loop or run it.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(input_container="input_container", input_dataset="input_dataset", output_container="output_container", output_dataset="output_dataset", tmpdir="tmpdir", function_path="function_path")
Note:
The method is implemented in the class.
"""

print("Starting worker")
Expand Down Expand Up @@ -218,12 +206,8 @@ def spawn_worker(
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(input_array_identifier="input_array_identifier", output_array_identifier="output_array_identifier", tmpdir="tmpdir", function_path="function_path")
Note:
The method is implemented in the class.
Returns:
Callable: The function to run the worker.
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
Expand Down Expand Up @@ -260,15 +244,7 @@ def spawn_worker(
def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
compute_context.execute(command)

return run_worker
Expand Down
41 changes: 3 additions & 38 deletions dacapo/blockwise/threshold_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,6 @@
default="INFO",
)
def cli(log_level):
"""
CLI for running the threshold worker.
Args:
log_level (str): The log level to use.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> cli(log_level="INFO")
Note:
The method is implemented in the class.
"""
logging.basicConfig(level=getattr(logging, log_level.upper()))


Expand Down Expand Up @@ -72,12 +60,6 @@ def start_worker(
output_container (Path | str): The output container.
output_dataset (str): The output dataset.
threshold (float): The threshold.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> start_worker(input_container="input", input_dataset="input", output_container="output", output_dataset="output", threshold=0.0)
Note:
The method is implemented in the class.
"""
# get arrays
Expand Down Expand Up @@ -119,19 +101,11 @@ def spawn_worker(
Spawn a worker to predict on a given dataset.
Args:
model (Model): The model to use for prediction.
raw_array (Array): The raw data to predict on.
prediction_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
block_shape (Tuple[int]): The shape of the blocks.
halo (Tuple[int]): The halo to use.
input_array_identifier (LocalArrayIdentifier): The raw data to predict on.
output_array_identifier (LocalArrayIdentifier): The identifier of the prediction array.
threshold (float): The threshold.
Returns:
Callable: The function to run the worker.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> spawn_worker(model, raw_array, prediction_array_identifier, block_shape, halo)
Note:
The method is implemented in the class.
"""
compute_context = create_compute_context()
if not compute_context.distribute_workers:
Expand All @@ -145,7 +119,6 @@ def spawn_worker(

# Make the command for the worker to run
command = [
# "python",
sys.executable,
path,
"start-worker",
Expand All @@ -164,15 +137,7 @@ def spawn_worker(
def run_worker():
"""
Run the worker in the given compute context.
Raises:
NotImplementedError: If the method is not implemented in the derived class.
Examples:
>>> run_worker()
Note:
The method is implemented in the class.
"""
# Run the worker in the given compute context
compute_context.execute(command)

return run_worker
Expand Down
Loading

0 comments on commit 0f7a007

Please sign in to comment.