In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# default_exp services.all

In [3]:
# export
from dataclasses import dataclass, asdict
from functools import partial
from typing import Any, List, Tuple, Iterator, Optional, Dict
from yaml import safe_load
from networkx import (
    DiGraph,
    compose,
    is_directed_acyclic_graph,
    relabel_nodes,
    topological_sort,
    descendants,
    number_of_nodes,
)
from flow.domain.model import Task, Config, WorkflowDefinition, Workflow
from flow.adapters.templater import *

# Workflow Graph

> API details.

## Test: Input Graph

In [None]:
from IPython.display import Image, display
from networkx.drawing.nx_pydot import to_pydot


def view_pydot(pdot):
    plt = Image(pdot.create_png())
    display(plt)

## Render Workflow

In [4]:
# export
def render(templater: Templater, template, task: dict) -> str:
    return templater.render(template = template, content = task)

### Test Render

In [5]:
# Default configuration for Airflow Tasks and Oozie Actions
config = {
    "ssh": {
        "BashOperator": {
            "task_id": "sleep",
            "bash_command": "sleep",
            "retries": 3
        },
        "imports": {
            # module: List[object]
            "airflow.operators.bash_operator": ["BashOperator"]
        }
    }
}

workflow = {
    "name": 'airflow',
    "imports": '', #For each task, search the config for required imports and then merge
    "default_args": {},
    "dag_args": {},
    "tasks": {},
}

# task extract operator data from config.
task = {
    "name": "task2",
    "dependencies": ["task1"],
    "type": "BashOperator",
    "args": {
        "task_id": "sleep",
        "bash_command": "sleep",
        "retries": 3      
    }
}

# scaffold extract import data from config
scaffold = {
    "metadata": "",
    "imports": {
        'module1': ['object1', 'object2'],
        "module2": [],
        "module3": ['object3'],
        "module3": ["*"],
    },
    "default_args": {
        'owner': "'airflow'",
        'depends_on_past': False,
        'start_date': 'datetime(2018, 5, 26)',
        'email_on_failure': False,
        'email_on_retry': False,
        'retries': 1,
        'retry_delay': 'timedelta(minutes=5)',
    },
    "dag_args": {
        'description': "'A simple tutorial DAG'",
        'schedule_interval': "'@daily'",
    },
}

templater = JinjaTemplater()

# Test BashOperator Task is rendered correctly
print(render(templater, 'operator.txt', task))

# Test Scaffold in rendered correctly
# print(render(templater, 'scaffold.txt', scaffold))

task2 = BashOperator(
    task_id = "task2",
    task_id = sleep,
    bash_command = sleep,
    retries = 3,
    dag = dag
)

task1 >> task2



In [6]:
class FakeTemplater(list):
    
    def get_template(self, filename):
        self.append(('GET', filename))
        
    def render(self, template, content):
        self.get_template(template)
        self.append(('RENDER', template, content))

def test_render_airflow_operator():
    templater = FakeTemplater()
    template = 'operator.txt'
    task = {
        "name": "task2",
        "dependencies": ["task1"],
        "type": "BashOperator",
        "args": {
            "task_id": "sleep",
            "bash_command": "sleep",
            "retries": 3      
        }
    }
    render(templater, template, task)
    assert templater == [('GET', 'operator.txt'), 
                         ('RENDER', 'operator.txt', task)]

def test_render_airflow_scaffold():
    templater = FakeTemplater()
    template = 'scaffold.txt'
    scaffold = {
        "imports": {
            "module1": ['object1', 'object2'],
            "module2": [],
            "module3": ['object3'],
            "module3": ["*"],
        }
    }
    render(templater, template, scaffold)
    assert templater == [('GET', template), ('RENDER', template, scaffold)]

In [7]:
test_render_airflow_operator()
test_render_airflow_scaffold()

## Workflow and Config data

In [8]:
#export
def load(read, build, path: str):
    with open(path, 'r') as f:
        data = read(f)
    return build(**data)


## Workflow

In [9]:
#export
def build_workflow(builder, configuration, name, tasks) -> "Workflow":
    """Build workflows from configuration and definition tasks/subtasks
    
    Parameters
    ----------
    builder
        Workflow class with `build` method
    configuration
        Workflow configuration
    name
        Workflow name
    tasks
        Workflow tasks
        
    Returns
    -------
    Workflow
        Workflow object with list of Tasks used to generate graph
    """
    return builder(name, configuration.task_types).build(tasks)

## Workflow Graph

In [10]:
#export
def build_graph(workflow: Workflow) -> DiGraph:
    # Add all tasks to a map of string -> task
    task_dict = {task.name: task for task in workflow.tasks}
    task_names = task_dict.keys()

    # DAG of the workflow in its raw/un-optimized state
    input_graph = DiGraph()

    # Add all dependencies as edges
    dependencies = [
        (task_dict[dependency], task)
        for task in workflow.tasks
        for dependency in task.dependencies
    ]
    input_graph.add_edges_from(dependencies)

    # Add all the tasks as vertices.
    input_graph.add_nodes_from(workflow.tasks)

    # Make sure all dependencies have an associated task
    dep_tasks = set(
        [
            dependency
            for task in workflow.tasks
            for dependency in task.dependencies
        ]
    )
    if not dep_tasks.issubset(task_names):
        dep = dep_tasks - task_names
        raise WorkflowGraphError(f"Missing task for dependencies: {dep}")

    return input_graph


In [11]:
def draw(G: DiGraph, output_dir: str = None):
    # re-label node with action names
    mapping = {action: action.name for action in G.nodes()}
    g = relabel_nodes(G, mapping)
    # save image
    pdot = to_pydot(g)

    if output_dir:
        pdot.write_png(output_dir)

    return pdot

In [12]:
# # hide

# # Visualize the graph
# viz = draw(workflow_graph)
# view_pydot(viz)

## Render Workflow from Graph

In [51]:
def get_task_imports(tasks):
    """Extracts imports from workflow tasks.
    
    Parameters
    ----------
    tasks : List[Task]
        workflow tasks
    
    Returns
    -------
    imports : Dict[module, Set[objects]]
        Set of object imports for every module
        
    >>> {'airflow.operators.bash_operator': {'BashOperator'},
         'airflow.operators.subdag_operator': {'SubDagOperator'}}
    """
    imports = {}
    for task in tasks:
        for (module, objs) in task.imports().items():
            # initialiase empty set for every module
            imports.setdefault(module, set())
            # add objects into set of imports for given module
            for obj in objs:
                imports[module].add(obj)
    return imports

In [203]:
#export
# Load workflow definition and configuration
load_definition = partial(load, safe_load, WorkflowDefinition.build)
load_configuration = partial(load, safe_load, Config.build)

workflow_definition = load_definition('../temp/workflow.yaml')
workflow_configuration = load_configuration('../temp/config.yaml')

# Workflow builder
build = partial(build_workflow, 
                Workflow,
                workflow_configuration)

# Templater instantiation
templater = JinjaTemplater()

In [206]:
# Build workflow object and derive graph
workflow = build(workflow_definition.name, workflow_definition.tasks)

workflow_graph = build_graph(workflow)

# Render Scaffold from workflow definition 
render_scaffold = partial(render, templater, 'scaffold.txt')
scaffold = {"name": workflow.name,
            "imports": get_task_imports(workflow.tasks),
            "default_args": workflow_definition.default_args,
            "dag_args": workflow_definition.dag_args}

rendered_scaffold = render_scaffold(scaffold)

# Render operator from workflow graph
render_operator = partial(render, templater, 'operator.txt')
rendered_tasks = [render_operator(task.todict())
                  for task in topological_sort(workflow_graph)]

# Combine Scaffold, Operator and SubDAGs
dag_definition = '\n'.join([rendered_scaffold, '\n']+rendered_tasks)


# Build subworkflows and derives graph
subworkflows = {name : build(name, subtask) 
                for (name, subtask) 
                in workflow_definition.subtasks.items()}

subworkflow_graphs = {name : build_graph(subworkflow)
                      for (name, subworkflow)
                      in subworkflows.items()}

# Render SubDags from subworkflow graph
render_subtask = partial(render, templater, 'subtask.txt')

render_subtasks = {}
for (name, subworkflow_graph) in subworkflow_graphs.items():
    render_subtasks.setdefault(name, [])
    for subtask in topological_sort(subworkflow_graph):
        render_subtasks[name].append(render_subtask(subtask.todict()))
        
render_sub_scaffold = partial(render, templater, 'subtask-scaffold.txt')

scaffolds = []
"""
[{'name': 'snapshot-entity1',
  'tasks': {'name': 'task4',
            'dependencies': ['task3'],
            'type': 'SubDagOperator',
            'args': {'retries': 3}},
  'imports': {'airflow.operators.bash_operator': {'BashOperator'}}},
 {'name': 'delta-entity1',
  'tasks': {'name': 'task4',
            'dependencies': ['task3'],
            'type': 'SubDagOperator',
            'args': {'retries': 3}},
  'imports': {'airflow.operators.bash_operator': {'BashOperator'}}}
]
"""
scaffolds = [dict(zip(
                ('name', 'tasks', 'imports', 'default_args', 'dag_args'), 
                (name, render_subtasks.get(name), get_task_imports(subworkflows.get(name).tasks),
                 workflow_definition.default_args, workflow_definition.dag_args)))
             for name, tasks 
             in render_subtasks.items()]

rendered_subworkflows = [render_sub_scaffold(subtask) for subtask in scaffolds]

In [207]:
print(dag_definition)

# -*- coding: utf-8 -*-

from airflow import DAG

from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.contrib.hooks import SSHHook
from airflow.operators.subdag_operator import SubDagOperator

DAG_NAME = "example-dag"

DEFAULT_ARGS = {
    owner: 'airflow',
    depends_on_past: False,
    start_date: datetime(2018, 5, 26),
    retries: 1,
    retry_delay: timedelta(minutes=5),
}

dag = DAG(
    default_args = DEFAULT_ARGS,
    description = 'An example DAG',
    schedule_interval = '@daily',
)


task1 = SSHOperator(
    task_id = "task1",
    ssh_hook = SSHHook(ssh_conn_id=dev_config["emr_con_id"]),
    command = 'some command',
    timeout = 30,
    retries = 4,
    retry_delay = timedelta(seconds=45),
    dag = dag
)


task2 = SSHOperator(
    task_id = "task2",
    ssh_hook = SSHHook(ssh_conn_id=dev_config["emr_con_id"]),
    command = 'some command',
    timeout = 30,
    retries = 4,
    retry_delay = timedelta(seconds=45),
    dag = dag
)

task1 >> task2



In [208]:
for subtask in rendered_subworkflows:
    print(subtask + '\n')

# -*- coding: utf-8 -*-

from datetime import datetime, timedelta

from airflow.models import Variable
form airflow.utils.dates import days_ago
from airflow import DAG

from airflow.contrib.operators.ssh_operator import SSHOperator
from airflow.contrib.hooks import SSHHook

def dag_snapshot_entity1(dag_id):

    DEFAULT_ARGS = {
        owner: 'airflow',
        depends_on_past: False,
        start_date: datetime(2018, 5, 26),
        retries: 1,
        retry_delay: timedelta(minutes=5),
    }

    dag = DAG(
        default_args = DEFAULT_ARGS,
        description = 'An example DAG',
        schedule_interval = '@daily',
    )
    
    task1 = SSHOperator(
      ssh_hook = SSHHook(ssh_conn_id=dev_config["emr_con_id"]),
      command = 'some command',
      timeout = 30,
      retries = 3,
      retry_delay = timedelta(seconds=45),
      dag = dag)

  
    task2 = SSHOperator(
      ssh_hook = SSHHook(ssh_conn_id=dev_config["emr_con_id"]),
      command = 'some command',
      timeou