Skip to content

Commit

Permalink
Rename "Defer errors" to "Stop on first error" (#2449)
Browse files Browse the repository at this point in the history
* Rename "Defer Errors" to "Stop on first error"

* throw_early -> fail_fast
  • Loading branch information
joeyballentine committed Jan 11, 2024
1 parent 5afa666 commit d613bed
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 30 deletions.
14 changes: 7 additions & 7 deletions backend/src/api/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,19 +607,19 @@ def add_package(
class Iterator(Generic[I]):
iter_supplier: Callable[[], Iterable[I | Exception]]
expected_length: int
defer_errors: bool = False
fail_fast: bool = True

@staticmethod
def from_iter(
iter_supplier: Callable[[], Iterable[I | Exception]],
expected_length: int,
defer_errors: bool = False,
fail_fast: bool = True,
) -> Iterator[I]:
return Iterator(iter_supplier, expected_length, defer_errors=defer_errors)
return Iterator(iter_supplier, expected_length, fail_fast=fail_fast)

@staticmethod
def from_list(
l: list[L], map_fn: Callable[[L, int], I], defer_errors: bool = False
l: list[L], map_fn: Callable[[L, int], I], fail_fast: bool = True
) -> Iterator[I]:
"""
Creates a new iterator from a list that is mapped using the given
Expand All @@ -633,11 +633,11 @@ def supplier():
except Exception as e:
yield e

return Iterator(supplier, len(l), defer_errors=defer_errors)
return Iterator(supplier, len(l), fail_fast=fail_fast)

@staticmethod
def from_range(
count: int, map_fn: Callable[[int], I], defer_errors: bool = False
count: int, map_fn: Callable[[int], I], fail_fast: bool = True
) -> Iterator[I]:
"""
Creates a new iterator the given number of items where each item is
Expand All @@ -652,7 +652,7 @@ def supplier():
except Exception as e:
yield e

return Iterator(supplier, count, defer_errors=defer_errors)
return Iterator(supplier, count, fail_fast=fail_fast)


N = TypeVar("N")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
icon="MdLoop",
inputs=[
DirectoryInput(),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad model doesn't interrupt your batch.",
BoolInput("Stop on first error", default=False).with_docs(
"Instead of collecting errors and throwing them at the end of processing, stop iteration and throw an error as soon as one occurs.",
hint=True,
),
],
Expand All @@ -49,7 +49,7 @@
)
def load_models_node(
directory: str,
defer_errors: bool,
fail_fast: bool,
) -> tuple[Iterator[tuple[NcnnModelWrapper, str, str, int]], str]:
logger.debug(f"Iterating over models in directory: {directory}")

Expand Down Expand Up @@ -81,4 +81,4 @@ def load_model(filepath_pairs: tuple[str, str], index: int):

model_files = list(zip(param_files, bin_files))

return Iterator.from_list(model_files, load_model, defer_errors), directory
return Iterator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
icon="MdLoop",
inputs=[
DirectoryInput(),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad model doesn't interrupt your batch.",
BoolInput("Stop on first error", default=False).with_docs(
"Instead of collecting errors and throwing them at the end of processing, stop iteration and throw an error as soon as one occurs.",
hint=True,
),
],
Expand All @@ -49,7 +49,7 @@
)
def load_models_node(
directory: str,
defer_errors: bool,
fail_fast: bool,
) -> tuple[Iterator[tuple[OnnxModel, str, str, int]], str]:
logger.debug(f"Iterating over models in directory: {directory}")

Expand All @@ -62,4 +62,4 @@ def load_model(path: str, index: int):
supported_filetypes = [".onnx"]
model_files = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model, defer_errors), directory
return Iterator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
icon="MdLoop",
inputs=[
DirectoryInput(),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad model doesn't interrupt your batch.",
BoolInput("Stop on first error", default=False).with_docs(
"Instead of collecting errors and throwing them at the end of processing, stop iteration and throw an error as soon as one occurs.",
hint=True,
),
],
Expand All @@ -46,7 +46,7 @@
)
def load_models_node(
directory: str,
defer_errors: bool,
fail_fast: bool,
) -> tuple[Iterator[tuple[ModelDescriptor, str, str, int]], str]:
logger.debug(f"Iterating over models in directory: {directory}")

Expand All @@ -59,4 +59,4 @@ def load_model(path: str, index: int):
supported_filetypes = [".pt", ".pth", ".ckpt", ".safetensors"]
model_files: list[str] = list_all_files_sorted(directory, supported_filetypes)

return Iterator.from_list(model_files, load_model, defer_errors), directory
return Iterator.from_list(model_files, load_model, fail_fast), directory
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@
"Limit the number of images to iterate over. This can be useful for testing the iterator without having to iterate over all images."
)
),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad image doesn't interrupt your batch.",
BoolInput("Stop on first error", default=False).with_docs(
"Instead of collecting errors and throwing them at the end of processing, stop iteration and throw an error as soon as one occurs.",
hint=True,
),
],
Expand All @@ -60,7 +60,7 @@ def load_image_pairs_node(
directory_b: str,
use_limit: bool,
limit: int,
defer_errors: bool,
fail_fast: bool,
) -> tuple[Iterator[tuple[np.ndarray, np.ndarray, str, str, str, str, int]], str, str]:
def load_images(filepaths: tuple[str, str], index: int):
path_a, path_b = filepaths
Expand Down Expand Up @@ -90,7 +90,7 @@ def load_images(filepaths: tuple[str, str], index: int):
image_files = list(zip(image_files_a, image_files_b))

return (
Iterator.from_list(image_files, load_images, defer_errors),
Iterator.from_list(image_files, load_images, fail_fast),
directory_a,
directory_b,
)
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def list_glob(directory: str, globexpr: str, ext_filter: list[str]) -> list[str]
"Limit the number of images to iterate over. This can be useful for testing the iterator without having to iterate over all images."
)
),
BoolInput("Defer errors", default=True).with_docs(
"Ignore errors that occur during iteration and throw them after processing. Use this if you want to make sure one bad image doesn't interrupt your batch.",
BoolInput("Stop on first error", default=False).with_docs(
"Instead of collecting errors and throwing them at the end of processing, stop iteration and throw an error as soon as one occurs.",
hint=True,
),
],
Expand All @@ -95,7 +95,7 @@ def load_images_node(
glob_str: str,
use_limit: bool,
limit: int,
defer_errors: bool,
fail_fast: bool,
) -> tuple[Iterator[tuple[np.ndarray, str, str, int]], str]:
def load_image(path: str, index: int):
img, img_dir, basename = load_image_node(path)
Expand All @@ -115,4 +115,4 @@ def load_image(path: str, index: int):
if use_limit:
just_image_files = just_image_files[:limit]

return Iterator.from_list(just_image_files, load_image, defer_errors), directory
return Iterator.from_list(just_image_files, load_image, fail_fast), directory
6 changes: 3 additions & 3 deletions backend/src/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,10 +606,10 @@ async def update_progress():
except Aborted:
raise
except Exception as e:
if iterator_output.iterator.defer_errors:
deferred_errors.append(str(e))
else:
if iterator_output.iterator.fail_fast:
raise e
else:
deferred_errors.append(str(e))

# reset cached value
self.cache.delete_many(all_iterated_nodes)
Expand Down

0 comments on commit d613bed

Please sign in to comment.