Skip to content

Commit

Permalink
Implement conditional subwokflow steps
Browse files Browse the repository at this point in the history
for mapped over subworkflows and simple subworkflows
  • Loading branch information
mvdbeek committed Nov 10, 2022
1 parent 3ea3ec6 commit 15bdadb
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 18 deletions.
48 changes: 31 additions & 17 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ def _find_collections_to_match(self, progress, step, all_inputs):
collections_to_match = matching.CollectionsToMatch()
dataset_collection_type_descriptions = self.trans.app.dataset_collection_manager.collection_type_descriptions

if progress.when:
collections_to_match.add("when_source", progress.when)

for input_dict in all_inputs:
name = input_dict["name"]
data = progress.replacement_for_input(step, input_dict)
Expand Down Expand Up @@ -445,6 +448,17 @@ def _find_collections_to_match(self, progress, step, all_inputs):

return collections_to_match

@staticmethod
def create_when_source_param():
return dict(
name="when_source",
label="when_source",
multiple=False,
input_type="parameter",
optional=False,
type="boolean",
)


class SubWorkflowModule(WorkflowModule):
# Two step improvements to build runtime inputs for subworkflow modules
Expand Down Expand Up @@ -599,10 +613,18 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
inputs, etc...
"""
step = invocation_step.workflow_step
collection_info = self.compute_collection_info(progress, step, self.get_all_inputs())
all_inputs = self.get_all_inputs()
if step.input_connections_by_name.get("when_source"):
param_dict = self.create_when_source_param()
all_inputs.append(param_dict)
# Maybe make this lazy with a callback ?
when = progress.replacement_for_input(step, param_dict)

collection_info = self.compute_collection_info(progress, step, all_inputs)
structure = collection_info.structure if collection_info else None

subworkflow_invoker = progress.subworkflow_invoker(
trans, step, use_cached_job=use_cached_job, subworkflow_structure=structure
trans, step, use_cached_job=use_cached_job, subworkflow_structure=structure, when=when
)
subworkflow_invoker.invoke()
subworkflow = subworkflow_invoker.workflow
Expand Down Expand Up @@ -1640,18 +1662,6 @@ def get_errors(self, include_tool_id=False, **kwargs):
def get_inputs(self):
return self.tool.inputs if self.tool else {}

def get_conditional_param(self, inputs):
inputs.append(
dict(
name="when_source",
label="when_source",
multiple=False,
input_type="parameter",
optional=False,
type="boolean",
)
)

def get_all_inputs(self, data_only=False, connectable_only=False):
if data_only and connectable_only:
raise Exception("Must specify at most one of data_only and connectable_only as True.")
Expand Down Expand Up @@ -1944,8 +1954,8 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
del tool_state.inputs[RUNTIME_STEP_META_STATE_KEY]

all_inputs = self.get_all_inputs()
if step.input_connections_by_name.get("when_source"):
self.get_conditional_param(all_inputs)
if progress.when is not None or step.input_connections_by_name.get("when_source"):
all_inputs.append(self.create_when_source_param())
tool_inputs["when_source"] = ConditionalStepWhen(None, {"name": "when_source", "type": "boolean"})
all_inputs_by_name = {}
for input_dict in all_inputs:
Expand Down Expand Up @@ -1974,6 +1984,8 @@ def callback(input, prefixed_name, **kwargs):
replacement: Union[model.Dataset, NoReplacement] = NO_REPLACEMENT
if iteration_elements and prefixed_name in iteration_elements: # noqa: B023
replacement = iteration_elements[prefixed_name] # noqa: B023
elif prefixed_name == "when_source" and progress.when is not None:
replacement = progress.when
else:
replacement = progress.replacement_for_input(step, input_dict)

Expand All @@ -1990,7 +2002,9 @@ def callback(input, prefixed_name, **kwargs):
replacement = json.load(f)
found_replacement_keys.add(prefixed_name) # noqa: B023

if isinstance(input, ConditionalStepWhen) and replacement is False:
# bool cast should be fine, can only have true/false on ConditionalStepWhen
# also terrible of course and it's not needed for API requests
if isinstance(input, ConditionalStepWhen) and bool(replacement) is False:
raise SkipWorkflowStepEvaluation

return replacement
Expand Down
11 changes: 10 additions & 1 deletion lib/galaxy/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ def __init__(
use_cached_job: bool = False,
replacement_dict: Optional[Dict[str, str]] = None,
subworkflow_structure=None,
when=None,
) -> None:
self.outputs: Dict[int, Any] = {}
self.module_injector = module_injector
Expand All @@ -318,6 +319,7 @@ def __init__(
self.use_cached_job = use_cached_job
self.replacement_dict = replacement_dict or {}
self.subworkflow_structure = subworkflow_structure
self.when = when

@property
def maximum_jobs_to_schedule_or_none(self) -> Optional[int]:
Expand Down Expand Up @@ -532,11 +534,16 @@ def subworkflow_invoker(
step: "WorkflowStep",
use_cached_job: bool = False,
subworkflow_structure=None,
when=None,
) -> WorkflowInvoker:
subworkflow_invocation = self._subworkflow_invocation(step)
workflow_run_config = workflow_request_to_run_config(subworkflow_invocation, use_cached_job)
subworkflow_progress = self.subworkflow_progress(
subworkflow_invocation, step, workflow_run_config.param_map, subworkflow_structure=subworkflow_structure
subworkflow_invocation,
step,
workflow_run_config.param_map,
subworkflow_structure=subworkflow_structure,
when=when,
)
subworkflow_invocation = subworkflow_progress.workflow_invocation
return WorkflowInvoker(
Expand All @@ -552,6 +559,7 @@ def subworkflow_progress(
step: "WorkflowStep",
param_map: Dict,
subworkflow_structure=None,
when=None,
) -> "WorkflowProgress":
subworkflow = subworkflow_invocation.workflow
subworkflow_inputs = {}
Expand Down Expand Up @@ -580,6 +588,7 @@ def subworkflow_progress(
use_cached_job=self.use_cached_job,
replacement_dict=self.replacement_dict,
subworkflow_structure=subworkflow_structure,
when=when,
)

def _recover_mapping(self, step_invocation: WorkflowInvocationStep) -> None:
Expand Down
111 changes: 111 additions & 0 deletions lib/galaxy_test/api/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,59 @@ def test_run_workflow_simple_conditional_step(self):
if step["workflow_step_label"] == "cat1":
assert sum(1 for j in step["jobs"] if j["state"] == "skipped") == 1

def test_run_workflow_subworkflow_conditional_step(self):
with self.dataset_populator.test_history() as history_id:
summary = self._run_workflow(
"""class: GalaxyWorkflow
inputs:
should_run:
type: boolean
some_file:
type: data
steps:
subworkflow:
run:
class: GalaxyWorkflow
inputs:
some_file:
type: data
steps:
a_tool_step:
tool_id: cat1
in:
input1: some_file
in:
some_file: some_file
outputs:
inner_out: a_tool_step/out_file1
when:
source: should_run
outputs:
outer_output:
outputSource: subworkflow/inner_out
""",
test_data="""
some_file:
value: 1.bed
type: File
should_run:
value: false
type: raw
""",
history_id=history_id,
wait=True,
assert_ok=True,
)
invocation_details = self.workflow_populator.get_invocation(summary.invocation_id, step_details=True)
subworkflow_invocation_id = invocation_details["steps"][-1]["subworkflow_invocation_id"]
self.workflow_populator.wait_for_invocation_and_jobs(
history_id=history_id, workflow_id="whatever", invocation_id=subworkflow_invocation_id
)
invocation_details = self.workflow_populator.get_invocation(subworkflow_invocation_id, step_details=True)
for step in invocation_details["steps"]:
if step["workflow_step_label"] == "a_tool_step":
assert sum(1 for j in step["jobs"] if j["state"] == "skipped") == 1

def test_run_workflow_conditional_step_map_over_expression_tool(self):
with self.dataset_populator.test_history() as history_id:
summary = self._run_workflow(
Expand Down Expand Up @@ -1971,6 +2024,64 @@ def test_run_workflow_conditional_step_map_over_expression_tool(self):
)
assert dataset_details["file_ext"] == "expression.json", dataset_details

def test_run_workflow_conditional_subworkflow_step_map_over_expression_tool(self):
with self.dataset_populator.test_history() as history_id:
summary = self._run_workflow(
"""
class: GalaxyWorkflow
inputs:
boolean_input_files: collection
steps:
create_list_of_boolean:
tool_id: param_value_from_file
in:
input1: boolean_input_files
state:
param_type: boolean
subworkflow:
run:
class: GalaxyWorkflow
inputs:
boolean_input_file: data
steps:
consume_expression_parameter:
tool_id: cat1
in:
input1: boolean_input_file
out:
out_file1:
change_datatype: txt
outputs:
inner_output:
outputSource: consume_expression_parameter/out_file1
in:
boolean_input_file: boolean_input_files
when:
source: create_list_of_boolean/boolean_param
outputs:
outer_output:
outputSource: subworkflow/inner_output
test_data:
boolean_input_files:
collection_type: list
elements:
- identifier: true
content: true
- identifier: false
content: false
""",
history_id=history_id,
)
invocation_details = self.workflow_populator.get_invocation(summary.invocation_id, step_details=True)
assert "outer_output" in invocation_details["output_collections"]
outer_output = invocation_details["output_collections"]["outer_output"]
outer_hdca = self.dataset_populator.get_history_collection_details(
history_id, content_id=outer_output["id"]
)
assert outer_hdca["job_state_summary"]["all_jobs"] == 2
assert outer_hdca["job_state_summary"]["ok"] == 1
assert outer_hdca["job_state_summary"]["skipped"] == 1

def test_run_workflow_conditional_step_map_over_expression_tool_pick_value(self):
with self.dataset_populator.test_history() as history_id:
summary = self._run_workflow(
Expand Down

0 comments on commit 15bdadb

Please sign in to comment.