Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for positional args in pipeline.get_config() #2478

Merged
merged 2 commits into from
May 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions haystack/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,6 @@ def wrapper_exportable_to_yaml(self, *args, **kwargs):
# Call the actuall __init__ function with all the arguments
init_func(self, *args, **kwargs)

# Warn for unnamed input params - should be rare
if args:
logger.warning(
"Unnamed __init__ parameters will not be saved to YAML if Pipeline.save_to_yaml() is called!"
)
# Create the configuration dictionary if it doesn't exist yet
if not self._component_config:
self._component_config = {"params": {}, "type": type(self).__name__}
Expand All @@ -46,6 +41,14 @@ def wrapper_exportable_to_yaml(self, *args, **kwargs):
for k, v in kwargs.items():
self._component_config["params"][k] = v

# Store unnamed input parameters in self._component_config too by inferring their names
sig = inspect.signature(init_func)
parameter_names = list(sig.parameters.keys())
# we can be sure that the first one is always "self"
arg_names = parameter_names[1 : 1 + len(args)]
for arg, arg_name in zip(args, arg_names):
self._component_config["params"][arg_name] = arg

return wrapper_exportable_to_yaml


Expand Down
12 changes: 4 additions & 8 deletions test/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,21 +386,17 @@ def __init__(self, param: int):
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}


def test_get_config_custom_node_with_positional_params(caplog):
def test_get_config_custom_node_with_positional_params():
class CustomNode(MockNode):
def __init__(self, param: int = 1):
super().__init__()
self.param = param

pipeline = Pipeline()
with caplog.at_level(logging.WARNING):
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])
assert (
"Unnamed __init__ parameters will not be saved to YAML "
"if Pipeline.save_to_yaml() is called" in caplog.text
)
pipeline.add_node(CustomNode(10), name="custom_node", inputs=["Query"])

assert len(pipeline.get_config()["components"]) == 1
assert pipeline.get_config()["components"][0]["params"] == {}
assert pipeline.get_config()["components"][0]["params"] == {"param": 10}


def test_generate_code_simple_pipeline():
Expand Down