Skip to content

Commit

Permalink
Test and fix conditional subworkflow steps where subworkflow step pro…
Browse files Browse the repository at this point in the history
…duces additional mapping over without being directly connected to a workflow input
  • Loading branch information
mvdbeek committed Nov 17, 2022
1 parent 0ddf69e commit 0d1b5c8
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 85 deletions.
5 changes: 4 additions & 1 deletion lib/galaxy/model/dataset_collections/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class MatchingCollections:
overkill but I suspect in the future plugins will be subtypable for
instance so matching collections will need to make heavy use of the
dataset collection type registry managed by the dataset collections
sevice - hence the complexity now.
service - hence the complexity now.
"""

def __init__(self):
Expand All @@ -46,6 +46,7 @@ def __init__(self):
self.collections = {}
self.subcollection_types = {}
self.action_tuples = {}
self.when_values = None

def __attempt_add_to_linked_match(self, input_name, hdca, collection_type_description, subcollection_type):
structure = get_structure(hdca, collection_type_description, leaf_subcollection_type=subcollection_type)
Expand All @@ -60,6 +61,7 @@ def __attempt_add_to_linked_match(self, input_name, hdca, collection_type_descri
self.subcollection_types[input_name] = subcollection_type

def slice_collections(self):
self.linked_structure.when_values = self.when_values
return self.linked_structure.walk_collections(self.collections)

def subcollection_mapping_type(self, input_name):
Expand All @@ -75,6 +77,7 @@ def structure(self):
if linked_structure is None:
linked_structure = leaf
effective_structure = effective_structure.multiply(linked_structure)
effective_structure.when_values = self.when_values
return None if effective_structure.is_leaf else effective_structure

def map_over_action_tuples(self, input_name):
Expand Down
9 changes: 5 additions & 4 deletions lib/galaxy/model/dataset_collections/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def __str__(self):
class Tree(BaseTree):
children_known = True

def __init__(self, children, collection_type_description):
def __init__(self, children, collection_type_description, when_values=None):
super().__init__(collection_type_description)
self.children = children
self.when_values = when_values

@staticmethod
def for_dataset_collection(dataset_collection, collection_type_description):
Expand Down Expand Up @@ -94,11 +95,11 @@ def get_element(collection):
return collection[index] # noqa: B023

if substructure.is_leaf:
yield dict_map(get_element, collection_dict)
yield dict_map(get_element, collection_dict), self.when_values[index] if self.when_values else None
else:
sub_collections = dict_map(lambda collection: get_element(collection).child_collection, collection_dict)
for element in substructure._walk_collections(sub_collections):
yield element
for element, _when_value in substructure._walk_collections(sub_collections):
yield element, self.when_values[index] if self.when_values else None

@property
def is_leaf(self):
Expand Down
2 changes: 1 addition & 1 deletion lib/galaxy/tools/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def record_success(self, execution_slice, job, outputs):
self.invocation_step.job = job

def new_collection_execution_slices(self):
for job_index, (param_combination, dataset_collection_elements) in enumerate(
for job_index, (param_combination, (dataset_collection_elements, _when_value)) in enumerate(
zip(self.param_combinations, self.walk_implicit_collections())
):
completed_job = self.completed_jobs and self.completed_jobs[job_index]
Expand Down
38 changes: 20 additions & 18 deletions lib/galaxy/workflow/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ class ConditionalStepWhen(BooleanToolParameter):
def evaluate_value_from_expressions(progress, step, execution_state, extra_step_state):
when_expression = step.when_expression
value_from_expressions = {}
replacements = {}

if execution_state:
for key in execution_state.inputs.keys():
Expand All @@ -121,7 +120,7 @@ def evaluate_value_from_expressions(progress, step, execution_state, extra_step_
value_from_expressions[key] = step_input.value_from

if not value_from_expressions and when_expression is None:
return replacements
return {}

hda_references = []

Expand Down Expand Up @@ -462,6 +461,8 @@ def compute_collection_info(self, progress, step, all_inputs):
collections_to_match = self._find_collections_to_match(progress, step, all_inputs)
# Have implicit collections...
collection_info = self.trans.app.dataset_collection_manager.match_collections(collections_to_match)
if collection_info and progress.subworkflow_collection_info:
collection_info.when_values = progress.subworkflow_collection_info.when_values
return collection_info or progress.subworkflow_collection_info

def _find_collections_to_match(self, progress, step, all_inputs):
Expand Down Expand Up @@ -705,11 +706,18 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
if collection_info:
iteration_elements_iter = collection_info.slice_collections()
else:
iteration_elements_iter = [None]
if progress.when_values:
# If we have more than one item in when_values it must have come from an expression.json
# collection, so we'd have a collection_info instance ... I think.
assert len(progress.when_values) == 1, "Got more than 1 when value, this shouldn't be possible"
iteration_elements_iter = [(None, progress.when_values[0] if progress.when_values else None)]

when_values = []
if step.when_expression:
for iteration_elements in iteration_elements_iter:
for (iteration_elements, when_value) in iteration_elements_iter:
if when_value is False:
when_values.append(when_value)
continue
extra_step_state = {}
for step_input in step.inputs:
step_input_name = step_input.name
Expand All @@ -724,13 +732,15 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
progress, step, execution_state={}, extra_step_state=extra_step_state
)
)
if collection_info:
collection_info.when_values = when_values

subworkflow_invoker = progress.subworkflow_invoker(
trans,
step,
use_cached_job=use_cached_job,
subworkflow_collection_info=collection_info,
when=when_values,
when_values=when_values,
)
subworkflow_invoker.invoke()
subworkflow = subworkflow_invoker.workflow
Expand Down Expand Up @@ -2070,10 +2080,12 @@ def execute(self, trans, progress, invocation_step, use_cached_job=False):
if collection_info:
iteration_elements_iter = collection_info.slice_collections()
else:
iteration_elements_iter = [None]
if progress.when_values:
assert len(progress.when_values) == 1, "Got more than 1 when value, this shouldn't be possible"
iteration_elements_iter = [(None, progress.when_values[0] if progress.when_values else None)]

resource_parameters = invocation.resource_parameters
for iteration_index, iteration_elements in enumerate(iteration_elements_iter):
for (iteration_elements, when_value) in iteration_elements_iter:
execution_state = tool_state.copy()
# TODO: Move next step into copy()
execution_state.inputs = make_dict_copy(execution_state.inputs)
Expand Down Expand Up @@ -2125,10 +2137,7 @@ def callback(input, prefixed_name, **kwargs):
message = message_template % (tool.name, unicodify(k))
raise exceptions.MessageException(message)

when_value = None
if progress.when and progress.when[iteration_index] is False:
when_value = False
elif step.when_expression:
if step.when_expression and when_value is not False:
extra_step_state = {}
for step_input in step.inputs:
step_input_name = step_input.name
Expand All @@ -2152,13 +2161,6 @@ def callback(input, prefixed_name, **kwargs):
value = progress.replacement_for_connection(step_input.connections[0], is_data=True)
extra_step_state[step_input_name] = value

when_value = None
if progress.when is not None:
if callable(progress.when):
when_value = progress.when()
else:
when_value = progress.when

if when_value is not False:
when_value = evaluate_value_from_expressions(
progress, step, execution_state=execution_state, extra_step_state=extra_step_state
Expand Down
12 changes: 6 additions & 6 deletions lib/galaxy/workflow/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def __init__(
use_cached_job: bool = False,
replacement_dict: Optional[Dict[str, str]] = None,
subworkflow_collection_info=None,
when=None,
when_values=None,
) -> None:
self.outputs: Dict[int, Any] = {}
self.module_injector = module_injector
Expand All @@ -320,7 +320,7 @@ def __init__(
self.replacement_dict = replacement_dict or {}
self.subworkflow_collection_info = subworkflow_collection_info
self.subworkflow_structure = subworkflow_collection_info.structure if subworkflow_collection_info else None
self.when = when
self.when_values = when_values

@property
def maximum_jobs_to_schedule_or_none(self) -> Optional[int]:
Expand Down Expand Up @@ -535,7 +535,7 @@ def subworkflow_invoker(
step: "WorkflowStep",
use_cached_job: bool = False,
subworkflow_collection_info=None,
when=None,
when_values=None,
) -> WorkflowInvoker:
subworkflow_invocation = self._subworkflow_invocation(step)
workflow_run_config = workflow_request_to_run_config(subworkflow_invocation, use_cached_job)
Expand All @@ -544,7 +544,7 @@ def subworkflow_invoker(
step,
workflow_run_config.param_map,
subworkflow_collection_info=subworkflow_collection_info,
when=when,
when_values=when_values,
)
subworkflow_invocation = subworkflow_progress.workflow_invocation
return WorkflowInvoker(
Expand All @@ -560,7 +560,7 @@ def subworkflow_progress(
step: "WorkflowStep",
param_map: Dict,
subworkflow_collection_info=None,
when=None,
when_values=None,
) -> "WorkflowProgress":
subworkflow = subworkflow_invocation.workflow
subworkflow_inputs = {}
Expand Down Expand Up @@ -589,7 +589,7 @@ def subworkflow_progress(
use_cached_job=self.use_cached_job,
replacement_dict=self.replacement_dict,
subworkflow_collection_info=subworkflow_collection_info,
when=when,
when_values=when_values,
)

def _recover_mapping(self, step_invocation: WorkflowInvocationStep) -> None:
Expand Down
119 changes: 64 additions & 55 deletions lib/galaxy_test/api/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
WorkflowPopulator,
)
from galaxy_test.base.workflow_fixtures import (
NESTED_WORKFLOW_WITH_CONDITIONAL_SUBWORKFLOW_AND_DISCONNECTED_MAP_OVER_SOURCE,
WORKFLOW_INPUTS_AS_OUTPUTS,
WORKFLOW_NESTED_REPLACEMENT_PARAMETER,
WORKFLOW_NESTED_RUNTIME_PARAMETER,
Expand Down Expand Up @@ -2085,66 +2086,74 @@ def test_run_workflow_conditional_subworkflow_step_map_over_expression_tool(self
def test_run_workflow_conditional_subworkflow_step_map_over_expression_tool_with_extra_nesting(self):
with self.dataset_populator.test_history() as history_id:
summary = self._run_workflow(
"""
class: GalaxyWorkflow
inputs:
boolean_input_files: collection
steps:
create_list_of_boolean:
tool_id: param_value_from_file
in:
input1: boolean_input_files
state:
param_type: boolean
subworkflow:
run:
class: GalaxyWorkflow
inputs:
boolean_input_file: data
steps:
create_more_inputs:
tool_id: collection_creates_dynamic_nested
consume_expression_parameter:
tool_id: cat1
state:
input1:
$link: create_more_inputs/list_output
queries:
- input2:
$link: boolean_input_file
out:
out_file1:
change_datatype: txt
outputs:
inner_output:
outputSource: consume_expression_parameter/out_file1
in:
boolean_input_file: boolean_input_files
should_run: create_list_of_boolean/boolean_param
when: $(inputs.should_run)
outputs:
outer_output:
outputSource: subworkflow/inner_output
test_data:
boolean_input_files:
collection_type: list
elements:
- identifier: true
content: true
- identifier: false
content: false
NESTED_WORKFLOW_WITH_CONDITIONAL_SUBWORKFLOW_AND_DISCONNECTED_MAP_OVER_SOURCE,
test_data="""boolean_input_files:
collection_type: list
elements:
- identifier: true
content: true
- identifier: false
content: false
""",
history_id=history_id,
)
invocation_details = self.workflow_populator.get_invocation(summary.invocation_id, step_details=True)
assert "outer_output" in invocation_details["output_collections"]
outer_output = invocation_details["output_collections"]["outer_output"]
outer_hdca = self.dataset_populator.get_history_collection_details(
history_id, content_id=outer_output["id"]
outer_create_nested_id = invocation_details["output_collections"]["outer_create_nested"]["id"]
outer_create_nested = self.dataset_populator.get_history_collection_details(
history_id, content_id=outer_create_nested_id
)
assert outer_create_nested["job_state_summary"]["all_jobs"] == 2
assert outer_create_nested["job_state_summary"]["ok"] == 1
assert outer_create_nested["job_state_summary"]["skipped"] == 1

for cat1_output in ["outer_output_1", "outer_output_2"]:
outer_output = invocation_details["output_collections"][cat1_output]
outer_hdca = self.dataset_populator.get_history_collection_details(
history_id, content_id=outer_output["id"]
)
# You might expect 12 total jobs, 6 ok and 6 skipped,
# but because we're not actually running one branch of collection_creates_dynamic_nested
# there's no input to consume_expression_parameter.
# It's unclear if that's a problem or not ... probably not a major one,
# since we keep producing "empty" outer collections, which seems somewhat correct.
assert outer_hdca["job_state_summary"]["all_jobs"] == 6
assert outer_hdca["job_state_summary"]["ok"] == 6
assert outer_hdca["collection_type"] == "list:list:list"
elements = outer_hdca["elements"]
assert elements[0]["element_identifier"] == "True"
assert elements[0]["object"]["element_count"] == 3
assert elements[1]["element_identifier"] == "False"
assert elements[1]["object"]["element_count"] == 0

def test_run_workflow_conditional_subworkflow_step_map_over_expression_tool_with_extra_nesting_skip_all(self):
with self.dataset_populator.test_history() as history_id:
summary = self._run_workflow(
NESTED_WORKFLOW_WITH_CONDITIONAL_SUBWORKFLOW_AND_DISCONNECTED_MAP_OVER_SOURCE,
test_data="""boolean_input_files:
collection_type: list
elements:
- identifier: false
content: false
- identifier: also_false
content: false
""",
history_id=history_id,
)
assert outer_hdca["job_state_summary"]["all_jobs"] == 14 # true/false * (6 inner elements + 1)
assert outer_hdca["job_state_summary"]["ok"] == 7
assert outer_hdca["job_state_summary"]["skipped"] == 7
invocation_details = self.workflow_populator.get_invocation(summary.invocation_id, step_details=True)
outer_create_nested_id = invocation_details["output_collections"]["outer_create_nested"]["id"]
outer_create_nested = self.dataset_populator.get_history_collection_details(
history_id, content_id=outer_create_nested_id
)
assert outer_create_nested["job_state_summary"]["all_jobs"] == 2
assert outer_create_nested["job_state_summary"]["skipped"] == 2

for cat1_output in ["outer_output_1", "outer_output_2"]:
outer_output = invocation_details["output_collections"][cat1_output]
outer_hdca = self.dataset_populator.get_history_collection_details(
history_id, content_id=outer_output["id"]
)
assert outer_hdca["job_state_summary"]["all_jobs"] == 0
assert outer_hdca["collection_type"] == "list:list:list"

def test_run_workflow_conditional_step_map_over_expression_tool_pick_value(self):
with self.dataset_populator.test_history() as history_id:
Expand Down

0 comments on commit 0d1b5c8

Please sign in to comment.