In [1]:
# %load_ext autoreload
# %autoreload 2
from pathlib import Path

from plantseg.workflows import WorkflowHandler
from plantseg.workflows.io_tasks import (
    export_image_task,
    gaussian_smoothing_task,
    import_image_task,
    mock_task1,
    mock_task2,
)
from plantseg.workflows.workflow_handler import workflow_handler

In [2]:
import_image_task.__plantseg_task__

'import_image_task'

{'__workflow_func__': 'import_image_task'}

In [2]:
workflow_handler.clean_dag()
path = Path("/home/lcerrone/data/ovule_sample.h5")

image = import_image_task(input_path=path, key="raw", image_name="ovule_sample", image_type="raw", stack_layout="ZXY")
label = import_image_task(
    input_path=path, key="label", image_name="ovule_sample_label", image_type="label", stack_layout="ZXY"
)

image2, image3 = mock_task1(image=image)
image4 = mock_task2(image=image2, image2=label)
smooth_image = gaussian_smoothing_task(image=image, sigma=1.0)

export_image_task(
    image=smooth_image,
    output_directory=path.parent,
    output_file_name=smooth_image.name,
    custom_key="smoothed",
    scale_to_origin=True,
    file_format="tiff",
    dtype="uint16",
)

In [3]:
workflow_handler.save_to_yaml("dag.yaml")

In [4]:
for task in workflow_handler.dag.list_tasks:
    print(task)

func='import_image_workflow' images_inputs={'input_path': 'input_path_0'} parameters={'key': 'raw', 'image_name': 'ovule_sample', 'image_type': 'raw', 'stack_layout': 'ZXY'} list_private_parameters=['image_type', 'stack_layout'] outputs=['ovule_sample'] node_type=<NodeType.ROOT: 'root'> id=UUID('ddf0e828-47bd-413b-a721-59a0cc8ad2a7')
func='gaussian_smoothing_workflow' images_inputs={'image': 'ovule_sample'} parameters={'sigma': 1.0} list_private_parameters=[] outputs=['ovule_sample_smoothed'] node_type=<NodeType.NODE: 'node'> id=UUID('87901597-7dc8-4736-82de-4d9fdf3a18ad')
func='export_image_workflow' images_inputs={'image': 'ovule_sample_smoothed', 'output_directory': 'output_directory_0', 'output_file_name': 'output_file_name_0'} parameters={'custom_key': 'smoothed', 'scale_to_origin': True, 'file_format': 'tiff', 'dtype': 'uint16'} list_private_parameters=[] outputs=[] node_type=<NodeType.LEAF: 'leaf'> id=UUID('2724af3c-f670-4e01-b2a7-f03653b8a966')


In [5]:
from plantseg.workflows.workflow_handler import DAG, Task


class SerialRunner:
    def __init__(self, dag_path: str | Path):
        if isinstance(dag_path, str):
            dag_path = Path(dag_path)

        self.dag_path = dag_path

        if not dag_path.exists():
            raise FileNotFoundError(f"File {dag_path} not found")

        self.func_registry = WorkflowHandler().func_registry
        print(self.func_registry.list_funcs())

    def find_next_task(self, dag: DAG, var_set: set[str]):
        for task in dag.list_tasks:
            required_inputs = set(task.images_inputs.values())
            if required_inputs.issubset(var_set):
                dag.list_tasks.remove(task)
                return task
        return None

    def run_task(self, task: Task, var_space: dict):
        # Get inputs from var_space
        inputs = {}
        for name, image_name in task.images_inputs.items():
            inputs[name] = var_space[image_name]

        # run the task
        func = self.func_registry.get_func(task.func)
        outputs = func(**inputs, **task.parameters)

        # Save outputs in var_space
        for i, name in enumerate(task.outputs):
            if isinstance(outputs, tuple):
                var_space[name] = outputs[i]
            else:
                var_space[name] = outputs

        return var_space

    def clean_var_space(self, dag: DAG, var_space: dict):
        all_remaining_required_inputs = set()
        for task in dag.list_tasks:
            required_inputs = set(task.images_inputs.values())
            all_remaining_required_inputs = all_remaining_required_inputs.union(required_inputs)

        list_key_to_delete = []
        for var in var_space.keys():
            if var not in all_remaining_required_inputs:
                list_key_to_delete.append(var)

        for key in list_key_to_delete:
            del var_space[key]
        return var_space

    def _parse_input(self, inputs: dict[str, str] | list[dict[str, str]]) -> list[dict]:
        if isinstance(inputs, dict):
            inputs = [inputs]

        return inputs

    def run(self, inputs: dict[str, str]):
        dag = WorkflowHandler().from_yaml(self.dag_path)._dag

        var_space = {}
        for key in dag.list_inputs:
            if key not in inputs:
                raise ValueError(f"Missing input variable {key}")
            var_space[key] = inputs[key]

        while dag.list_tasks:
            # Find next task to run
            next_task = self.find_next_task(dag, set(var_space.keys()))
            if next_task is None:
                raise ValueError("No task to run next, the computation graph might be corrupted")

            # Run the task
            var_space = self.run_task(next_task, var_space)

            # Remove from var_space the variables that are not needed anymore
            var_space = self.clean_var_space(dag, var_space)

        if var_space:
            raise ValueError("Some variables are still in the var_space, the computation graph might be corrupted")
        return True


path = Path("/home/lcerrone/data/ovule_sample.h5")

runner = SerialRunner("dag.yaml")
var_space = {
    'input_path_0': path,
    'input_path_1': path,
    'output_directory_0': path.parent,
    'output_file_name_0': "test2",
}

runner.run(inputs=var_space)

['import_image_workflow', 'export_image_workflow', 'gaussian_smoothing_workflow', 'mock_task1', 'mock_task2']


True

In [13]:
with open("dag.yaml", "r") as f:
    dag_dict = yaml.load(f, Loader=yaml.FullLoader)

dag = DAG(**dag_dict)

var_space = {
    'path_0': path,
    'path_1': path,
    'directory_0': path.parent,
}

while dag.list_tasks:
    next_task = find_next_task(dag, set(var_space.keys()))
    if next_task is None:
        break

    print(next_task.func)
    print("before task", var_space)
    var_space = schedule_task(next_task, var_space, task_handler._funcs)
    var_space = clean_var_space(dag, var_space)
    print("after task", var_space)

import_image_workflow
before task {'path_0': PosixPath('/home/lcerrone/data/ovule_sample.h5'), 'path_1': PosixPath('/home/lcerrone/data/ovule_sample.h5'), 'directory_0': PosixPath('/home/lcerrone/data')}
after task {'path_0': PosixPath('/home/lcerrone/data/ovule_sample.h5'), 'directory_0': PosixPath('/home/lcerrone/data'), 'ovule_sample': <plantseg.image.Image object at 0x7fa524bf67d0>}
import_image_workflow
before task {'path_0': PosixPath('/home/lcerrone/data/ovule_sample.h5'), 'directory_0': PosixPath('/home/lcerrone/data'), 'ovule_sample': <plantseg.image.Image object at 0x7fa524bf67d0>}
after task {'directory_0': PosixPath('/home/lcerrone/data'), 'ovule_sample': <plantseg.image.Image object at 0x7fa465ee7cd0>}
task1
before task {'directory_0': PosixPath('/home/lcerrone/data'), 'ovule_sample': <plantseg.image.Image object at 0x7fa465ee7cd0>}
after task {'directory_0': PosixPath('/home/lcerrone/data'), 'ovule_sample': <plantseg.image.Image object at 0x7fa465ee7cd0>, 'ovule_sample_2'

In [12]:
var_space

{'path_import_image_workflow_8279316341258423075': PosixPath('/home/lcerrone/data/ovule_sample.h5'),
 'path_import_image_workflow_6773622813115396525': PosixPath('/home/lcerrone/data/ovule_sample.h5'),
 'directory_export_image_workflow_-8523320658921002950': PosixPath('/home/lcerrone/data')}

['path_import_image_workflow_8279316341258423075',
 'path_import_image_workflow_6773622813115396525',
 'directory_export_image_workflow_-8523320658921002950']