In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# default_exp services.all

In [None]:
# export
from dataclasses import dataclass, asdict
from functools import partial
from typing import Any, List, Tuple, Iterator, Optional, Dict
from yaml import safe_load
from networkx import (
    DiGraph,
    compose,
    is_directed_acyclic_graph,
    relabel_nodes,
    topological_sort,
    descendants,
    number_of_nodes,
)
from flow.domain.model import Task, Config, WorkflowDefinition, Workflow
from flow.adapters.templater import *

# Workflow Graph

> API details.

## Test: Input Graph

In [None]:
from IPython.display import Image, display
from networkx.drawing.nx_pydot import to_pydot


def view_pydot(pdot):
    plt = Image(pdot.create_png())
    display(plt)

## Render Workflow

In [None]:
# export
def render(templater: Templater, template, task: dict) -> str:
    return templater.render(template = template, content = task)

### Test Render

In [None]:
# Default configuration for Airflow Tasks and Oozie Actions
config = {
    "ssh": {
        "BashOperator": {
            "task_id": "sleep",
            "bash_command": "sleep",
            "retries": 3
        },
        "imports": {
            # module: List[object]
            "airflow.operators.bash_operator": ["BashOperator"]
        }
    }
}

workflow = {
    "name": 'airflow',
    "imports": '', #For each task, search the config for required imports and then merge
    "default_args": {},
    "dag_args": {},
    "tasks": {},
}

# task extract operator data from config.
task = {
    "name": "task2",
    "dependencies": ["task1"],
    "type": "BashOperator",
    "args": {
        "task_id": "sleep",
        "bash_command": "sleep",
        "retries": 3      
    }
}

# scaffold extract import data from config
scaffold = {
    "metadata": "",
    "imports": {
        'module1': ['object1', 'object2'],
        "module2": [],
        "module3": ['object3'],
        "module3": ["*"],
    },
    "default_args": {
        'owner': "'airflow'",
        'depends_on_past': False,
        'start_date': 'datetime(2018, 5, 26)',
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': 'timedelta(minutes=5)',
    },
    "dag_args": {
        'description': "'A simple tutorial DAG'",
        'schedule_interval': "'@daily'",
    },
}

templater = JinjaTemplater()

# Test BashOperator Task is rendered correctly
print(render(templater, 'operator.txt', task))

# Test Scaffold in rendered correctly
# print(render(templater, 'scaffold.txt', scaffold))

In [None]:
class FakeTemplater(list):
    
    def get_template(self, filename):
        self.append(('GET', filename))
        
    def render(self, template, content):
        self.get_template(template)
        self.append(('RENDER', template, content))

def test_render_airflow_operator():
    templater = FakeTemplater()
    template = 'operator.txt'
    task = {
        "name": "task2",
        "dependencies": ["task1"],
        "type": "BashOperator",
        "args": {
            "task_id": "sleep",
            "bash_command": "sleep",
            "retries": 3      
        }
    }
    render(templater, template, task)
    assert templater == [('GET', 'operator.txt'), 
                         ('RENDER', 'operator.txt', task)]

def test_render_airflow_scaffold():
    templater = FakeTemplater()
    template = 'scaffold.txt'
    scaffold = {
        "imports": {
            "module1": ['object1', 'object2'],
            "module2": [],
            "module3": ['object3'],
            "module3": ["*"],
        }
    }
    render(templater, template, scaffold)
    assert templater == [('GET', template), ('RENDER', template, scaffold)]

In [None]:
test_render_airflow_operator()
test_render_airflow_scaffold()

## Workflow and Config data

In [None]:
#export
def load(read, build, path: str):
    with open(path, 'r') as f:
        data = read(f)
    return build(**data)


## Workflow

In [None]:
#export
def build_workflow(builder, configuration, name, tasks) -> "Workflow":
    """Build workflows from configuration and definition tasks/subtasks
    
    Parameters
    ----------
    builder
        Workflow class with `build` method
    configuration
        Workflow configuration
    name
        Workflow name
    tasks
        Workflow tasks
        
    Returns
    -------
    Workflow
        Workflow object with list of Tasks used to generate graph
    """
    return builder(name, configuration.task_types).build(tasks)

## Workflow Graph

In [None]:
#export
def build_graph(workflow: Workflow) -> DiGraph:
    # Add all tasks to a map of string -> task
    task_dict = {task.name: task for task in workflow.tasks}
    task_names = task_dict.keys()

    # DAG of the workflow in its raw/un-optimized state
    input_graph = DiGraph()

    # Add all dependencies as edges
    dependencies = [
        (task_dict[dependency], task)
        for task in workflow.tasks
        for dependency in task.dependencies
    ]
    input_graph.add_edges_from(dependencies)

    # Add all the tasks as vertices.
    input_graph.add_nodes_from(workflow.tasks)

    # Make sure all dependencies have an associated task
    dep_tasks = set(
        [
            dependency
            for task in workflow.tasks
            for dependency in task.dependencies
        ]
    )
    if not dep_tasks.issubset(task_names):
        dep = dep_tasks - task_names
        raise WorkflowGraphError(f"Missing task for dependencies: {dep}")

    return input_graph


In [None]:
def draw(G: DiGraph, output_dir: str = None):
    # re-label node with action names
    mapping = {action: action.name for action in G.nodes()}
    g = relabel_nodes(G, mapping)
    # save image
    pdot = to_pydot(g)

    if output_dir:
        pdot.write_png(output_dir)

    return pdot

In [None]:
# # hide

# # Visualize the graph
# viz = draw(workflow_graph)
# view_pydot(viz)

## Render Workflow from Graph

In [None]:
#export
# Load workflow definition and configuration
load_definition = partial(load, safe_load, WorkflowDefinition.build)
load_configuration = partial(load, safe_load, Config.build)

workflow_definition = load_definition('../temp/workflow.yaml')
workflow_configuration = load_configuration('../temp/config.yaml')

# Build workflow object and derive graph
build = partial(build_workflow, 
                Workflow,
                workflow_configuration)

workflow = build(workflow_definition.name, workflow_definition.tasks)

workflow_graph = build_graph(workflow)

# Render Scaffold from workflow definition 
templater = JinjaTemplater()

render_scaffold = partial(render, templater, 'scaffold.txt')
# import = set( all imports from workflow )
scaffold = {"imports": {},
            "default_args": workflow_definition.default_args,
            "dag_args": workflow_definition.dag_args}

rendered_scaffold = render_scaffold(scaffold)

# Render operator from workflow graph
render_operator = partial(render, templater, 'operator.txt')
rendered_tasks = [render_operator(task.todict())
                  for task in topological_sort(workflow_graph)]

# Combine Scaffold, Operator and SubDAGs
dag_definition = '\n'.join([rendered_scaffold, '\n']+rendered_tasks)


# Build subworkflows and derives graph
subworkflows = {name : build(name, subtask) 
                for (name, subtask) 
                in workflow_definition.subtasks.items()}

subworkflow_graphs = {name : build_graph(subworkflow)
                      for (name, subworkflow)
                      in subworkflows.items()}

# Render SubDags from subworkflow graph
render_subtask = partial(render, templater, 'subtask.txt')
render_subtasks = {name : render_subtask(subtask.todict())
                   for (name, subworkflow_graph) in subworkflow_graphs.items()
                   for subtask in topological_sort(subworkflow_graph)}

In [None]:
for name, subtask in render_subtasks.items():
    print(subtask)

In [None]:
workflow_definition.subtasks

In [None]:
print(dag_definition)