In [1]:
import hpcflow.app as hf
from hpcflow.sdk.core.test_utils import make_workflow

from plotly.graph_objects import FigureWidget
tmp_path = r"C:\code_local\scratch\hpcflow-workflows"

# TODO: 
# - tasks should be yaxis ticks
# - centre elements in the middle
# - add element index label in the element square
# - show groups as a rectangle around elements
# - support max number of elements; use ellipses and annotate with total num elements in the task
# - option to colour elements (or add symbols or something) to show: 
#     - jobscript + submission index (if submitted)
#     - submission status?
#     - or show jobscript as a bounding polygon around elements?
#        - might have overlapping boxes if loops distributed over multiple jobscripts
# - show actions if zoomed in enough?
#   - or could stack the boxes to show multiple actions?
# - right-angled lines linking many-to-one elements?

In [2]:
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
from plotly import graph_objects


def get_empty_fig():
    fig = FigureWidget(layout={"height": 500, "dragmode": "pan"})
    fig.update_yaxes(scaleanchor="x", zeroline=False, showgrid=False)
    fig.update_xaxes(showgrid=False, zeroline=False)
    return fig


@dataclass
class ElementIterationNode:
    """Class to represent a workflow element-iteration."""

    DEFAULT_WIDTH = 0.5
    DEFAULT_CENTRE = [0, 0]
    _centre = None
    _width = None
    
    task_idx: int
    inputs: List[str]
    outputs: List[str]
    run_IDs: List[int]
    
    @property
    def centre(self):
        if not self._centre:
            self._centre = self.DEFAULT_CENTRE
        return self._centre

    @centre.setter
    def centre(self, value):
        self._centre = value

    @property
    def width(self):
        if not self._width:
            self._width = self.DEFAULT_WIDTH
        return self._width

    @width.setter
    def width(self, value):
        self._width = value


    @property
    def num_actions(self):
        return len(self.run_IDs)

    @property
    def num_inputs(self):
        return len(self.inputs)

    @property
    def num_outputs(self):
        return len(self.outputs)

    def get_height(self):
        return 0.5 * self.width

    @staticmethod
    def distribute_items(length, num_items, origin):
        interval = length / (num_items + 1)
        coords = [interval * i + origin for i in range(1, num_items + 1)]
        return coords

    def get_input_coords(self, input: Optional[str] = None):
        (top_y, right_x, bottom_y, left_x) = self._get_edges()
        inputs_y = [top_y] * self.num_inputs
        inputs_x = self.distribute_items(self.width, self.num_inputs, left_x)
        if input:
            idx = self.inputs.index(input)
            return (inputs_x[idx], inputs_y[idx])

        return (inputs_x, inputs_y)

    def get_output_coords(self, output: Optional[str] = None):
        (top_y, right_x, bottom_y, left_x) = self._get_edges()
        outputs_y = [bottom_y] * self.num_outputs
        outputs_x = self.distribute_items(self.width, self.num_outputs, left_x)
        if output:
            idx = self.outputs.index(output)
            return (outputs_x[idx], outputs_y[idx])

        return (outputs_x, outputs_y)

    def _get_edges(self):
        height = self.get_height()
        top_y = self.centre[1] + 0.5 * height
        bottom_y = self.centre[1] - 0.5 * height
        left_x = self.centre[0] - 0.5 * self.width
        right_x = self.centre[0] + 0.5 * self.width
        return (top_y, right_x, bottom_y, left_x)

    def get_dependency_trace(self, source_iter, input_name, added_input_names=None):
        if not added_input_names:
            added_input_names = set()
        sx, sy = source_iter.get_output_coords(output=input_name)
        ix, iy = self.get_input_coords(input=input_name)
        trace = graph_objects.Scatter(
            x=[sx, ix],
            y=[sy, iy],
            mode="lines",
            line={"color": "gray"},
            name=input_name,
            legendgroup=input_name,
            showlegend=input_name not in added_input_names,
        )
        return trace

    @staticmethod
    def get_exec_trace(iter_IDs: List[int], iters: List["ElementIterationNode"], sub_idx: int, js_idx: int):
        x = []
        y = []
        for i in iter_IDs:
            centre = iters[i].centre
            x.append(centre[0])
            y.append(centre[1])

        if len(x) > 1:
            marker = {"size":10, "symbol": "arrow-bar-up", "angleref": "previous"}
        else:
            marker = {"size":10, "symbol": "square"}
        trace = graph_objects.Scatter(
            x=x,
            y=y,
            mode="markers+lines",
            line={"color": "green"},
            name=f"sub. {sub_idx} - JS {js_idx}",
            marker=marker,
        )
        return trace

    def get_plot_data(self):
        # box for the iteration with point at top and bottom edges for inputs/outputs
        (top_y, right_x, bottom_y, left_x) = self._get_edges()
        shapes = [
            graph_objects.layout.Shape(
                type="rect",
                x0=left_x,
                y0=bottom_y,
                x1=right_x,
                y1=top_y,
                line={"color": "gray"},
                layer="below",
            )
        ]
        ix, iy = self.get_input_coords()
        ox, oy = self.get_output_coords()
        traces = [
            graph_objects.Scatter(
                x=ix + ox,
                y=iy + oy,
                mode="markers",
                text=self.inputs + self.outputs,
                hoverinfo="text",
                marker={"color": "black"},   
                showlegend=False,
            )
        ]
        return shapes, traces

    def plot(self):
        fig = get_empty_fig()
        shapes, traces = self.get_plot_data()
        for i in shapes:
            fig.add_shape(i)
        for i in traces:
            fig.add_trace(i)
        fig.show(config={"scrollZoom": True})


@dataclass
class RunNode:
    """Class to represent a workflow run."""

    iter_ID: int
    run_sources: Dict[str, int]  # map input names to other `RunNode`s
    local_sources: List[str]
    default_sources: List[str]


def show_2(wk):


    runs, elem_iters, elem_iters_by_task = init_nodes(wk)
    fig = get_empty_fig()
    task_names = []
    shapes = []
    traces = []
    added_input_names = set()
    for idx, (task_iID, loop_idx, iter_IDs) in enumerate(
        wk.get_iteration_task_pathway(ret_iter_IDs=True)
    ):
        task_iters = {k: v for k, v in elem_iters.items() if k in iter_IDs}
        task = wk.tasks.get(insert_ID=task_iID)        
        task_names.append(task.get_dir_name(loop_idx))
        centre_y = -idx
        for iter_idx, iter_ID_i in enumerate(task_iters):
            iter_i = elem_iters[iter_ID_i]
            centre_x = iter_idx
            iter_i.centre = [centre_x, centre_y]            
            shapes_i, traces_i = iter_i.get_plot_data()

            for run_j_ID in iter_i.run_IDs:
                run_j = runs[run_j_ID]
                for inp_name_k, source_runs_k_IDs in run_j.run_sources.items():
                    for src_run_m_ID in source_runs_k_IDs:
                        source_run_k = runs[src_run_m_ID]
                        source_run_k_iter = elem_iters[source_run_k.iter_ID]
                        # connect source run's iteration output to this iteration's input
                        dep_trace = iter_i.get_dependency_trace(source_run_k_iter, inp_name_k, added_input_names)
                        added_input_names.add(inp_name_k)
                        traces.append(dep_trace)

            shapes.extend(shapes_i)
            traces.extend(traces_i)

    # add traces to show execution order:
    for sub in wk.submissions or []:
        for js in sub.jobscripts:
            if js.is_array:
                raise NotImplementedError # need separate trace for each JS element index
            else:
                # single trace 
                # run_lst = js.EAR_ID.flatten(order="F")
                run_lst = js.all_EAR_IDs                
                iter_lst = [runs[i].iter_ID for i in run_lst]
                traces.append(ElementIterationNode.get_exec_trace(iter_lst, elem_iters, sub.index, js.index))
    
    for i in shapes:
        fig.add_shape(i)
    for i in traces:
        fig.add_trace(i)

    fig.update_layout(
        yaxis = dict(
            tickmode = 'array',            
            tickvals = list(range(0, -len(task_names), -1)),
            ticktext = task_names,
        )
    )    
    return fig


def init_nodes(wk):
    runs = {}  # by ID
    elem_iters = {}  # by ID
    elem_iters_by_task = defaultdict(list)  # elem_iter_IDs, by task ID
    for task in wk.tasks:
        for elem_i in task.elements:
            for iter_j in elem_i.iterations:
                run_IDs = []
                for act_run in iter_j.action_runs:
                    run_sources = {}
                    local_sources = []
                    default_sources = []
                    for inp_name_full, source in act_run.get_parameter_sources(
                        "inputs"
                    ).items():
                        inp_name = ".".join(inp_name_full.split(".")[1:])
                        run_sources[inp_name] = []
                        if not isinstance(source, list):
                            source = [source]                        
                        for src_i in source:                            
                            if src_i["type"] == "EAR_output":
                                run_sources[inp_name].append(src_i["EAR_ID"])
                            elif src_i["type"] == "local_input":
                                local_sources.append(inp_name)
                            elif src_i["type"] == "default":
                                default_sources.append(inp_name)
                    run_IDs.append(act_run.id_)
                    runs[act_run.id_] = RunNode(
                        iter_ID=iter_j.id_,
                        run_sources=run_sources,
                        local_sources=local_sources,
                        default_sources=default_sources,
                    )
                elem_iters_by_task[task.index].append(iter_j.id_)
                elem_iters[iter_j.id_] = ElementIterationNode(
                    task_idx=task.index,
                    inputs=list(iter_j.task.template.all_schema_input_types),
                    outputs=list(iter_j.task.template.all_schema_output_types),
                    run_IDs=run_IDs,
                )
    return runs, elem_iters, elem_iters_by_task


# fig = show_2(wk)
# fig.show(config={"scrollZoom": True})

In [3]:
wk = hf.Workflow(r"C:\code_local\scratch\hpcflow-workflows\subset_simulation_2024-05-30_163552")
show_2(wk)

FigureWidget({
    'data': [{'hoverinfo': 'text',
              'marker': {'color': 'black'},
              'mode': 'markers',
              'showlegend': False,
              'text': [dimension, x],
              'type': 'scatter',
              'uid': 'd70ce763-e29d-4aed-bfad-f3aae1ef07c7',
              'x': [0.0, 0.0],
              'y': [0.125, -0.125]},
             {'hoverinfo': 'text',
              'marker': {'color': 'black'},
              'mode': 'markers',
              'showlegend': False,
              'text': [dimension, x],
              'type': 'scatter',
              'uid': 'a10ccf4b-773e-47fe-be9d-bab3a65d51ac',
              'x': [1.0, 1.0],
              'y': [0.125, -0.125]},
             {'hoverinfo': 'text',
              'marker': {'color': 'black'},
              'mode': 'markers',
              'showlegend': False,
              'text': [dimension, x],
              'type': 'scatter',
              'uid': '2d09632e-6051-40a1-8b77-315344706327',
            

In [6]:
mcmc_wk = hf.Workflow(r"C:\code_local\scratch\matflow-workflows\subset_simulation_2024-05-16_172542")
fig = show_2(mcmc_wk)
fig.update_layout(autosize=True, height=850)
fig.write_html("wk_graph.html", config={"scrollZoom": True})

In [76]:
mwk = hf.Workflow(r"C:\code_local\scratch\matflow-workflows\tension_DAMASK_Al_2024-05-11_164202")
show_2(mwk)

In [74]:
show_1(mwk)

In [69]:
show_1(wk)

In [2]:
# wk = hf.Workflow(r"C:\code_local\scratch\matflow-workflows\MCMC_2024-05-10_231704")

In [68]:
def show_1(wk):
    elements = {}
    for task in wk.tasks:
        up_tasks = task.upstream_tasks
        down_tasks = task.downstream_tasks
        # print(f"{up_tasks=}")
        # print(f"{down_tasks=}")
        for elem in task.elements:
            # print(f"{elem=}")
            deps = elem.get_element_dependencies(as_objects=True)
            # print(f"{deps=}")
            upstream_deps = {i.id_ for i in deps if i.task in up_tasks}
            downstream_deps = {i.id_ for i in deps if i.task in down_tasks}
            
            # element dependencies from upstream tasks only:
            deps = {i.id_ for i in deps if i.task in task.upstream_tasks}
            elements[elem.id_] = {
                "coords": (-task.index, elem.index),
                "depends_on_up": upstream_deps, 
                "depends_on_down": downstream_deps,
            }
    # print(elements)

    elem_size = 0.25
    iter_dep_shift = 0.25
    fig = FigureWidget(layout={"height": 500, "dragmode": "pan"})
    fig.update_yaxes(scaleanchor="x", zeroline=False, showgrid=False)
    fig.update_xaxes(showgrid=False, zeroline=False)

    # add task names:
    for task in wk.tasks:
        fig.add_annotation(
            y=-task.index,
            x=0 - elem_size,
            text=f"Task {task.unique_name}",
            showarrow=False,
            xanchor="right",
        )

    # add elements:
    for elem in elements.values():
        x0 = elem["coords"][1] - 0.5 * elem_size 
        x1 = elem["coords"][1] + 0.5 * elem_size 
        y0 = elem["coords"][0] - 0.5 * elem_size 
        y1 = elem["coords"][0] + 0.5 * elem_size     
        fig.add_shape(type="rect", x0=x0, x1=x1, y0=y0, y1=y1)
        if elem["depends_on_up"]:
            for i in elem["depends_on_up"]:
                dep_coords = elements[i]["coords"]
                x0 = elem["coords"][1]
                y0 = elem["coords"][0] + 0.5 * elem_size
                x1 = dep_coords[1]
                y1 = dep_coords[0] - 0.5 * elem_size
                fig.add_shape(type="line", x0=x0, x1=x1, y0=y0, y1=y1)

        if elem["depends_on_down"]:
            for i in elem["depends_on_down"]:
                dep_coords = elements[i]["coords"]
                x0a = elem["coords"][1] + iter_dep_shift
                y0a = elem["coords"][0] + elem_size
                x1a = dep_coords[1] + iter_dep_shift
                y1a = dep_coords[0]
                fig.add_shape(type="line", x0=x0a, x1=x1a, y0=y0a, y1=y1a)

                x1 = elem["coords"][1]
                y1 = elem["coords"][0] + elem_size
                fig.add_shape(type="line", x0=x0a, x1=x1, y0=y0a, y1=y1)
                
                x0 = dep_coords[1] + 0.5 * elem_size
                y0 = dep_coords[0]
                fig.add_shape(type="line", x0=x0, x1=x1a, y0=y0, y1=y1a)

                x0 = elem["coords"][1]
                y0 = elem["coords"][0]+ 0.5 * elem_size
                fig.add_shape(type="line", x0=x0, x1=x1, y0=y0, y1=y1)            

    fig.show(config={'scrollZoom': True})

In [8]:
show_1(wk)

In [15]:
import plotly.graph_objects as go

# Create a sample scatter plot
fig = go.Figure(go.Scatter(x=[1, 2, 3], y=[4, 5, 6]))

# Define a function to handle relayout events
def on_relayout(layoutUpdate):
    x_range = layoutUpdate['xaxis.range']
    y_range = layoutUpdate['yaxis.range']

    # Calculate the zoom level based on the difference between the x and y ranges
    x_zoom = x_range[1] - x_range[0]
    y_zoom = y_range[1] - y_range[0]

    # Example: If zoom level is greater than 1 in both axes, add a shape to the plot
    if x_zoom > 1 and y_zoom > 1:
        fig.add_shape(type="rect",
                      x0=min(x_range), y0=min(y_range),
                      x1=max(x_range), y1=max(y_range),
                      line=dict(color="RoyalBlue"),
                      fillcolor="LightSkyBlue", opacity=0.5)

    # Example: If zoom level is less than 1 in both axes, remove any shapes from the plot
    else:
        fig.update_layout(shapes=[])

# Attach the function to the relayout event of the entire figure's layout
fig.layout.on_change(on_relayout)

# Show the plot
fig.show()

ValueError: At least one change property must be specified

### A notebook for looking into visualising workflow dependencies

**Note this required additional dependencies: `networkx`, `matplotlib` and `plotly`**

In [1]:
%load_ext autoreload
%autoreload 2

import tempfile
import numpy as np
import pytest

import networkx
import matplotlib.pyplot as plt
from plotly import graph_objects as go

from hpcflow.api import (
    hpcflow, Workflow, WorkflowTemplate, TaskSchema, Task, Parameter, InputValue, 
    SchemaInput, ValueSequence, InputSource, SchemaOutput, InputSourceType, TaskSourceType,
    ElementPropagation,
)
from hpcflow.sdk.core.utils import read_YAML
from hpcflow.sdk.core.errors import MissingInputs
from hpcflow.sdk.core.zarr_io import ZarrEncodable

hpcflow.load_config(config_dir=tempfile.gettempdir())

ModuleNotFoundError: No module named 'plotly'

In [None]:
param_p1 = Parameter("p1")
param_p2 = Parameter("p2")
param_p3 = Parameter("p3")
param_p4 = Parameter("p4")

param_p1 = SchemaInput(param_p1, default_value=1001)
param_p2 = SchemaInput(param_p2, default_value=np.array([2002, 2003]))
param_p3 = SchemaInput(param_p3, default_value=3001)
param_p4 = SchemaInput(param_p4)

In [None]:
s1 = TaskSchema("ts1", actions=[], inputs=[param_p1], outputs=[param_p3])
t1 = Task(
    schemas=s1,
    inputs=[InputValue(param_p1, 101)],
)

s2 = TaskSchema("ts2", actions=[], inputs=[param_p2, param_p3], outputs=[param_p4])
t2 = Task(
    schemas=s2,
    sequences=[ValueSequence('inputs.p2', values=[201, 202], nesting_order=1)],
)

s3 = TaskSchema("ts3", actions=[], inputs=[param_p3, param_p4])
t3 = Task(schemas=s3, nesting_order={'inputs.p3': 0, 'inputs.p4': 1})

wkt = WorkflowTemplate(name="w1", tasks=[t1, t2, t3])
wk = Workflow.from_template(wkt, path=tempfile.gettempdir(), name=wkt.name)

wk.tasks.ts1.add_elements(
    sequences=[ValueSequence('inputs.p1', values=[102, 103, 104], nesting_order=1)],
    propagate_to=[
        ElementPropagation(
            task=wk.tasks.ts2,
            nesting_order={'inputs.p2': 0, 'inputs.p3': 1}
        ),
        ElementPropagation(
            task=wk.tasks.ts3,
            nesting_order={'inputs.p3': 0, 'inputs.p4': 1},
        )
    ],
)

In [None]:
def build_element_graph(workflow):
    G = networkx.DiGraph()
    for task in workflow.tasks:
        for element in task.elements:
            G.add_node(element.global_index, task=task.index)
        
    for i in wk.elements:
        for j in i.dependent_elements:
            G.add_edge(i.global_index, j)
        
    return G

def _prepare_element_graph(G):
    task_colours = ['blue', 'green', 'red']
    node_colours = [task_colours[data["task"]] for v, data in G.nodes(data=True)]
    pos = networkx.multipartite_layout(G, subset_key='task')    
    return node_colours, pos

def show_element_graph_matplotlib(G):
    node_colours, pos = _prepare_element_graph(G)
    networkx.draw(
        G,
        pos,
        node_color=node_colours,
        with_labels=True,
    )

def show_element_graph_plotly(G):
    node_colours, pos = _prepare_element_graph(G)
    edge_x = []
    edge_y = []
    for edge in G.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        edge_x.append(x0)
        edge_x.append(x1)
        edge_x.append(None)
        edge_y.append(y0)
        edge_y.append(y1)
        edge_y.append(None)

    edge_trace = go.Scatter(
        x=edge_x,
        y=edge_y,
        hoverinfo='none',
        mode='lines'
    )

    node_x = []
    node_y = []
    for node in G.nodes():
        x, y = pos[node]
        node_x.append(x)
        node_y.append(y)

    node_trace = go.Scatter(
        x=node_x,
        y=node_y,
        mode='markers',
        hoverinfo='text',
    )

    fig = go.Figure(data=[edge_trace, node_trace])
    return fig

In [None]:
G = build_element_graph(wk)

show_element_graph_matplotlib(G)
show_element_graph_plotly(G)