diff --git a/pydra/engine/core.py b/pydra/engine/core.py index 9a5724cf29..0aebb4ca9f 100644 --- a/pydra/engine/core.py +++ b/pydra/engine/core.py @@ -14,7 +14,7 @@ from tempfile import mkdtemp from . import state -from . import auxiliary as aux +from . import helpers_state as hlpst from .specs import File, BaseSpec, RuntimeSpec, Result, SpecInfo, LazyField, TaskHook from .helpers import ( make_klass, @@ -54,11 +54,6 @@ class TaskBase: _cache_dir = None # Working directory in which to operate _references = None # List of references for a task - # dj: do we need it?? - input_spec = BaseSpec - output_spec = BaseSpec - - # TODO: write state should be removed def __init__( self, name: str, @@ -351,7 +346,7 @@ def split(self, splitter, **kwargs): self.inputs = dc.replace(self.inputs, **kwargs) # dj:??, check if I need it self.state_inputs = kwargs - splitter = aux.change_splitter(splitter, self.name) + splitter = hlpst.add_name_splitter(splitter, self.name) if self.state: raise Exception("splitter has been already set") else: @@ -376,7 +371,7 @@ def combine(self, combiner): def get_input_el(self, ind): """collecting all inputs required to run the node (for specific state element)""" if ind is not None: - # TODO: doesnt work properly for more cmplicated wf + # TODO: doesnt work properly for more cmplicated wf (check if still an issue) state_dict = self.state.states_val[ind] input_ind = self.state.inputs_ind[ind] inputs_dict = {} @@ -417,7 +412,7 @@ def done(self): def _combined_output(self): combined_results = [] - for (gr, ind_l) in self.state.final_groups_mapping.items(): + for (gr, ind_l) in self.state.final_combined_ind_mapping.items(): combined_results.append([]) for ind in ind_l: result = load_result(self.checksum_states(ind), self.cache_locations) diff --git a/pydra/engine/auxiliary.py b/pydra/engine/helpers_state.py similarity index 69% rename from pydra/engine/auxiliary.py rename to pydra/engine/helpers_state.py index 8baab495b4..b6968a343d 100644 --- a/pydra/engine/auxiliary.py +++ b/pydra/engine/helpers_state.py @@ -1,3 +1,5 @@ +""" additional functions used mostly by the State class """ + import itertools from functools import reduce from copy import deepcopy @@ -6,8 +8,6 @@ logger = logging.getLogger("pydra") -# dj: might create a new class or move to State - # Function to change user provided splitter to "reverse polish notation" used in State def splitter2rpn(splitter, other_states=None, state_fields=True): @@ -39,7 +39,7 @@ def _ordering( node_nm, other_states.keys() ) ) - splitter_mod = change_splitter( + splitter_mod = add_name_splitter( splitter=other_states[node_nm][0].splitter_final, name=node_nm ) if state_fields: @@ -54,7 +54,7 @@ def _ordering( node_nm, other_states.keys() ) ) - splitter_mod = change_splitter( + splitter_mod = add_name_splitter( splitter=other_states[node_nm][0].splitter_final, name=node_nm ) if state_fields: @@ -77,7 +77,7 @@ def _ordering( node_nm, other_states.keys() ) ) - splitter_mod = change_splitter( + splitter_mod = add_name_splitter( splitter=other_states[node_nm][0].splitter_final, name=node_nm ) if state_fields: @@ -92,7 +92,7 @@ def _ordering( node_nm, other_states.keys() ) ) - splitter_mod = change_splitter( + splitter_mod = add_name_splitter( splitter=other_states[node_nm][0].splitter_final, name=node_nm ) if state_fields: @@ -115,7 +115,7 @@ def _ordering( node_nm, other_states.keys() ) ) - splitter_mod = change_splitter( + splitter_mod = add_name_splitter( splitter=other_states[node_nm][0].splitter_final, name=node_nm ) if state_fields: @@ -162,148 +162,6 @@ def _iterate_list(element, sign, other_states, output_splitter, state_fields=Tru # functions used in State to know which element should be used for a specific axis -def splitting_axis(state_inputs, splitter_rpn): - """Having inputs and splitter (in rpn notation), functions returns the axes of output for every input.""" - axis_for_input = {} - stack = [] - # to remember current axis - current_axis = None - # to remember shapes and axes for partial results - out_shape = {} - out_axes = {} - # to remember imput names for partial results - out_inputname = {} - for el in splitter_rpn: - # scalar splitter - if el == ".": - right = stack.pop() - left = stack.pop() - # when both, left and right, are already products of partial splitter - if left.startswith("OUT") and right.startswith("OUT"): - if out_shape[left] != out_shape[right]: - raise Exception( - "arrays for scalar operations should have the same size" - ) - current_inputs = out_inputname[left] + out_inputname[right] - # when left is already product of partial splitter - elif left.startswith("OUT"): - if ( - state_inputs[right].shape == out_shape[left] - ): # todo:should we allow for one-element array? - axis_for_input[right] = out_axes[left] - else: - raise Exception( - "arrays for scalar operations should have the same size" - ) - current_inputs = out_inputname[left] + [right] - # when right is already product of partial splitter - elif right.startswith("OUT"): - if state_inputs[left].shape == out_shape[right]: - axis_for_input[left] = out_axes[right] - else: - raise Exception( - "arrays for scalar operations should have the same size" - ) - current_inputs = out_inputname[right] + [left] - - else: - if state_inputs[right].shape == state_inputs[left].shape: - current_axis = list(range(state_inputs[right].ndim)) - current_shape = state_inputs[left].shape - axis_for_input[left] = current_axis - axis_for_input[right] = current_axis - current_inputs = [left, right] - else: - raise Exception( - "arrays for scalar operations should have the same size" - ) - # adding partial output to the stack - stack.append("OUT_{}".format(len(out_shape))) - out_inputname["OUT_{}".format(len(out_shape))] = current_inputs - out_axes["OUT_{}".format(len(out_shape))] = current_axis - out_shape["OUT_{}".format(len(out_shape))] = current_shape - - # outer splitter - elif el == "*": - right = stack.pop() - left = stack.pop() - # when both, left and right, are already products of partial splitter - if left.startswith("OUT") and right.startswith("OUT"): - # changing all axis_for_input for inputs from right - for key in out_inputname[right]: - axis_for_input[key] = [ - i + len(out_axes[left]) for i in axis_for_input[key] - ] - current_axis = out_axes[left] + [ - i + (out_axes[left][-1] + 1) for i in out_axes[right] - ] - current_shape = tuple([i for i in out_shape[left] + out_shape[right]]) - current_inputs = out_inputname[left] + out_inputname[right] - # when left is already product of partial splitter - elif left.startswith("OUT"): - axis_for_input[right] = [ - i + (out_axes[left][-1] + 1) - for i in range(state_inputs[right].ndim) - ] - current_axis = out_axes[left] + axis_for_input[right] - current_shape = tuple( - [i for i in out_shape[left] + state_inputs[right].shape] - ) - current_inputs = out_inputname[left] + [right] - # when right is already product of partial splitter - elif right.startswith("OUT"): - # changing all axis_for_input for inputs from right - for key in out_inputname[right]: - axis_for_input[key] = [ - i + state_inputs[left].ndim for i in axis_for_input[key] - ] - axis_for_input[left] = [ - i - len(out_shape[right]) + (out_axes[right][-1] + 1) - for i in range(state_inputs[left].ndim) - ] - current_axis = out_axes[right] + [ - i + (out_axes[right][-1] + 1) - for i in range(state_inputs[left].ndim) - ] - current_shape = tuple( - [i for i in state_inputs[left].shape + out_shape[right]] - ) - current_inputs = out_inputname[right] + [left] - else: - axis_for_input[left] = list(range(state_inputs[left].ndim)) - axis_for_input[right] = [ - i + state_inputs[left].ndim for i in range(state_inputs[right].ndim) - ] - current_axis = axis_for_input[left] + axis_for_input[right] - current_shape = tuple( - [i for i in state_inputs[left].shape + state_inputs[right].shape] - ) - current_inputs = [left, right] - # adding partial output to the stack - stack.append("OUT_{}".format(len(out_shape))) - out_inputname["OUT_{}".format(len(out_shape))] = current_inputs - out_axes["OUT_{}".format(len(out_shape))] = current_axis - out_shape["OUT_{}".format(len(out_shape))] = current_shape - - # just a name of input - else: - stack.append(el) - - if len(stack) == 0: - pass - elif len(stack) > 1: - raise Exception("exception from splitting_axis") - elif not stack[0].startswith("OUT"): - current_axis = [i for i in range(state_inputs[stack[0]].ndim)] - axis_for_input[stack[0]] = current_axis - - if current_axis: - ndim = max(current_axis) + 1 - else: - ndim = 0 - return axis_for_input, ndim - - def converter_groups_to_input(group_for_inputs): """ Having axes for all the input fields, @@ -321,133 +179,8 @@ def converter_groups_to_input(group_for_inputs): return input_for_axis, ngr -# TODO: currently not used -def groups_stack_input(group_for_inputs, groups_stack): - """ function that helps testing groups_stack_final - returns groups_stack_final with input names - """ - inputs_for_groups = converter_groups_to_input(group_for_inputs)[0] - groups_stack_input = [] - for stack_el in groups_stack: - stack_el_inp = [] - for gr in stack_el: - stack_el_inp += inputs_for_groups[gr] - groups_stack_input.append(stack_el_inp) - return groups_stack_input - - -# TODO: not used currently, think if I need it -def converting_axis2input(axis_for_input, ndim, state_inputs=None): - """ Having axes for all the input fields, the function returns fields for each axis. """ - input_for_axis = [] - shape = [] - for i in range(ndim): - input_for_axis.append([]) - shape.append(0) - - for inp, axis in axis_for_input.items(): - for (i, ax) in enumerate(axis): - input_for_axis[ax].append(inp) - if state_inputs is not None: - shape[ax] = state_inputs[inp].shape[i] - - if state_inputs is not None: - return input_for_axis, shape - else: - return input_for_axis - - # function used in State if combiner -# TODO: not used currently, think if I need it -def matching_input_from_splitter(splitter_rpn): - """similar to splitting_axis, but without state_input, - finding inputs that are for the same axes. - can't find the final dimensions without inputs. - """ - axes_for_inputs = {} - output_inputs = {} - stack_inp = [] - for el in splitter_rpn: - if el == ".": - right, left = stack_inp.pop(), stack_inp.pop() - out_nm = "OUT{}".format(len(output_inputs)) - if left.startswith("OUT") and right.startswith("OUT"): - output_inputs[out_nm] = output_inputs[left] + output_inputs[right] - axes_for_inputs[out_nm] = axes_for_inputs[left].copy() - elif right.startswith("OUT"): - output_inputs[out_nm] = output_inputs[right] + [left] - axes_for_inputs[out_nm] = axes_for_inputs[right].copy() - axes_for_inputs[left] = axes_for_inputs[right].copy() - elif left.startswith("OUT"): - output_inputs[out_nm] = output_inputs[left] + [right] - axes_for_inputs[out_nm] = axes_for_inputs[left].copy() - axes_for_inputs[right] = axes_for_inputs[left].copy() - else: - output_inputs[out_nm] = [left, right] - axes_for_inputs[left] = [0] - axes_for_inputs[out_nm] = [0] - axes_for_inputs[right] = [0] - stack_inp.append(out_nm) - elif el == "*": - right, left = stack_inp.pop(), stack_inp.pop() - out_nm = "OUT{}".format(len(output_inputs)) - if left.startswith("OUT") and right.startswith("OUT"): - output_inputs[out_nm] = output_inputs[left] + output_inputs[right] - for inp in output_inputs[right] + [right]: - axes_for_inputs[inp] = [ - i + len(axes_for_inputs[left]) for i in axes_for_inputs[inp] - ] - axes_for_inputs[out_nm] = axes_for_inputs[left] + axes_for_inputs[right] - elif right.startswith("OUT"): - output_inputs[out_nm] = output_inputs[right] + [left] - axes_for_inputs[left] = [min(axes_for_inputs[right]) - 1] - axes_for_inputs[out_nm] = axes_for_inputs[left] + axes_for_inputs[right] - elif left.startswith("OUT"): - output_inputs[out_nm] = output_inputs[left] + [right] - axes_for_inputs[right] = [max(axes_for_inputs[left]) + 1] - axes_for_inputs[out_nm] = axes_for_inputs[left] + axes_for_inputs[right] - else: - output_inputs[out_nm] = [left, right] - axes_for_inputs[left] = [0] - axes_for_inputs[right] = [1] - axes_for_inputs[out_nm] = [0, 1] - stack_inp.append(out_nm) - else: - stack_inp.append(el) - - # checking if at the end I have only one element - if len(stack_inp) == 1 and stack_inp[0].startswith("OUT"): - pass - elif len(stack_inp) == 1: - axes_for_inputs[stack_inp[0]] = [0] - else: - raise Exception("something wrong with the splittper") - - # removing "OUT*" elements - axes_for_inputs = dict( - (key, val) - for (key, val) in axes_for_inputs.items() - if not key.startswith("OUT") - ) - - # checking if I have any axes below 0 - all_axes = [] - for _, val in axes_for_inputs.items(): - all_axes += val - min_ax = min(all_axes) - # moving all axes in case min_ax <0 , so everything starts from 0 - if min_ax < 0: - axes_for_inputs = dict( - (key, [v + abs(min_ax) for v in val]) - for (key, val) in axes_for_inputs.items() - ) - - # dimensions - ndim = len(set(all_axes)) - - return axes_for_inputs, ndim - def remove_inp_from_splitter_rpn(splitter_rpn, inputs_to_remove): """modifying splitter_rpn: removing inputs due to combining""" @@ -524,7 +257,7 @@ def rpn2splitter(splitter_rpn): # used in the Node to change names in a splitter and combiner -def change_combiner(combiner, name): +def add_name_combiner(combiner, name): combiner_changed = [] for comb in combiner: if "." not in comb: @@ -534,7 +267,7 @@ def change_combiner(combiner, name): return combiner_changed -def change_splitter(splitter, name): +def add_name_splitter(splitter, name): """changing names of splitter: adding names of the node""" if isinstance(splitter, str): return _add_name([splitter], name)[0] @@ -604,8 +337,7 @@ def input_shape(in1): return tuple(shape) -# dj: changing the function so it takes splitter_rpn -def _splits(splitter_rpn, inputs, inner_inputs=None): +def splits(splitter_rpn, inputs, inner_inputs=None): """ Process splitter rpn from left to right """ stack = [] @@ -724,7 +456,7 @@ def _splits(splitter_rpn, inputs, inner_inputs=None): # dj: TODO: do I need keys? -def _splits_groups(splitter_rpn, combiner=None, inner_inputs=None): +def splits_groups(splitter_rpn, combiner=None, inner_inputs=None): """ Process splitter rpn from left to right """ if not splitter_rpn: @@ -859,8 +591,6 @@ def _splits_groups(splitter_rpn, combiner=None, inner_inputs=None): def _single_op_splits( op_single, inputs, inner_inputs, shapes_var, previous_states_ind, keys_fromLeftSpl ): - import numpy as np - if op_single.startswith("_"): return ( previous_states_ind[op_single][0], @@ -954,9 +684,7 @@ def map_splits(split_iter, inputs): yield {k: list(flatten(ensure_list(inputs[k])))[v] for k, v in split.items()} -""" Functions for merging and completing splitters in states. - Used only in State, could be moved to that class -""" +# Functions for merging and completing splitters in states. def connect_splitters(splitter, other_states): diff --git a/pydra/engine/state.py b/pydra/engine/state.py index f733a4ef54..79a61ceb79 100644 --- a/pydra/engine/state.py +++ b/pydra/engine/state.py @@ -1,27 +1,76 @@ from copy import deepcopy -from . import auxiliary as aux +from . import helpers_state as hlpst from .specs import BaseSpec class State: + """ A class that specifies a State of all tasks, + it's only used when a task have a splitter. + It contains all information about splitter, combiner, final splitter, + and input values for specific task states + (specified by the splitter and the input). + I also contains information about the final groups and the final splitter + if combiner is available. + + Attributes: + name (str): name of the state that is the same as name of the task + splitter (str, tuple, list): can be a str (name of a single input), + tuple for scalar splitter, or list for outer splitter + splitter_rpn_compact (list): splitter in RPN notation, using a compact + notation for splitter from previous states, e.g. _NA + splitter_rpn (list): splitter represented in RPN (Reverse Polish Not.), + unwrapping splitters from previous states + combiner (list): list of fields that should be combined + (order is not important) + splitter_final: final splitter that includes the combining process + other_states (dict): used to create connections with previous states + {name of a previous state: + (prefious state, input from current state needed the connection)} + inner_inputs (dict): used to create connections with previous states + {"{self.name}.input name for current inp": previous state} + states_ind (list(dict)): dictionary for every state that contains + indices for all state inputs (i.e. inputs that are part of the splitter) + states_val (list(dict)): dictionary for every state that contains + values for all state inputs (i.e. inputs that are part of the splitter) + inputs_ind (list(dict)): dictionary for every state that contains + indices for all task inputs (i.e. inputs that are relevant + for current task, can be outputs from previous nodes) + group_for_inputs (dict): specifying groups (axes) for each input field + (depends on the splitter) + group_for_inputs_final (dict): specifying final groups (axes) + for each input field (depends on the splitter and combiner) + groups_stack_final (list): specify stack of groups/axes (used to + determine which field could be combined) + final_combined_ind_mapping (dict): mapping between final indices + after combining and partial indices of the results + """ + def __init__(self, name, splitter=None, combiner=None, other_states=None): + """ + :param name (str): name (should be the same as task name) + :param splitter (str, tuple or list): splitter of a task + :param combiner (str, list): field/fields used to combine results + :param other_states (dict): {name of a previous state: + (prefious state, input from current state needed the connection)} + """ self.name = name - self.other_states = other_states + if not other_states: + self.other_states = {} + else: + self.other_states = other_states self.splitter = splitter self.connect_splitters() self.combiner = combiner - if not self.other_states: - self.other_states = {} self.inner_inputs = {} for name, (st, inp) in self.other_states.items(): - if f"_{st.name}" in self.splitter_rpn_nost: + if f"_{st.name}" in self.splitter_rpn_compact: self.inner_inputs[f"{self.name}.{inp}"] = st self.set_input_groups() self.set_splitter_final() self.states_val = [] self.inputs_ind = [] - self.final_groups_mapping = {} + self.final_combined_ind_mapping = {} def __str__(self): return f"State for {self.name} with a splitter: {self.splitter} and combiner: {self.combiner}" @@ -33,34 +82,36 @@ def splitter(self): @splitter.setter def splitter(self, splitter): if splitter: - self._splitter = aux.change_splitter(splitter, self.name) - self.splitter_rpn = aux.splitter2rpn( + self._splitter = hlpst.add_name_splitter(splitter, self.name) + self.splitter_rpn = hlpst.splitter2rpn( deepcopy(self._splitter), other_states=self.other_states ) - self.splitter_rpn_nost = aux.splitter2rpn( + self.splitter_rpn_compact = hlpst.splitter2rpn( deepcopy(self._splitter), other_states=self.other_states, state_fields=False, ) - for spl in self.splitter_rpn_nost: - if ( + # checking that all fields in splitter are either fields of current state, + # i.e. {self.name}.input + # or entire splitter from previous state, e.g. _NA + for spl in self.splitter_rpn_compact: + if not ( spl in [".", "*"] or spl.startswith("_") or spl.split(".")[0] == self.name ): - pass - else: raise Exception( "can't include {} in the splitter, consider using _{}".format( spl, spl.split(".")[0] ) ) + # splitter_final will take into account a combiner self.splitter_final = self._splitter else: self._splitter = None self.splitter_final = None self.splitter_rpn = [] - self.splitter_rpn_nost = [] + self.splitter_rpn_compact = [] @property def combiner(self): @@ -75,12 +126,14 @@ def combiner(self, combiner): combiner = [combiner] elif type(combiner) is not list: raise Exception("combiner should be a string or a list") - self._combiner = aux.change_combiner(combiner, self.name) + self._combiner = hlpst.add_name_combiner(combiner, self.name) if set(self._combiner) - set(self.splitter_rpn): raise Exception("all combiners should be in the splitter") + # combiners from the current fields: i.e. {self.name}.input self._right_combiner = [ comb for comb in self._combiner if self.name in comb ] + # combiners from the previous states self._left_combiner = list(set(self._combiner) - set(self._right_combiner)) else: self._combiner = [] @@ -90,44 +143,44 @@ def combiner(self, combiner): def connect_splitters(self): """ connect splitters from previous nodes, - evaluate Left (previous nodes) and Right (current node) parts + evaluate Left (the part from previous states) and Right (current state) parts """ if self.other_states: - self.splitter, self._left_splitter, self._right_splitter = aux.connect_splitters( + self.splitter, self._left_splitter, self._right_splitter = hlpst.connect_splitters( splitter=self.splitter, other_states=self.other_states ) # left rpn part, but keeping the names of the nodes, e.g. [_NA, _NB, *] - self._left_splitter_rpn_nost = aux.splitter2rpn( + self._left_splitter_rpn_compact = hlpst.splitter2rpn( deepcopy(self._left_splitter), other_states=self.other_states, state_fields=False, ) - self._left_splitter_rpn = aux.splitter2rpn( + self._left_splitter_rpn = hlpst.splitter2rpn( deepcopy(self._left_splitter), other_states=self.other_states ) else: # if other_states is empty there is only Right part self._left_splitter = None - self._left_splitter_rpn_nost = [] + self._left_splitter_rpn_compact = [] self._left_splitter_rpn = [] self._right_splitter = self.splitter - self._right_splitter_rpn = aux.splitter2rpn( + self._right_splitter_rpn = hlpst.splitter2rpn( deepcopy(self._right_splitter), other_states=self.other_states ) def set_splitter_final(self): """evaluate a final splitter after combining""" - _splitter_rpn_final = aux.remove_inp_from_splitter_rpn( + _splitter_rpn_final = hlpst.remove_inp_from_splitter_rpn( deepcopy(self.splitter_rpn), self.right_combiner_all + self.left_combiner_all, ) - self.splitter_final = aux.rpn2splitter(_splitter_rpn_final) - self.splitter_rpn_final = aux.splitter2rpn( + self.splitter_final = hlpst.rpn2splitter(_splitter_rpn_final) + self.splitter_rpn_final = hlpst.splitter2rpn( self.splitter_final, other_states=self.other_states ) def set_input_groups(self): """evaluate groups, especially the final groups that address the combiner""" - keys_f, group_for_inputs_f, groups_stack_f, combiner_all = aux._splits_groups( + keys_f, group_for_inputs_f, groups_stack_f, combiner_all = hlpst.splits_groups( self._right_splitter_rpn, combiner=self._right_combiner, inner_inputs=self.inner_inputs, @@ -158,16 +211,16 @@ def merge_previous_states(self): self.keys_final = [] self.left_combiner_all = [] if self._left_combiner: - _, _, _, self._left_combiner = aux._splits_groups( + _, _, _, self._left_combiner = hlpst.splits_groups( self._left_splitter_rpn, combiner=self._left_combiner ) - for i, left_nm in enumerate(self._left_splitter_rpn_nost): + for i, left_nm in enumerate(self._left_splitter_rpn_compact): if left_nm in ["*", "."]: continue if ( - i + 1 < len(self._left_splitter_rpn_nost) - and self._left_splitter_rpn_nost[i + 1] == "." + i + 1 < len(self._left_splitter_rpn_compact) + and self._left_splitter_rpn_compact[i + 1] == "." ): last_gr = last_gr - 1 st = self.other_states[left_nm[1:]][0] @@ -178,7 +231,7 @@ def merge_previous_states(self): if st_combiner: # keys and groups from previous states # after taking into account combiner from current state - keys_f_st, group_for_inputs_f_st, groups_stack_f_st, combiner_all_st = aux._splits_groups( + keys_f_st, group_for_inputs_f_st, groups_stack_f_st, combiner_all_st = hlpst.splits_groups( st.splitter_rpn_final, combiner=st_combiner, inner_inputs=st.inner_inputs, @@ -235,7 +288,7 @@ def prepare_states(self, inputs): and state values (specific elements from inputs that can be used running interfaces) """ if isinstance(inputs, BaseSpec): - self.inputs = aux.inputs_types_to_dict(self.name, inputs) + self.inputs = hlpst.inputs_types_to_dict(self.name, inputs) else: self.inputs = inputs if self.other_states: @@ -249,58 +302,60 @@ def prepare_states(self, inputs): self.prepare_states_val() def prepare_states_ind(self): - """using aux._splits to calculate a list of dictionaries with state indices""" + """using hlpst.splits to calculate a list of dictionaries with state indices""" # removing elements that are connected to inner splitter - # (they will be taken into account in aux._splits anyway) + # (they will be taken into account in hlpst.splits anyway) # _comb part will be used in prepare_states_combined_ind elements_to_remove = [] elements_to_remove_comb = [] for name, (st, inp) in self.other_states.items(): if ( "{}.{}".format(self.name, inp) in self.splitter_rpn - and "_{}".format(name) in self.splitter_rpn_nost + and "_{}".format(name) in self.splitter_rpn_compact ): elements_to_remove.append("_{}".format(name)) if "{}.{}".format(self.name, inp) not in self.combiner: elements_to_remove_comb.append("_{}".format(name)) - partial_rpn = aux.remove_inp_from_splitter_rpn( - deepcopy(self.splitter_rpn_nost), elements_to_remove + partial_rpn = hlpst.remove_inp_from_splitter_rpn( + deepcopy(self.splitter_rpn_compact), elements_to_remove ) - values_out_pr, keys_out_pr, _, kL = aux._splits( + values_out_pr, keys_out_pr, _, kL = hlpst.splits( partial_rpn, self.inputs, inner_inputs=self.inner_inputs ) values_pr = list(values_out_pr) self.ind_l = values_pr self.keys = keys_out_pr - self.states_ind = list(aux.iter_splits(values_pr, self.keys)) + self.states_ind = list(hlpst.iter_splits(values_pr, self.keys)) self.keys_final = self.keys if self.combiner: self.prepare_states_combined_ind(elements_to_remove_comb) else: self.ind_l_final = self.ind_l self.keys_final = self.keys - self.final_groups_mapping = {i: [i] for i in range(len(self.states_ind))} + self.final_combined_ind_mapping = { + i: [i] for i in range(len(self.states_ind)) + } self.states_ind_final = self.states_ind return self.states_ind def prepare_states_combined_ind(self, elements_to_remove_comb): """preparing the final list of dictionaries with indices after combiner""" - partial_rpn_nost = aux.remove_inp_from_splitter_rpn( - deepcopy(self.splitter_rpn_nost), elements_to_remove_comb + partial_rpn_compact = hlpst.remove_inp_from_splitter_rpn( + deepcopy(self.splitter_rpn_compact), elements_to_remove_comb ) # combiner can have parts from the left splitter, so have to have rpn with states - partial_rpn = aux.splitter2rpn( - aux.rpn2splitter(partial_rpn_nost), other_states=self.other_states + partial_rpn = hlpst.splitter2rpn( + hlpst.rpn2splitter(partial_rpn_compact), other_states=self.other_states ) - combined_rpn = aux.remove_inp_from_splitter_rpn( + combined_rpn = hlpst.remove_inp_from_splitter_rpn( deepcopy(partial_rpn), self.right_combiner_all + self.left_combiner_all ) # TODO: create a function for this!! if combined_rpn: - val_r, key_r, _, _ = aux._splits( + val_r, key_r, _, _ = hlpst.splits( combined_rpn, self.inputs, inner_inputs=self.inner_inputs ) values = list(val_r) @@ -315,23 +370,27 @@ def prepare_states_combined_ind(self, elements_to_remove_comb): self.keys_final = keys_out # groups after combiner ind_map = { - tuple(aux.flatten(tup, max_depth=10)): ind + tuple(hlpst.flatten(tup, max_depth=10)): ind for ind, tup in enumerate(self.ind_l_final) } - self.final_groups_mapping = {i: [] for i in range(len(self.ind_l_final))} + self.final_combined_ind_mapping = { + i: [] for i in range(len(self.ind_l_final)) + } for ii, st in enumerate(self.states_ind): ind_f = tuple([st[k] for k in self.keys_final]) - self.final_groups_mapping[ind_map[ind_f]].append(ii) + self.final_combined_ind_mapping[ind_map[ind_f]].append(ii) else: self.ind_l_final = values self.keys_final = keys_out # should be 0 or None? - self.final_groups_mapping = {0: list(range(len(self.states_ind)))} - self.states_ind_final = list(aux.iter_splits(self.ind_l_final, self.keys_final)) + self.final_combined_ind_mapping = {0: list(range(len(self.states_ind)))} + self.states_ind_final = list( + hlpst.iter_splits(self.ind_l_final, self.keys_final) + ) def prepare_states_val(self): """evaluate states values having states indices""" - self.states_val = list(aux.map_splits(self.states_ind, self.inputs)) + self.states_val = list(hlpst.map_splits(self.states_ind, self.inputs)) return self.states_val def prepare_inputs(self): @@ -341,14 +400,14 @@ def prepare_inputs(self): # removing elements that come from connected states elements_to_remove = [ spl - for spl in self.splitter_rpn_nost + for spl in self.splitter_rpn_compact if spl[1:] in self.other_states.keys() ] - partial_rpn = aux.remove_inp_from_splitter_rpn( - deepcopy(self.splitter_rpn_nost), elements_to_remove + partial_rpn = hlpst.remove_inp_from_splitter_rpn( + deepcopy(self.splitter_rpn_compact), elements_to_remove ) if partial_rpn: - values_inp, keys_inp, _, _ = aux._splits( + values_inp, keys_inp, _, _ = hlpst.splits( partial_rpn, self.inputs, inner_inputs=self.inner_inputs ) inputs_ind = values_inp @@ -362,7 +421,7 @@ def prepare_inputs(self): keys_inp_prev = [] inputs_ind_prev = [] connected_to_inner = [] - for ii, el in enumerate(self._left_splitter_rpn_nost): + for ii, el in enumerate(self._left_splitter_rpn_compact): if el in ["*", "."]: continue st, inp = self.other_states[el[1:]] @@ -376,26 +435,26 @@ def prepare_inputs(self): st_ind = range(len(st.states_ind_final)) if inputs_ind_prev: # in case the Left part has scalar parts (not very well tested) - if self._left_splitter_rpn_nost[ii + 1] == ".": - inputs_ind_prev = aux.op["."](inputs_ind_prev, st_ind) + if self._left_splitter_rpn_compact[ii + 1] == ".": + inputs_ind_prev = hlpst.op["."](inputs_ind_prev, st_ind) else: - inputs_ind_prev = aux.op["*"](inputs_ind_prev, st_ind) + inputs_ind_prev = hlpst.op["*"](inputs_ind_prev, st_ind) else: - inputs_ind_prev = aux.op["*"](st_ind) + inputs_ind_prev = hlpst.op["*"](st_ind) keys_inp_prev += ["{}.{}".format(self.name, inp)] keys_inp = keys_inp_prev + keys_inp if inputs_ind and inputs_ind_prev: - inputs_ind = aux.op["*"](inputs_ind_prev, inputs_ind) + inputs_ind = hlpst.op["*"](inputs_ind_prev, inputs_ind) elif inputs_ind: - inputs_ind = aux.op["*"](inputs_ind) + inputs_ind = hlpst.op["*"](inputs_ind) elif inputs_ind_prev: - inputs_ind = aux.op["*"](inputs_ind_prev) + inputs_ind = hlpst.op["*"](inputs_ind_prev) else: inputs_ind = [] # iter_splits using inputs from current state/node - self.inputs_ind = list(aux.iter_splits(inputs_ind, keys_inp)) + self.inputs_ind = list(hlpst.iter_splits(inputs_ind, keys_inp)) # removing elements that are connected to inner splitter for el in connected_to_inner: [dict.pop(el) for dict in self.inputs_ind] diff --git a/pydra/engine/tests/test_auxiliary.py b/pydra/engine/tests/test_helpers_state.py similarity index 71% rename from pydra/engine/tests/test_auxiliary.py rename to pydra/engine/tests/test_helpers_state.py index ec850864bd..8b8e52a753 100644 --- a/pydra/engine/tests/test_auxiliary.py +++ b/pydra/engine/tests/test_helpers_state.py @@ -1,6 +1,5 @@ -from .. import auxiliary as aux +from .. import helpers_state as hlpst -import numpy as np import pytest @@ -46,8 +45,8 @@ def __init__( ], ) def test_splits_groups(splitter, keys_exp, groups_exp, grstack_exp): - splitter_rpn = aux.splitter2rpn(splitter) - keys_f, groups_f, grstack_f, _ = aux._splits_groups(splitter_rpn) + splitter_rpn = hlpst.splitter2rpn(splitter) + keys_f, groups_f, grstack_f, _ = hlpst.splits_groups(splitter_rpn) assert set(keys_f) == set(keys_exp) assert groups_f == groups_exp @@ -80,8 +79,8 @@ def test_splits_groups_comb( grstack_final_exp, combiner_all_exp, ): - splitter_rpn = aux.splitter2rpn(splitter) - keys_final, groups_final, grstack_final, combiner_all = aux._splits_groups( + splitter_rpn = hlpst.splitter2rpn(splitter) + keys_final, groups_final, grstack_final, combiner_all = hlpst.splits_groups( splitter_rpn, combiner ) assert keys_final == keys_final_exp @@ -274,12 +273,12 @@ def test_splits_1b(splitter, values, keys, splits): "z": [7, 8], "x": [[10, 100], [20, 200]], } - splitter_rpn = aux.splitter2rpn(splitter) - values_out, keys_out, _, _ = aux._splits(splitter_rpn, inputs) + splitter_rpn = hlpst.splitter2rpn(splitter) + values_out, keys_out, _, _ = hlpst.splits(splitter_rpn, inputs) value_list = list(values_out) assert keys == keys_out assert values == value_list - splits_out = list(aux.map_splits(aux.iter_splits(value_list, keys_out), inputs)) + splits_out = list(hlpst.map_splits(hlpst.iter_splits(value_list, keys_out), inputs)) assert splits_out == splits @@ -296,12 +295,12 @@ def test_splits_1b(splitter, values, keys, splits): ], ) def test_splits_1c(splitter, inputs, mismatch): - splitter_rpn = aux.splitter2rpn(splitter) + splitter_rpn = hlpst.splitter2rpn(splitter) if mismatch: with pytest.raises(ValueError): - aux._splits(splitter_rpn, inputs) + hlpst.splits(splitter_rpn, inputs) else: - aux._splits(splitter_rpn, inputs) + hlpst.splits(splitter_rpn, inputs) @pytest.mark.parametrize( @@ -335,13 +334,13 @@ def test_splits_1c(splitter, inputs, mismatch): ) def test_splits_1d(splitter, values, keys, shapes, splits): inputs = {"a": [1, 2], "v": ["a", "b"], "c": [[3, 4], [5, 6]]} - splitter_rpn = aux.splitter2rpn(splitter) - values_out, keys_out, shapes_out, _ = aux._splits(splitter_rpn, inputs) + splitter_rpn = hlpst.splitter2rpn(splitter) + values_out, keys_out, shapes_out, _ = hlpst.splits(splitter_rpn, inputs) value_list = list(values_out) assert keys == keys_out assert values == value_list assert shapes == shapes_out - splits_out = list(aux.map_splits(aux.iter_splits(value_list, keys_out), inputs)) + splits_out = list(hlpst.map_splits(hlpst.iter_splits(value_list, keys_out), inputs)) assert splits_out == splits @@ -371,12 +370,12 @@ def test_splits_1e(splitter, values, keys, splits): # dj?: not sure if I like that this example works # c - is like an inner splitter inputs = {"a": [1, 2], "v": ["a", "b"], "c": [[3, 4], 5]} - splitter_rpn = aux.splitter2rpn(splitter) - values_out, keys_out, _, _ = aux._splits(splitter_rpn, inputs) + splitter_rpn = hlpst.splitter2rpn(splitter) + values_out, keys_out, _, _ = hlpst.splits(splitter_rpn, inputs) value_list = list(values_out) assert keys == keys_out assert values == value_list - splits_out = list(aux.map_splits(aux.iter_splits(value_list, keys_out), inputs)) + splits_out = list(hlpst.map_splits(hlpst.iter_splits(value_list, keys_out), inputs)) assert splits_out == splits @@ -413,12 +412,12 @@ def test_splits_2(splitter_rpn, inner_inputs, values, keys, splits): [["d211", "d212"], ["d221", "d222"]], ], } - values_out, keys_out, _, _ = aux._splits( + values_out, keys_out, _, _ = hlpst.splits( splitter_rpn, inputs, inner_inputs=inner_inputs ) value_list = list(values_out) assert keys == keys_out - splits_out = list(aux.map_splits(aux.iter_splits(value_list, keys_out), inputs)) + splits_out = list(hlpst.map_splits(hlpst.iter_splits(value_list, keys_out), inputs)) assert splits_out == splits @@ -438,7 +437,7 @@ def test_splits_2(splitter_rpn, inner_inputs, values, keys, splits): ], ) def test_splitter2rpn(splitter, rpn): - assert aux.splitter2rpn(splitter) == rpn + assert hlpst.splitter2rpn(splitter) == rpn @pytest.mark.parametrize( @@ -451,7 +450,7 @@ def test_splitter2rpn(splitter, rpn): ], ) def test_splitter2rpn_2(splitter, rpn): - assert aux.splitter2rpn(splitter) == rpn + assert hlpst.splitter2rpn(splitter) == rpn @pytest.mark.parametrize( @@ -470,7 +469,7 @@ def test_splitter2rpn_2(splitter, rpn): ], ) def test_rpn2splitter(splitter, rpn): - assert aux.rpn2splitter(rpn) == splitter + assert hlpst.rpn2splitter(rpn) == splitter @pytest.mark.parametrize( @@ -494,7 +493,7 @@ def test_rpn2splitter(splitter, rpn): ], ) def test_splitter2rpn_wf_splitter_1(splitter, other_states, rpn): - assert aux.splitter2rpn(splitter, other_states=other_states) == rpn + assert hlpst.splitter2rpn(splitter, other_states=other_states) == rpn @pytest.mark.parametrize( @@ -519,7 +518,8 @@ def test_splitter2rpn_wf_splitter_1(splitter, other_states, rpn): ) def test_splitter2rpn_wf_splitter_3(splitter, other_states, rpn): assert ( - aux.splitter2rpn(splitter, other_states=other_states, state_fields=False) == rpn + hlpst.splitter2rpn(splitter, other_states=other_states, state_fields=False) + == rpn ) @@ -531,185 +531,8 @@ def test_splitter2rpn_wf_splitter_3(splitter, other_states, rpn): (("a", ["b", "c"]), ("Node.a", ["Node.b", "Node.c"])), ], ) -def test_change_splitter(splitter, splitter_changed): - assert aux.change_splitter(splitter, "Node") == splitter_changed - - -@pytest.mark.parametrize( - "inputs, rpn, expected", - [ - ({"a": np.array([1, 2])}, ["a"], {"a": [0]}), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4])}, - ["a", "b", "."], - {"a": [0], "b": [0]}, - ), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4, 1])}, - ["a", "b", "*"], - {"a": [0], "b": [1]}, - ), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4]), "c": np.array([1, 2, 3])}, - ["a", "b", ".", "c", "*"], - {"a": [0], "b": [0], "c": [1]}, - ), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4]), "c": np.array([1, 2, 3])}, - ["c", "a", "b", ".", "*"], - {"a": [1], "b": [1], "c": [0]}, - ), - ( - { - "a": np.array([[1, 2], [1, 2]]), - "b": np.array([[3, 4], [3, 3]]), - "c": np.array([1, 2, 3]), - }, - ["a", "b", ".", "c", "*"], - {"a": [0, 1], "b": [0, 1], "c": [2]}, - ), - ( - { - "a": np.array([[1, 2], [1, 2]]), - "b": np.array([[3, 4], [3, 3]]), - "c": np.array([1, 2, 3]), - }, - ["c", "a", "b", ".", "*"], - {"a": [1, 2], "b": [1, 2], "c": [0]}, - ), - ( - { - "a": np.array([1, 2]), - "b": np.array([3, 3]), - "c": np.array([[1, 2], [3, 4]]), - }, - ["a", "b", "*", "c", "."], - {"a": [0], "b": [1], "c": [0, 1]}, - ), - ( - { - "a": np.array([1, 2]), - "b": np.array([3, 4, 5]), - "c": np.array([1, 2]), - "d": np.array([1, 2, 3]), - }, - ["a", "b", "*", "c", "d", "*", "."], - {"a": [0], "b": [1], "c": [0], "d": [1]}, - ), - ( - { - "a": np.array([1, 2]), - "b": np.array([3, 4]), - "c": np.array([1, 2, 3]), - "d": np.array([1, 2, 3]), - }, - ["a", "b", ".", "c", "d", ".", "*"], - {"a": [0], "b": [0], "c": [1], "d": [1]}, - ), - ], -) -def test_splitting_axis(inputs, rpn, expected): - res = aux.splitting_axis(inputs, rpn)[0] - print(res) - for key in inputs.keys(): - assert res[key] == expected[key] - - -def test_splitting_axis_error(): - with pytest.raises(Exception): - aux.splitting_axis( - {"a": np.array([1, 2]), "b": np.array([3, 4, 5])}, ["a", "b", "."] - ) - - -@pytest.mark.parametrize( - "inputs, axis_inputs, ndim, expected", - [ - ({"a": np.array([1, 2])}, {"a": [0]}, 1, [["a"]]), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4])}, - {"a": [0], "b": [0]}, - 1, - [["a", "b"]], - ), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4, 1])}, - {"a": [0], "b": [1]}, - 2, - [["a"], ["b"]], - ), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4]), "c": np.array([1, 2, 3])}, - {"a": [0], "b": [0], "c": [1]}, - 2, - [["a", "b"], ["c"]], - ), - ( - {"a": np.array([1, 2]), "b": np.array([3, 4]), "c": np.array([1, 2, 3])}, - {"a": [1], "b": [1], "c": [0]}, - 2, - [["c"], ["a", "b"]], - ), - ( - { - "a": np.array([[1, 2], [1, 2]]), - "b": np.array([[3, 4], [3, 3]]), - "c": np.array([1, 2, 3]), - }, - {"a": [0, 1], "b": [0, 1], "c": [2]}, - 3, - [["a", "b"], ["a", "b"], ["c"]], - ), - ( - { - "a": np.array([[1, 2], [1, 2]]), - "b": np.array([[3, 4], [3, 3]]), - "c": np.array([1, 2, 3]), - }, - {"a": [1, 2], "b": [1, 2], "c": [0]}, - 3, - [["c"], ["a", "b"], ["a", "b"]], - ), - ], -) -def test_converting_axis2input(inputs, axis_inputs, ndim, expected): - assert ( - aux.converting_axis2input( - state_inputs=inputs, axis_for_input=axis_inputs, ndim=ndim - )[0] - == expected - ) - - -@pytest.mark.parametrize( - "rpn, expected, ndim", - [ - (["a"], {"a": [0]}, 1), - (["a", "b", "."], {"a": [0], "b": [0]}, 1), - (["a", "b", "*"], {"a": [0], "b": [1]}, 2), - (["a", "b", ".", "c", "*"], {"a": [0], "b": [0], "c": [1]}, 2), - (["c", "a", "b", ".", "*"], {"a": [1], "b": [1], "c": [0]}, 2), - (["a", "b", ".", "c", "*"], {"a": [0], "b": [0], "c": [1]}, 2), - (["c", "a", "b", ".", "*"], {"a": [1], "b": [1], "c": [0]}, 2), - (["a", "b", "*", "c", "."], {"a": [0], "b": [1], "c": [0, 1]}, 2), - ( - ["a", "b", "*", "c", "d", "*", "."], - {"a": [0], "b": [1], "c": [0], "d": [1]}, - 2, - ), - ( - ["a", "b", ".", "c", "d", ".", "*"], - {"a": [0], "b": [0], "c": [1], "d": [1]}, - 2, - ), - ], -) -def test_matching_input_from_splitter(rpn, expected, ndim): - res = aux.matching_input_from_splitter(rpn) - print(res) - for key in expected.keys(): - assert res[0][key] == expected[key] - assert res[1] == ndim +def test_addname_splitter(splitter, splitter_changed): + assert hlpst.add_name_splitter(splitter, "Node") == splitter_changed @pytest.mark.parametrize( @@ -732,7 +555,7 @@ def test_remove_inp_from_splitter_rpn( splitter_rpn, input_to_remove, final_splitter_rpn ): assert ( - aux.remove_inp_from_splitter_rpn(splitter_rpn, input_to_remove) + hlpst.remove_inp_from_splitter_rpn(splitter_rpn, input_to_remove) == final_splitter_rpn ) @@ -745,7 +568,7 @@ def test_remove_inp_from_splitter_rpn( ], ) def test_groups_to_input(group_for_inputs, input_for_groups, ndim): - res = aux.converter_groups_to_input(group_for_inputs) + res = hlpst.converter_groups_to_input(group_for_inputs) assert res[0] == input_for_groups assert res[1] == ndim @@ -809,7 +632,7 @@ def test_groups_to_input(group_for_inputs, input_for_groups, ndim): def test_connect_splitters( splitter, other_states, expected_splitter, expected_left, expected_right ): - updated_splitter, left_splitter, right_splitter = aux.connect_splitters( + updated_splitter, left_splitter, right_splitter = hlpst.connect_splitters( splitter, other_states ) assert updated_splitter == expected_splitter @@ -834,18 +657,4 @@ def test_connect_splitters( ) def test_connect_splitters_exception(splitter, other_states): with pytest.raises(Exception): - aux.connect_splitters(splitter, other_states) - - -@pytest.mark.parametrize( - "group_for_inputs, groups_stack, groups_stack_input_exp", - [ - ({"a": 0, "b": 1}, [[0, 1]], [["a", "b"]]), - ({"a": 0, "b": 1, "c": [0, 1]}, [[0, 1]], [["a", "b", "c"]]), - ({"a": 0, "b": 1}, [[0], [1]], [["a"], ["b"]]), - ], -) -def test_groups_stack_input(group_for_inputs, groups_stack, groups_stack_input_exp): - groups_stack_input = aux.groups_stack_input(group_for_inputs, groups_stack) - for i, grs in enumerate(groups_stack_input): - assert set(grs) == set(groups_stack_input_exp[i]) + hlpst.connect_splitters(splitter, other_states) diff --git a/pydra/engine/tests/test_state.py b/pydra/engine/tests/test_state.py index e1d09f79ca..41e071382a 100644 --- a/pydra/engine/tests/test_state.py +++ b/pydra/engine/tests/test_state.py @@ -1023,12 +1023,12 @@ def test_state_connect_combine_1(): ] assert st1.states_ind_final == [{"NA.b": 0}, {"NA.b": 1}] assert st1.keys_final == ["NA.b"] - assert st1.final_groups_mapping == {0: [0, 2], 1: [1, 3]} + assert st1.final_combined_ind_mapping == {0: [0, 2], 1: [1, 3]} assert st2.states_ind == [{"NA.b": 0}, {"NA.b": 1}] assert st2.states_val == [{"NA.b": 10}, {"NA.b": 20}] assert st2.keys_final == ["NA.b"] - assert st2.final_groups_mapping == {0: [0], 1: [1]} + assert st2.final_combined_ind_mapping == {0: [0], 1: [1]} st2.prepare_inputs() assert st2.inputs_ind == [{"NB.c": 0}, {"NB.c": 1}] @@ -1070,7 +1070,7 @@ def test_state_connect_combine_2(): ] assert st1.states_ind_final == [{"NA.b": 0}, {"NA.b": 1}] assert st1.keys_final == ["NA.b"] - assert st1.final_groups_mapping == {0: [0, 2], 1: [1, 3]} + assert st1.final_combined_ind_mapping == {0: [0, 2], 1: [1, 3]} assert st2.states_ind == [ {"NA.b": 0, "NB.d": 0}, @@ -1085,7 +1085,7 @@ def test_state_connect_combine_2(): {"NA.b": 20, "NB.d": 1}, ] assert st2.keys_final == ["NA.b", "NB.d"] - assert st2.final_groups_mapping == {0: [0], 1: [1], 2: [2], 3: [3]} + assert st2.final_combined_ind_mapping == {0: [0], 1: [1], 2: [2], 3: [3]} st2.prepare_inputs() assert st2.inputs_ind == [ @@ -1134,7 +1134,7 @@ def test_state_connect_combine_3(): ] assert st1.states_ind_final == [{"NA.b": 0}, {"NA.b": 1}] assert st1.keys_final == ["NA.b"] - assert st1.final_groups_mapping == {0: [0, 2], 1: [1, 3]} + assert st1.final_combined_ind_mapping == {0: [0, 2], 1: [1, 3]} assert st2.states_ind == [ {"NA.b": 0, "NB.d": 0}, @@ -1150,7 +1150,7 @@ def test_state_connect_combine_3(): ] assert st2.states_ind_final == [{"NA.b": 0}, {"NA.b": 1}] assert st2.keys_final == ["NA.b"] - assert st2.final_groups_mapping == {0: [0, 1], 1: [2, 3]} + assert st2.final_combined_ind_mapping == {0: [0, 1], 1: [2, 3]} st2.prepare_inputs() assert st2.inputs_ind == [ diff --git a/pydra/engine/tests/test_workflow.py b/pydra/engine/tests/test_workflow.py index c74fe29625..e45a4412e4 100644 --- a/pydra/engine/tests/test_workflow.py +++ b/pydra/engine/tests/test_workflow.py @@ -1397,7 +1397,7 @@ def test_wf_nostate_cachelocations(plugin, tmpdir): # checking execution time assert t1 > 3 - assert t2 < 0.1 + assert t2 < 0.5 # checking if the second wf didn't run again assert wf1.output_dir.exists() @@ -1456,7 +1456,7 @@ def test_wf_state_cachelocations(plugin, tmpdir): # checking execution time assert t1 > 3 - assert t2 < 0.1 + assert t2 < 0.5 # checking all directories assert wf1.output_dir @@ -1524,7 +1524,7 @@ def test_wf_state_cachelocations_updateinp(plugin, tmpdir): # checking execution time assert t1 > 3 - assert t2 < 0.3 + assert t2 < 0.5 # checking all directories assert wf1.output_dir @@ -1755,7 +1755,7 @@ def test_wf_ndstate_cachelocations(plugin, tmpdir): # checking execution time assert t1 > 3 - assert t2 < 0.1 + assert t2 < 0.5 # checking all directories assert wf1.output_dir.exists() @@ -1817,7 +1817,7 @@ def test_wf_ndstate_cachelocations_updatespl(plugin, tmpdir): # checking execution time assert t1 > 3 - assert t2 < 0.1 + assert t2 < 0.5 # checking all directories assert wf1.output_dir.exists()