Skip to content

Commit

Permalink
Fix for issue #15.
Browse files Browse the repository at this point in the history
  • Loading branch information
lucabaldini committed Oct 11, 2023
1 parent 729c0c1 commit ed2e38b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
35 changes: 25 additions & 10 deletions hexsample/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from hxsim import HXSIM_ARGPARSER, hxsim as _hxsim


def required_args(parser : ArgumentParser) -> list:
def required_arguments(parser : ArgumentParser) -> list:
"""Return a list of the positional arguments for a given parser.
This is useful to retrieve all the default values from an ArgumentParser
Expand All @@ -49,7 +49,7 @@ def required_args(parser : ArgumentParser) -> list:
"""
return [action.dest for action in parser._actions if action.required]

def default_args(parser : ArgumentParser) -> dict:
def default_arguments(parser : ArgumentParser) -> dict:
"""Return the default arguments for a given ArgumentParser object.
If the parser has no positional arguments, this is simply achieved via a
Expand All @@ -68,13 +68,22 @@ def default_args(parser : ArgumentParser) -> dict:
---------
parser : ArgumentParser
The argument parser object for a given application.
Returns
-------
args, kwargs : list, dict
The list of positional arguments and a dictionary of all the other arguments.
"""
args = required_args(parser)
# Positional arguments.
args = required_arguments(parser)
# All parser arguments.
kwargs = vars(parser.parse_args(args))
# Strip the positional arguments from the complete list.
[kwargs.pop(key) for key in args]
return kwargs
# And return the two sets separately.
return args, kwargs

def update_args(parser : ArgumentParser, **kwargs) -> dict:
def update_arguments(parser : ArgumentParser, **kwargs) -> dict:
"""Retrieve the default option from an ArgumentParser object and update
specific keys based on arbitrary keyword arguments.
Expand All @@ -86,16 +95,22 @@ def update_args(parser : ArgumentParser, **kwargs) -> dict:
kwargs : dict
Additional keyword arguments.
"""
args = default_args(parser)
args.update(kwargs)
return args
# Retrieve the default arguments.
_args, _kwargs = default_arguments(parser)
# Loop over the kwargs passed to the function to make sure that all of them
# are recognized by the parser.
for key in kwargs:
if key not in _args and key not in _kwargs:
raise RuntimeError(f'Unknown parameter {key} passed to a pipeline component')
_kwargs.update(kwargs)
return _kwargs

def hxrecon(**kwargs):
"""Application wrapper.
"""
return _hxrecon(**update_args(HXRECON_ARGPARSER, **kwargs))
return _hxrecon(**update_arguments(HXRECON_ARGPARSER, **kwargs))

def hxsim(**kwargs):
"""Application wrapper.
"""
return _hxsim(**update_args(HXSIM_ARGPARSER, **kwargs))
return _hxsim(**update_arguments(HXSIM_ARGPARSER, **kwargs))
18 changes: 13 additions & 5 deletions tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,19 @@
def test_parsers():
"""Test the relevant ArgumentParser objects.
"""
assert pipeline.required_args(pipeline.HXSIM_ARGPARSER) == []
assert pipeline.required_args(pipeline.HXRECON_ARGPARSER) == ['infile']
assert 'infile' not in pipeline.update_args(pipeline.HXRECON_ARGPARSER)
assert 'infile' in pipeline.update_args(pipeline.HXRECON_ARGPARSER, infile='test_file')
print(pipeline.update_args(pipeline.HXSIM_ARGPARSER))
assert pipeline.required_arguments(pipeline.HXSIM_ARGPARSER) == []
assert pipeline.required_arguments(pipeline.HXRECON_ARGPARSER) == ['infile']
assert 'infile' not in pipeline.update_arguments(pipeline.HXRECON_ARGPARSER)
assert 'infile' in pipeline.update_arguments(pipeline.HXRECON_ARGPARSER, infile='test_file')
print(pipeline.update_arguments(pipeline.HXSIM_ARGPARSER))

def test_wrong_args():
"""Make sure that, when supplied with wrong parameters, the pipeline applications
are raising a RuntimeError.
"""
with pytest.raises(RuntimeError) as excinfo:
pipeline.hxsim(numevents=100, bogusparam='howdy')
print(excinfo.value)

def test_pipeline():
"""Test generating and reconstructing files.
Expand Down

0 comments on commit ed2e38b

Please sign in to comment.