In [1]:
from __future__ import print_function

# Install pip packages required for query workflows.
!pip install -q networkx
!pip install -q matplotlib
!pip install -q tensorboard

!pip install -q ml_metadata
# !pip install -q /google/data/rw/users/ma/martinz/ml_metadata-0.12.0.dev0-cp27-cp27mu-linux_x86_64.whl

# Imports.
import os
import sys
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import tensorflow_model_analysis as tfma  # requires numpy 1.15.4.
from IPython.display import display, display_html
from ml_metadata.proto import metadata_store_pb2
from ml_metadata.metadata_store import metadata_store

# Constants for TFX Artifact types.
TFX_ARTIFACT_EXAMPLES = 'ExamplesPath'
TFX_ARTIFACT_SCHEMA = 'SchemaPath'
TFX_ARTIFACT_EXAMPLE_STATS = 'ExampleStatisticsPath'
TFX_ARTIFACT_EXAMPLE_VALIDATION = 'ExampleValidationPath'
TFX_ARTIFACT_TRANSFORMED_EXAMPLES = 'TransformPath'
TFX_ARTIFACT_MODEL = 'ModelExportPath'
TFX_ARTIFACT_MODEL_EVAL = 'ModelEvalPath'

# Constants for TFX Execution types.
TFX_EXECUTION_EXAMPLE_GEN = 'examples_gen'
TFX_EXECUTION_STATISTICS_GEN = 'statistics_gen'
TFX_EXECUTION_SCHEMA_GEN = 'schema_gen'
TFX_EXECUTION_EXAMPLE_VALIDATION = 'example_validation'
TFX_EXECUTION_TRANSFORM = 'transform'
TFX_EXECUTION_TRAINER = 'trainer'
TFX_EXECUTION_EVALUATOR = 'evaluator'

def _make_default_sqlite_uri():
    return '/'.join([
        os.environ['HOME'],
        "airflow/data/tfx_example/pipelines/tfx_example_pipeline_DAG/metadata.db",
    ])


def get_metadata_store(
    filename_uri='',
    connection_mode=metadata_store_pb2.SqliteMetadataSourceConfig.UNKNOWN,
    reset=False):    
    """Returns a metadata_store.MetadataStore handle to a SQLITE backend."""
    c = metadata_store_pb2.ConnectionConfig()
    c.sqlite.filename_uri = filename_uri or _make_default_sqlite_uri()
    c.sqlite.connection_mode = connection_mode
    return metadata_store.MetadataStore(c)


def delete_sqlite_db(filename_uri=''):
    os.remove(filename_uri or _make_default_sqlite_uri())


def update_airflow_db_airtifacts_uri(extracted_dir=''):
    """extracted_dir is the location where the airflow data dir is extracted
       e.g., /usr/local/google/home/huimiao/airflow/
                data/
                   taxi_data/
                   tfx/
       then extracted_dir = '/usr/local/google/home/huimiao/airflow/'
    """    
    store = get_metadata_store(
        filename_uri = extracted_dir + "data/tfx_example/pipelines/tfx_example_pipeline_DAG/metadata.db",
        connection_mode=metadata_store_pb2.SqliteMetadataSourceConfig.READWRITE)
    for artifact in store.get_artifacts():
        tokens = artifact.uri.split("airflow")
        if len(tokens) > 1:
            new_uri = extracted_dir + tokens[1]
            artifact.uri = new_uri
            store.put_artifacts([artifact])

            
# Utils to display artifacts, executions.
def _get_value_str(p):
    """Returns a string representation of a metadata_store_pb2.Value object."""
    if p.int_value:
        return str(p.int_value)
    if p.string_value:
        return p.string_value
    if p.double_value:
        return str(p.double_value)
    return ''


def get_df_from_artifacts_or_executions(objects, is_artifact):
    """Returns a `pd.DataFrame` of given artifact/execution objects."""
    data = {}
    for o in objects:
        col_map = {}
        if is_artifact:
            col_map['URI'] = o.uri
        for p in o.properties:
            col_map[p.upper()] = _get_value_str(o.properties[p])
        for p in o.custom_properties:
            col_map[p.upper()] = _get_value_str(o.custom_properties[p])
        data[o.id] = col_map
    df = pd.DataFrame.from_dict(data=data, orient='index').fillna('-')
    df.index.name = 'ID'
    return df


def _get_df_from_single_artifact_or_execution(obj, is_artifact):
    """Returns a `pd.DataFrame` based on properties of an artifact/execution `obj`."""
    data = {}
    if is_artifact:
        data['URI'] = obj.uri
    for p in obj.properties:
        data[p.upper()] = _get_value_str(obj.properties[p])
    for p in obj.custom_properties:
        data[p.upper()] = _get_value_str(obj.custom_properties[p])
    return pd.DataFrame.from_dict(data=data, orient='index', columns=['']).fillna('-')


def get_artifact_df(artifact_id):
    """Returns a `pd.DataFrame` for artifact with `artifact_id`."""
    [artifact] = store.get_artifacts_by_id([artifact_id])
    return _get_df_from_single_artifact_or_execution(artifact, True)


def get_execution_df(execution_id):
    """Returns a `pd.DataFrame` for execution with `execution_id`."""
    [execution] = store.get_executions_by_id([execution_id])
    return _get_df_from_single_artifact_or_execution(execution, False)


def get_artifacts_of_type_df(type_name):
    """Returns a `pd.DataFrame` of all artifacts with `type`."""
    return get_df_from_artifacts_or_executions(
        store.get_artifacts_by_type(type_name), is_artifact=True)


def get_executions_of_type_df(type_name):
    """Returns a `pd.DataFrame` of all executions with `type`."""
    return get_df_from_artifacts_or_executions(
        store.get_executions_by_type(type_name), is_artifact=False)

# Specialized util methods to query lineage for models.

def get_trainer_run_id_for_model(model_id):
    """Returns an execution of type TFX_EXECUTION_TRAINER that generated `model_id`."""
    trainer_run_ids = []
    events = store.get_events_by_artifact_ids([model_id])
    for event in events:
        if event.type != metadata_store_pb2.Event.DECLARED_OUTPUT:
            continue
        [execution] = store.get_executions_by_id([event.execution_id])
        [execution_type] = store.get_execution_types_by_id([execution.type_id])
        if execution_type.name != TFX_EXECUTION_TRAINER:
            continue
        trainer_run_ids.append(execution.id)
    if len(trainer_run_ids) > 1:
        raise ValueError('Multiple trainer runs {} generated model artifact {}'.format(
            ','.join(trainer_run_ids), model_id))
    return trainer_run_ids[0] if trainer_run_ids else None


def get_tfma_eval_result_for_model(model_id):
    """Returns an artifact of type TFX_ARTIFACT_MODEL_EVAL for given `model_id`."""
    # Get a tfma run for given model_id.
    events = store.get_events_by_artifact_ids([model_id])
    tfma_run_ids = []
    for event in events:
        if event.type != metadata_store_pb2.Event.DECLARED_INPUT:
            continue
        [execution] = store.get_executions_by_id([event.execution_id])
        [execution_type] = store.get_execution_types_by_id([execution.type_id])
        if execution_type.name != TFX_EXECUTION_EVALUATOR:
            continue
        tfma_run_ids.append(execution.id)
    if not tfma_run_ids:
        return None
    tfma_run_id = tfma_run_ids[0]

    # Get the tfma eval result for the tfma_run_id.
    events = store.get_events_by_execution_ids([tfma_run_id])
    tfma_eval_results = []
    for event in events:
        if event.type != metadata_store_pb2.Event.DECLARED_OUTPUT:
            continue
        [artifact] = store.get_artifacts_by_id([event.artifact_id])
        [artifact_type] = store.get_artifact_types_by_id([artifact.type_id])
        if artifact_type.name != TFX_ARTIFACT_MODEL_EVAL:
            continue
        tfma_eval_results.append(artifact)
    if len(tfma_eval_results) > 1:
        raise ValueError('Multiple tfma eval results {} for tfma run {} of model {}'.format(
            ','.join([r.id for r in tfma_eval_results]), tfma_run_id, model_id))
    return tfma_eval_results[0] if tfma_eval_results else None

# Internal utils used to compute lineage DAG.

def _find_upstream_executions(artifact_id):
    """Returns a list of upstream execution ids."""
    result = []
    for e in store.get_events_by_artifact_ids([artifact_id]):
        if e.type in [metadata_store_pb2.Event.DECLARED_OUTPUT, metadata_store_pb2.Event.OUTPUT]:
            result.append(e.execution_id)
    return result


def _find_upstream_artifacts(execution_id):
    """Returns a list of upstream artifact ids."""
    result = []
    for e in store.get_events_by_execution_ids([execution_id]):
        if e.type in [metadata_store_pb2.Event.DECLARED_INPUT, metadata_store_pb2.Event.INPUT]:
            result.append(e.artifact_id)
    return result


def _add_node_attribute(g, node_id, depth, is_artifact):
    # if it is not an artifact, use negative gnode id
    gnode_id = node_id if is_artifact else -1 * node_id
    g.add_node(gnode_id, depth=depth, is_artifact=is_artifact)
    node_label = str(node_id) + "\n"
    if is_artifact:
        [a] = store.get_artifacts_by_id([node_id])
        [t] = store.get_artifact_types_by_id([a.type_id])
        node_label += t.name
    else:
        [e] = store.get_executions_by_id([node_id])
        [t] = store.get_execution_types_by_id([e.type_id])
        node_label += t.name
    g.nodes[gnode_id]['_label_'] = node_label

    
def _add_parents(g, node_id, is_artifact, depth, max_depth=None):
    _add_node_attribute(g, node_id, depth, is_artifact)
    gnode_id = node_id if is_artifact else -1 * node_id
    if gnode_id in g and len(g.in_edges(gnode_id)) > 0: 
        return
    if max_depth is not None and depth > max_depth:
        return
    if is_artifact:
        for e_id in _find_upstream_executions(node_id):
            g.add_edge(e_id * -1, node_id)
            _add_parents(g, e_id, not is_artifact, depth + 1, max_depth)
    else:
        for a_id in _find_upstream_artifacts(node_id):
            g.add_edge(a_id, node_id * -1)
            _add_parents(g, a_id, not is_artifact, depth + 1, max_depth)
    

def _construct_artifact_lineage(artifact_id, max_depth=None):
    """Returns a networkx DiGraph representing the lineage of the given artifact_id."""
    g = nx.DiGraph(query_artifact_id=artifact_id)
    if max_depth is None or max_depth > 0:
        _add_parents(g, artifact_id, True, 1, max_depth)
    return g


# Generic utils to get and plot artifact lineage.

def get_artifact_lineage(artifact_id, max_depth=None):
    """Returns lineage of artifact_id as a DAG.
       DAG is a networkx DiGraph
    """
    return _construct_artifact_lineage(artifact_id, max_depth)


def plot_artifact_lineage(dag):
    """Use networkx and matplotlib to plot the graph.
       The nodes are places from left to right w.r.t. its depth.
       Nodes at the same depths are placed vertically.
       Artifact is shown in green, and Execution is shown in red.
       Nodes are positioned in a bipartite graph layout. 
    """
    node_color = ""
    node_labels = {}
    for node_id in dag.nodes:
        node_color += 'c' if dag.node[node_id]['is_artifact'] else 'r'
        node_labels[node_id] = dag.node[node_id]['_label_']
    
    pos = {}
    a_nodes = []; e_nodes = []
    for node_id in dag.nodes:
        if node_id > 0:
            a_nodes.append(node_id)
        else:
            e_nodes.append(node_id)

    def order_nodes_by_depth(node_id):
        return -1 * dag.node[node_id]['depth']
            
    a_nodes.sort(key = order_nodes_by_depth)
    e_nodes.sort(key = order_nodes_by_depth) 
    a_node_y = 0
    e_node_y = 0.2
    a_offset = -0.5 if len(a_nodes) % 2 == 0 else 0
    e_offset = -0.5 if len(e_nodes) % 2 == 0 else 0
    a_node_x_min = -1 * len(a_nodes)/2 + a_offset
    e_node_x_min = -1 * len(e_nodes)/2 + e_offset
    for a_id in a_nodes:
        pos[a_id] = [a_node_x_min, a_node_y]
        a_node_x_min += 1
    for e_id in e_nodes:
        pos[e_id] = [e_node_x_min, e_node_y]
        e_node_x_min += 1    

    nx.draw(dag, pos=pos, 
            node_size=3000, node_color=node_color, labels=node_labels, node_shape = '8')
    plt.show()

    
def get_and_plot_artifact_lineage(artifact_id, max_depth=None):
    plot_artifact_lineage(get_artifact_lineage(artifact_id, max_depth=max_depth))



[33mDEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.[0m
[33mDEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.[0m
[33mDEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.[0m
[33mDEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A future version of pip will drop support for Python 2.7.[0m


In [2]:
# Setup a read-only connection to the ML-Metadata store.
store = get_metadata_store(
    connection_mode=metadata_store_pb2.SqliteMetadataSourceConfig.READONLY)

In [3]:
# Visualize all TFX Trainer runs.
display(get_executions_of_type_df(TFX_EXECUTION_TRAINER))

Unnamed: 0_level_0,EVAL_STEPS,LOG_ROOT,WARM_START_FROM,WARM_STARTING,STATE,CHECKSUM_MD5,MODULE_FILE,TRAIN_STEPS
ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
528,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
530,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
532,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
534,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
536,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
539,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
541,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
543,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
545,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
547,5000,/var/tmp/tfx/logs/tfx_example_pipeline_DAG/tfx...,,True,new,74d80d9bc7891faf34fc3a5a5423d0c6,/usr/local/google/home/robertcrowe/airflow/plu...,10000
