diff --git a/WDL/Tree.py b/WDL/Tree.py index a27ba2e6..ad9f3798 100644 --- a/WDL/Tree.py +++ b/WDL/Tree.py @@ -736,7 +736,7 @@ def available_inputs(self) -> Env.Decls: for decl in self.inputs: ans = Env.bind(ans, [], decl.name, decl) - for elt in _decls_and_calls(self): + for elt in _decls_and_calls(self, exclude_outputs=True): if isinstance(elt, Decl): if self.inputs is None: ans = Env.bind(ans, [], elt.name, elt) @@ -760,7 +760,7 @@ def required_inputs(self) -> Env.Decls: if decl.expr is None and decl.type.optional is False: ans = Env.bind(ans, [], decl.name, decl) - for elt in _decls_and_calls(self): + for elt in _decls_and_calls(self, exclude_outputs=True): if isinstance(elt, Decl): if self.inputs is None and elt.expr is None and elt.type.optional is False: ans = Env.bind(ans, [], elt.name, elt) @@ -1060,11 +1060,15 @@ def load( def _decls_and_calls( - element: Union[Workflow, Scatter, Conditional] + element: Union[Workflow, Scatter, Conditional], exclude_outputs: bool = True ) -> Generator[Union[Decl, Call], None, None]: # Yield each Decl and Call in the workflow, including those nested within # scatter/conditional sections - for ch in element.children: + children = element.children + if isinstance(element, Workflow) and exclude_outputs: + children = element.inputs if element.inputs else [] + children = children + element.elements + for ch in children: if isinstance(ch, (Decl, Call)): yield ch elif isinstance(ch, (Scatter, Conditional)): diff --git a/tests/test_1doc.py b/tests/test_1doc.py index 90872370..08310df6 100644 --- a/tests/test_1doc.py +++ b/tests/test_1doc.py @@ -1279,6 +1279,22 @@ def test_multi_errors(self): except WDL.Error.MultipleValidationErrors as multi: self.assertEqual(len(multi.exceptions), 2) + def test_issue135_workflow_available_inputs(self): + # Workflow.available_inputs should not include declarations in the + # output section + doc = r""" + workflow a { + File in + output { + File out = in + } + } + """ + doc = WDL.parse_document(doc) + doc.typecheck() + self.assertEqual(len(doc.workflow.available_inputs), 1) + self.assertEqual(doc.workflow.available_inputs[0].name, "in") + class TestCycleDetection(unittest.TestCase): def test_task(self): doc = r"""