Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 48 additions & 22 deletions pydra/engine/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def __init__(
for field in dc.fields(klass)
if field.name not in ["_func", "_graph_checksums"]
]
# dictionary to save the connections with lazy fields
self.inp_lf = {}
self.state = None
self._output = {}
self._result = {}
Expand Down Expand Up @@ -124,8 +126,6 @@ def __init__(
self.allow_cache_override = True
self._checksum = None

# dictionary of results from tasks
self.results_dict = {}
self.plugin = None
self.hooks = TaskHook()

Expand Down Expand Up @@ -165,6 +165,10 @@ def version(self):
def checksum(self):
"""calculating checksum
"""
# if checksum is called before run the _graph_checksums is not ready
if is_workflow(self) and self.inputs._graph_checksums is None:
self.inputs._graph_checksums = [nd.checksum for nd in self.graph_sorted]

input_hash = self.inputs.hash
if self.state is None:
self._checksum = create_checksum(self.__class__.__name__, input_hash)
Expand All @@ -176,6 +180,30 @@ def checksum(self):
)
return self._checksum

def checksum_states(self, state_index=None):
""" calculating checksum for the specific state or all of the states
replace lists in the inputs fields with a specific values for states
can be used only for tasks with a state
"""
if state_index is not None:
if self.state is None:
raise Exception("can't use state_index if no splitter is used")
inputs_copy = deepcopy(self.inputs)
for key, ind in self.state.inputs_ind[state_index].items():
setattr(
inputs_copy,
key.split(".")[1],
getattr(inputs_copy, key.split(".")[1])[ind],
)
input_hash = inputs_copy.hash
checksum_ind = create_checksum(self.__class__.__name__, input_hash)
return checksum_ind
else:
checksum_list = []
for ind in range(len(self.state.inputs_ind)):
checksum_list.append(self.checksum_states(state_index=ind))
return checksum_list

def set_state(self, splitter, combiner=None):
if splitter is not None:
self.state = state.State(
Expand Down Expand Up @@ -226,14 +254,7 @@ def cache_locations(self, locations):
@property
def output_dir(self):
if self.state:
if self.results_dict:
return [
self._cache_dir / res[1] for (_, res) in self.results_dict.items()
]
else:
raise Exception(
f"output_dir not available, will be ready after running {self.name}"
)
return [self._cache_dir / checksum for checksum in self.checksum_states()]
else:
return self._cache_dir / self.checksum

Expand Down Expand Up @@ -399,7 +420,7 @@ def _combined_output(self):
for (gr, ind_l) in self.state.final_groups_mapping.items():
combined_results.append([])
for ind in ind_l:
result = load_result(self.results_dict[ind][1], self.cache_locations)
result = load_result(self.checksum_states(ind), self.cache_locations)
if result is None:
return None
combined_results[gr].append(result)
Expand All @@ -419,10 +440,8 @@ def result(self, state_index=None):
return self._combined_output()
else:
results = []
for (ii, val) in enumerate(self.state.states_val):
result = load_result(
self.results_dict[ii][1], self.cache_locations
)
for checksum in self.checksum_states():
result = load_result(checksum, self.cache_locations)
if result is None:
return None
results.append(result)
Expand All @@ -431,19 +450,25 @@ def result(self, state_index=None):
if self.state.combiner:
return self._combined_output()[state_index]
result = load_result(
self.results_dict[state_index][1], self.cache_locations
self.checksum_states(state_index), self.cache_locations
)
return result
else:
if state_index is not None:
raise ValueError("Task does not have a state")
if self.results_dict:
checksum = self.results_dict[None][1]
else:
checksum = self.checksum
checksum = self.checksum
result = load_result(checksum, self.cache_locations)
return result

def _reset(self):
"""resetting the connections between inputs and LazyFields"""
for field in dc.fields(self.inputs):
if field.name in self.inp_lf:
setattr(self.inputs, field.name, self.inp_lf[field.name])
if is_workflow(self):
for task in self.graph.nodes:
task._reset()


class Workflow(TaskBase):
def __init__(
Expand Down Expand Up @@ -534,6 +559,8 @@ def create_connections(self, task):
for field in dc.fields(task.inputs):
val = getattr(task.inputs, field.name)
if isinstance(val, LazyField):
# saving all connections with LazyFields
task.inp_lf[field.name] = val
# adding an edge to the graph if task id expecting output from a different task
if val.name != self.name:
# checking if the connection is already in the graph
Expand All @@ -558,7 +585,7 @@ def create_connections(self, task):
task.state = state.State(task.name, other_states=other_states)

async def _run(self, submitter=None, **kwargs):
self.inputs = dc.replace(self.inputs, **kwargs)
# self.inputs = dc.replace(self.inputs, **kwargs) don't need it?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question as before.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not allow running a workflow without a submitter, so I believe this can't be used

checksum = self.checksum
lockfile = self.cache_dir / (checksum + ".lock")
# Eagerly retrieve cached
Expand Down Expand Up @@ -610,7 +637,6 @@ async def _run_task(self, submitter):
if not submitter:
raise Exception("Submitter should already be set.")
# at this point Workflow is stateless so this should be fine
self.results_dict[None] = (None, self.checksum)
await submitter._run_workflow(self)

def set_output(self, connections):
Expand Down
1 change: 1 addition & 0 deletions pydra/engine/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, name, splitter=None, combiner=None, other_states=None):
self.set_input_groups()
self.set_splitter_final()
self.states_val = []
self.inputs_ind = []
self.final_groups_mapping = {}

def __str__(self):
Expand Down
6 changes: 4 additions & 2 deletions pydra/engine/submitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,15 @@ def __call__(self, runnable, cache_locations=None):
self.loop.run_until_complete(self.submit_workflow(runnable))
else:
self.loop.run_until_complete(self.submit(runnable, wait=True))
if is_workflow(runnable):
# resetting all connections with LazyFields
runnable._reset()
return runnable.result()

async def submit_workflow(self, workflow):
"""Distributes or initiates workflow execution"""
if workflow.plugin and workflow.plugin != self.plugin:
# dj: this is not tested!!!
await self.worker.run_el(workflow)
else:
await workflow._run(self)
Expand Down Expand Up @@ -81,8 +85,6 @@ async def submit(self, runnable, wait=False):
)
for sidx in range(len(runnable.state.states_val)):
job = runnable.to_job(sidx)
job.results_dict[None] = (sidx, job.checksum)
runnable.results_dict[sidx] = (None, job.checksum)
logger.debug(
f'Submitting runnable {job}{str(sidx) if sidx is not None else ""}'
)
Expand Down
24 changes: 13 additions & 11 deletions pydra/engine/tests/test_node_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,17 +160,6 @@ def test_odir_init():
assert nn.output_dir


def test_odir_init_error():
""" checking if output_dir raises an error for task with a state
if the task doesn't have result (before running)
"""
nn = fun_addtwo(name="NA").split(splitter="a", a=[3, 5])

with pytest.raises(Exception) as excinfo:
assert nn.output_dir
assert "output_dir not available" in str(excinfo.value)


# Tests for tasks without state (i.e. no splitter)


Expand Down Expand Up @@ -224,6 +213,19 @@ def test_task_nostate_1_call_plug(plugin):
assert nn.output_dir.exists()


def test_task_nostate_1_call_updateinp():
""" task without splitter"""
nn = fun_addtwo(name="NA", a=30)
# updating input when calling the node
nn(a=3)

# checking the results
results = nn.result()
assert results.output.out == 5
# checking the output_dir
assert nn.output_dir.exists()


@pytest.mark.parametrize("plugin", Plugins)
def test_task_nostate_2(plugin):
""" task with a list as an input, but no splitter"""
Expand Down
Loading