In [71]:
# Imports.
import networkx as nx
%matplotlib notebook
import matplotlib.pyplot as plt
import os
import pandas as pd
import sys
import tensorflow_model_analysis as tfma  # requires numpy 1.15.4.
import tensorflow_data_validation as tfdv
import time

from ml_metadata.proto import metadata_store_pb2

In [6]:
# Import shared utils.
%run 'mlmd_utils.ipynb'

In [70]:
# Install pip packages required for query workflows.
!{sys.executable} -m pip install -q networkx
!{sys.executable} -m pip install -q matplotlib
!{sys.executable} -m pip install -q /google/data/rw/users/ma/martinz/ml_metadata-0.12.0.dev0-cp27-cp27mu-linux_x86_64.whl
!{sys.executable} -m pip install -q tensorboard
!{sys.executable} -m pip install -q papermill --user
import papermill as pm

In [8]:
# Setup a read-only connection to the ML-Metadata store.
store = get_metadata_store(
    connection_mode=metadata_store_pb2.SqliteMetadataSourceConfig.READONLY)

In [9]:
# Utils to display artifacts, executions.
def _get_value_str(p):
    """Returns a string representation of a metadata_store_pb2.Value object."""
    if p.int_value:
        return str(p.int_value)
    if p.string_value:
        return p.string_value
    if p.double_value:
        return str(p.double_value)
    return ''


def get_df_from_artifacts_or_executions(objects, is_artifact):
    """Returns a `pd.DataFrame` of given artifact/execution objects."""
    data = {}
    for o in objects:
        col_map = {}
        if is_artifact:
            col_map['URI'] = o.uri
        for p in o.properties:
            col_map[p.upper()] = _get_value_str(o.properties[p])
        for p in o.custom_properties:
            col_map[p.upper()] = _get_value_str(o.custom_properties[p])
        data[o.id] = col_map
    df = pd.DataFrame.from_dict(data=data, orient='index').fillna('-')
    df.index.name = 'ID'
    return df


def _get_df_from_single_artifact_or_execution(obj, is_artifact):
    """Returns a `pd.DataFrame` based on properties of an artifact/execution `obj`."""
    data = {}
    if is_artifact:
        data['URI'] = obj.uri
    for p in obj.properties:
        data[p.upper()] = _get_value_str(obj.properties[p])
    for p in obj.custom_properties:
        data[p.upper()] = _get_value_str(obj.custom_properties[p])
    return pd.DataFrame.from_dict(data=data, orient='index', columns=['']).fillna('-')


def get_artifact_df(artifact_id):
    """Returns a `pd.DataFrame` for artifact with `artifact_id`."""
    [artifact] = store.get_artifacts_by_id([artifact_id])
    return _get_df_from_single_artifact_or_execution(artifact, True)


def get_execution_df(execution_id):
    """Returns a `pd.DataFrame` for execution with `execution_id`."""
    [execution] = store.get_executions_by_id([execution_id])
    return _get_df_from_single_artifact_or_execution(execution, False)


def get_artifacts_of_type_df(type_name):
    """Returns a `pd.DataFrame` of all artifacts with `type`."""
    return get_df_from_artifacts_or_executions(
        store.get_artifacts_by_type(type_name), is_artifact=True)


def get_executions_of_type_df(type_name):
    """Returns a `pd.DataFrame` of all executions with `type`."""
    return get_df_from_artifacts_or_executions(
        store.get_executions_by_type(type_name), is_artifact=False)

In [66]:
# Lineage query utils.

def get_input_artifact(artifact_id, input_type_name):
    """Returns the artifact of type `input_type_name` that directly/indirectly generated `artifact_id`."""
    a_events = store.get_events_by_artifact_ids([artifact_id])
    for a_event in a_events:
        if a_event.type != metadata_store_pb2.Event.DECLARED_OUTPUT:
            continue
        [execution] = store.get_executions_by_id([a_event.execution_id])
        #print('execution_id {} -> artifact_id {}'.format(a_event.execution_id, artifact_id))
        e_events = store.get_events_by_execution_ids([execution.id])
        for e_event in e_events:
            if e_event.type != metadata_store_pb2.Event.DECLARED_INPUT:
                continue
            [artifact] = store.get_artifacts_by_id([e_event.artifact_id])
            [artifact_type] = store.get_artifact_types_by_id([artifact.type_id])
            #print('artifact_id {} of type {} -> execution_id {}'.format(
            #    artifact.id, artifact_type.name, execution.id))
            if artifact_type.name == input_type_name:
                return artifact
            input_artifact = get_input_artifact(artifact.id, input_type_name)
            if input_artifact:
                return input_artifact


def get_output_artifact(artifact_id, output_type_name):
    """Returns the artifact of type `output_type_name` that was directly/indirectly generated from `artifact_id`."""
    a_events = store.get_events_by_artifact_ids([artifact_id])
    for a_event in a_events:
        if a_event.type != metadata_store_pb2.Event.DECLARED_INPUT:
            continue
        [execution] = store.get_executions_by_id([a_event.execution_id])
        e_events = store.get_events_by_execution_ids([execution.id])
        for e_event in e_events:
            if e_event.type != metadata_store_pb2.Event.DECLARED_OUTPUT:
                continue
            [artifact] = store.get_artifacts_by_id([e_event.artifact_id])
            [artifact_type] = store.get_artifact_types_by_id([artifact.type_id])
            #print('execution_id {} -> artifact_id {} of type'.format(
            #    execution.id, artifact.id, artifact_type.name))
            if artifact_type.name == output_type_name:
                return artifact
            output_artifact = get_output_artifact(artifact.id, output_type_name)
            if output_artifact:
                return output_artifact
            

def get_execution_for_output_artifact(artifact_id, execution_type_name):
    """"Returns the execution of type `execution_type_name` that generated `artifact_id`."""
    a_events = store.get_events_by_artifact_ids([artifact_id])
    for a_event in a_events:
        if a_event.type != metadata_store_pb2.Event.DECLARED_OUTPUT:
            continue
        [execution] = store.get_executions_by_id([a_event.execution_id])
        [execution_type] = store.get_execution_types_by_id([execution.type_id])
        if execution_type.name == execution_type_name:
            return execution


In [1]:
def display_tfma_analysis(model_id, slicing_column=None):
    """Visualizes TFMA results for `model_id`."""
    model_eval_result = tfma.load_eval_result(
        get_output_artifact(model_id, TFX_ARTIFACT_MODEL_EVAL).uri)
    return tfma.view.render_slicing_metrics(model_eval_result, slicing_column=slicing_column)

def compare_tfma_analysis(model_id, other_model_id):
    """Visualizes TFMA results for `model_id` and `other_model_id`."""
    model_eval_result = tfma.load_eval_result(
        get_output_artifact(model_id, TFX_ARTIFACT_MODEL_EVAL).uri)
    other_model_eval_result = tfma.load_eval_result(
        get_output_artifact(other_model_id, TFX_ARTIFACT_MODEL_EVAL).uri)
    eval_results = tfma.make_eval_results(
        [model_eval_result, other_model_eval_result],
        tfma.constants.MODEL_CENTRIC_MODE)
    return tfma.view.render_time_series(eval_results, tfma.slicer.slicer.SingleSliceSpec())

def display_data_stats_for_model(model_id, other_model_id=None):
    """Visualizes stats for data that generated `model_id` and optionally `other_model_id`."""
    lhs_statistics = tfdv.load_statistics(
            get_input_artifact(model_id, TFX_ARTIFACT_EXAMPLE_STATS).uri + "/eval")
    rhs_statistics = None
    if other_model_id:
        rhs_statistics = tfdv.load_statistics(
            get_input_artifact(other_model_id, TFX_ARTIFACT_EXAMPLE_STATS).uri)
    tfdv.visualize_statistics(
        lhs_statistics,
        rhs_statistics=rhs_statistics,
        lhs_name='Model {}\'s data'.format(model_id),
        rhs_name='Model {}\'s data'.format(other_model_id) if other_model_id else None)
    
def display_tensorboard(model_id, other_model_id=None):
    """Opens up Tensorboard for `model_id` and optionally `other_model_id` and logs output to `log_filename`."""
    log_filename = os.path.join(
        os.environ['HOME'],
        'tensorboard_model_ids_{}_log.txt'.format(
            '-'.join([str(model_id)] + [str(other_model_id)] if other_model_id else []))
    )

    [model] = store.get_artifacts_by_id([model_id])
    logdir_arg = 'model_{}:{}'.format(model_id, model.uri)

    other_model = None
    if other_model_id:
        [other_model] = store.get_artifacts_by_id([other_model_id])
        logdir_arg = logdir_arg + ',model_{}:{}'.format(other_model_id, other_model.uri)

    pm.execute_notebook(
        'spawn_tensorboard.ipynb',
        'spawn_tensorboard_output.ipynb',
        parameters = dict(tb_logdir=logdir_arg, tb_run_log=log_filename),
        progress_bar=False)
    time.sleep(5)  # Give it some time for log_filename to be flushed.
    with open(log_filename) as f:
        for l in f.readlines():
            if 'TensorBoard' in l:
                return l.split(' ')[3]

In [10]:
# Internal utils used to compute lineage DAG.

def _find_upstream_executions(artifact_id):
    """Returns a list of upstream execution ids."""
    result = []
    for e in store.get_events_by_artifact_ids([artifact_id]):
        if e.type in [metadata_store_pb2.Event.DECLARED_OUTPUT, metadata_store_pb2.Event.OUTPUT]:
            result.append(e.execution_id)
    return result

def _find_upstream_artifacts(execution_id):
    """Returns a list of upstream artifact ids."""
    result = []
    for e in store.get_events_by_execution_ids([execution_id]):
        if e.type in [metadata_store_pb2.Event.DECLARED_INPUT, metadata_store_pb2.Event.INPUT]:
            result.append(e.artifact_id)
    return result

def _add_node_attribute(g, node_id, depth, is_artifact):
    # if it is not an artifact, use negative gnode id
    gnode_id = node_id if is_artifact else -1 * node_id
    g.add_node(gnode_id, depth=depth, is_artifact=is_artifact)
    node_label = "" #str(node_id) + "\n"
    if is_artifact:
        [a] = store.get_artifacts_by_id([node_id])
        [t] = store.get_artifact_types_by_id([a.type_id])
        node_label += t.name
    else:
        [e] = store.get_executions_by_id([node_id])
        [t] = store.get_execution_types_by_id([e.type_id])
        node_label += t.name
    g.nodes[gnode_id]['_label_'] = node_label

def _add_parents(g, node_id, is_artifact, depth, max_depth=None):
    # if it is not an artifact, use negative gnode id
    gnode_id = node_id if is_artifact else -1 * node_id
    _add_node_attribute(g, node_id, depth, is_artifact)
    if gnode_id in g and len(g.in_edges(gnode_id)) > 0: 
        return
    if max_depth is not None and depth > max_depth:
        return
    if is_artifact:
        for e_id in _find_upstream_executions(node_id):
            g.add_edge(e_id * -1, node_id)
            _add_parents(g, e_id, not is_artifact, depth + 1, max_depth)
    else:
        for a_id in _find_upstream_artifacts(node_id):
            g.add_edge(a_id, node_id * -1)
            _add_parents(g, a_id, not is_artifact, depth + 1, max_depth)

def _construct_artifact_lineage(artifact_id, max_depth=None):
    """Returns a networkx DiGraph representing the lineage of the given artifact_id."""
    g = nx.DiGraph(query_artifact_id=artifact_id)
    if max_depth is None or max_depth > 0:
        _add_parents(g, artifact_id, True, 1, max_depth)
    return g

In [11]:
# Generic utils to get and plot artifact lineage.

def get_artifact_lineage(artifact_id, max_depth=None):
    """Returns lineage of artifact_id as a DAG.
       DAG is a networkx DiGraph
    """
    return _construct_artifact_lineage(artifact_id, max_depth)

def plot_artifact_lineage(g):
    """Use networkx and matplotlib to plot the graph.
       The nodes are places from left to right w.r.t. its depth.
       Nodes at the same depths are placed vertically.
       Artifact is shown in green, and Execution is shown in red.
       Nodes are positioned in a bipartite graph layout. 
    """
    # make a copy of the graph; add auxilary nodes
    dag = g.copy(as_view=False)
    label_anchor_id = 10000
    for node_id in g.nodes:
        if node_id > 0:
            dag.add_node(label_anchor_id + node_id)
        else:
            dag.add_node(node_id - label_anchor_id)

    # assign node color and label
    node_color = ""
    node_labels = {}
    for node_id in dag.nodes:
        if node_id > 0 and node_id < label_anchor_id:
            node_color += 'c' 
            node_labels[node_id] = abs(node_id)
        elif node_id > 0 and node_id >= label_anchor_id:
            node_color += 'w'
            node_labels[node_id] = dag.node[node_id - label_anchor_id]['_label_']
        elif node_id < 0 and node_id > -1 * label_anchor_id:
            node_color += 'm'
            node_labels[node_id] = abs(node_id)
        else:
            node_color += 'w'
            node_labels[node_id] = dag.node[node_id + label_anchor_id]['_label_']                
        
    pos = {}
    a_nodes = []; e_nodes = []
    for node_id in dag.nodes:
        if node_id > 0 and node_id < label_anchor_id:
            a_nodes.append(node_id)
        elif node_id < 0 and node_id > -1 * label_anchor_id:
            e_nodes.append(node_id)
            
    def order_nodes_by_depth(node_id):
        return -1 * dag.node[node_id]['depth'] + 1.0/node_id
            
    a_nodes.sort(key = abs)
    e_nodes.sort(key = abs) 
    a_node_y = 0
    e_node_y = 0.035
    a_offset = -0.5 if len(a_nodes) % 2 == 0 else 0
    e_offset = -0.5 if len(e_nodes) % 2 == 0 else 0
    a_node_x_min = -1 * len(a_nodes)/2 + a_offset
    e_node_x_min = -1 * len(e_nodes)/2 + e_offset
    a_node_x = a_node_x_min
    e_node_x = e_node_x_min
    node_step = 1
    for a_id in a_nodes:
        pos[a_id] = [a_node_x, a_node_y]
        pos[a_id + label_anchor_id] = [a_node_x, a_node_y - 0.01]
        a_node_x += node_step
    for e_id in e_nodes:
        pos[e_id] = [e_node_x, e_node_y]
        pos[e_id - label_anchor_id] = [e_node_x, e_node_y + 0.01]
        e_node_x += node_step

    nx.draw(dag, pos=pos,
            node_size=500, node_color=node_color, 
            labels=node_labels, node_shape = 'o', font_size=8.3, label="abc")#, ax=ax)
    
    legend_x = max(a_node_x, e_node_x) - 0.85
    legend_y = 0.02
    a_bbox_props = dict(boxstyle="square,pad=0.3", fc="c", ec="b", lw=0)
    plt.text(legend_x - 0.0025, legend_y, "  Artifacts  ", bbox=a_bbox_props)
    e_bbox_props = dict(boxstyle="square,pad=0.3", fc="m", ec="b", lw=0)
    plt.text(legend_x - 0.0025, legend_y - 0.007, "Executions", bbox=e_bbox_props)

    x_lim_left = min(a_node_x_min, e_node_x_min) - 0.5
    x_lim_right = min(1 - 0.05 * len(a_nodes), max(a_node_x, e_node_x))
    
    x_lim_left = max(-2 - 1.5/len(a_nodes), min(a_node_x_min, e_node_x_min) - 1.0)
    x_lim_right = max(a_node_x, e_node_x) + 0.1
    plt.xlim(x_lim_left, x_lim_right)
    
    plt.show()

def get_and_plot_artifact_lineage(artifact_id, max_depth=None):
    plot_artifact_lineage(get_artifact_lineage(artifact_id, max_depth=max_depth))