Skip to content

Commit

Permalink
Expanded exceptions handling (#88)
Browse files Browse the repository at this point in the history
  • Loading branch information
natsunlee committed May 10, 2024
1 parent 2dfc6bb commit defb1a8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 15 deletions.
12 changes: 12 additions & 0 deletions flowmancer/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,15 @@ def add_error(self, field: str, msg: str) -> None:
@property
def errors(self) -> List[Dict[str, str]]:
return self._errors


class VarFormatError(Exception):
pass


class TaskClassNotFoundError(Exception):
pass


class ModuleLoadError(Exception):
pass
3 changes: 3 additions & 0 deletions flowmancer/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
StdOutLogWriterWrapper,
TaskLogWriterWrapper,
)
from .exceptions import TaskClassNotFoundError
from .task import Task, _task_classes


Expand Down Expand Up @@ -142,6 +143,8 @@ def get_task_class(self) -> Type[Task]:
if inspect.isclass(self.task_class) and issubclass(self.task_class, Task):
return self.task_class
elif type(self.task_class) == str:
if self.task_class not in _task_classes:
raise TaskClassNotFoundError(self.task_class)
return _task_classes[self.task_class]
else:
raise TypeError('The `task_class` param must be either an extension of `Task` or the string name of one.')
Expand Down
26 changes: 11 additions & 15 deletions flowmancer/flowmancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,11 @@
from .exceptions import (
CheckpointInvalidError,
ExtensionsDirectoryNotFoundError,
ModuleLoadError,
NotAPackageError,
NoTasksLoadedError,
TaskValidationError,
VarFormatError,
)
from .executor import Executor
from .extensions.extension import Extension, _extension_classes
Expand Down Expand Up @@ -62,7 +64,7 @@ def _create_loop():
loop.close()


def _load_extensions_path(path: str, package_chain: Optional[List[str]] = None):
def _load_extensions_path(path: str, package_chain: Optional[List[str]] = None) -> None:
if not path.startswith('/'):
path = os.path.abspath(
os.path.join(
Expand All @@ -79,6 +81,7 @@ def _load_extensions_path(path: str, package_chain: Optional[List[str]] = None):
raise NotAPackageError(f"Only packages (directories) are allowed. The following is not a dir: '{path}'")
if not os.path.exists(os.path.join(path, '__init__.py')):
print(f"WARNING: The '{path}' dir is not a package (no __init__.py file found). Modules will not be imported.")
return None

if not package_chain:
package_chain = [os.path.basename(path)]
Expand All @@ -88,9 +91,7 @@ def _load_extensions_path(path: str, package_chain: Optional[List[str]] = None):
print(f"Loading Module: {'.'.join(package_chain+[x.name])}")
importlib.import_module('.'.join(package_chain+[x.name]))
except Exception as e:
print(
f"WARNING: Skipping import for '{'.'.join(package_chain+[x.name])}' due to {type(e).__name__}: {str(e)}"
)
raise ModuleLoadError(f"Error loading '{'.'.join(package_chain+[x.name])}': {e}")
if x.ispkg:
_load_extensions_path(os.path.join(path, x.name), package_chain+[x.name])

Expand Down Expand Up @@ -153,21 +154,16 @@ def start(
except ValidationError as e:
if raise_exception_on_failure:
raise
print('Errors exist in the provided JobDefinition:')
print('ERROR: Errors exist in the provided JobDefinition:')
error_list = json.loads(e.json())
for err in error_list:
print(f' - {".".join(err["loc"])}: {err["msg"]}')
return 1
except TaskValidationError as e:
if raise_exception_on_failure:
raise
print(e)
return 2
except NoTasksLoadedError as e:
except Exception as e:
if raise_exception_on_failure:
raise
print(e)
return 3
print(f'ERROR: {e}')
return 99
finally:
os.chdir(orig_cwd)

Expand All @@ -180,7 +176,7 @@ def _validate_tasks(self) -> None:
for ve in json.loads(e.json()):
err.add_error(f'tasks.{n}.parameters.{".".join(ve["loc"])}', ve['msg'])
except Exception as e:
err.add_error(f'tasks.{n}.parameters.{".".join(ve["loc"])}', str(e))
err.add_error(f'tasks.{n}', repr(e))
if err.errors:
raise err

Expand Down Expand Up @@ -243,7 +239,7 @@ def _process_cmd_args(
for v in args.jobdef_vars:
parts = v.split('=')
if len(parts) <= 1:
raise ValueError('`var` arguments must follow the pattern: <key>=<value>')
raise VarFormatError('`var` arguments must follow the pattern: <key>=<value>')
self.set_jobdef_var(parts[0], '='.join(parts[1:]))

if args.jobdef:
Expand Down

0 comments on commit defb1a8

Please sign in to comment.