# LineageTableVisualizer

In [None]:
import sagemaker
from sagemaker.lineage.visualizer import LineageTableVisualizer

sm_client = sagemaker.Session().sagemaker_client
sagemaker_session = sagemaker.session.Session()

In [None]:
viz = LineageTableVisualizer(sagemaker_session)

In [None]:
latest_execution_arn = sm_client.list_pipeline_executions(
    PipelineName="pytorch-nlp-Pipeline",
    SortBy="CreationTime",
    SortOrder="Descending",
    MaxResults=1,
)["PipelineExecutionSummaries"][0]["PipelineExecutionArn"]

In [None]:
pipeline_execution_steps = sm_client.list_pipeline_execution_steps(
    PipelineExecutionArn=latest_execution_arn, SortOrder="Ascending"
)["PipelineExecutionSteps"]

In [None]:
for execution_step in pipeline_execution_steps:
    print(execution_step["StepName"])
    display(viz.show(pipeline_execution_step=execution_step))
    # time.sleep(5)

# PyVis
Note by default the links between execution steps are missing. Use this once Sagemaker Experiments has been integrated into the pipeline workflow.

In [None]:
!pip install pyvis

In [None]:
from pyvis.network import Network
import os
import pprint as pp


class Visualizer:
    def __init__(self):
        self.directory = "output"
        if not os.path.exists(self.directory):
            os.makedirs(self.directory)

    def render(self, query_lineage_response, scenario_name):
        net = self.get_network()
        for vertex in query_lineage_response["Vertices"]:
            arn = vertex["Arn"]
            if "Type" in vertex:
                label = vertex["Type"]
            else:
                label = None
            lineage_type = vertex["LineageType"]
            name = self.get_name(arn)
            title = self.get_title(arn)
            net.add_node(
                vertex["Arn"],
                label=label + "\n" + lineage_type,
                title=title,
                shape="circle",
            )

        for edge in query_lineage_response["Edges"]:
            source = edge["SourceArn"]
            dest = edge["DestinationArn"]
            net.add_edge(dest, source)

        return net.show(f"{self.directory}/{scenario_name}.html")

    def get_title(self, arn):
        return f"Arn: {arn}"

    def get_name(self, arn):
        name = arn.split("/")[1]
        return name

    def get_network(self):
        net = Network(height="400px", width="800px", directed=True, notebook=True)
        net.set_options(
            """
        var options = {
  "nodes": {
    "borderWidth": 3,
    "shadow": {
      "enabled": true
    },
    "shapeProperties": {
      "borderRadius": 3
    },
    "size": 11,
    "shape": "circle"
  },
  "edges": {
    "arrows": {
      "to": {
        "enabled": true
      }
    },
    "color": {
      "inherit": true
    },
    "smooth": false
  },
  "layout": {
    "hierarchical": {
      "enabled": true,
      "direction": "LR",
      "sortMethod": "directed"
    }
  },
  "physics": {
    "hierarchicalRepulsion": {
      "centralGravity": 0
    },
    "minVelocity": 0.75,
    "solver": "hierarchicalRepulsion"
  }
}
        """
        )
        return

In [None]:
from sagemaker.lineage.context import Context, EndpointContext
from sagemaker.lineage.action import Action
from sagemaker.lineage.association import Association
from sagemaker.lineage.artifact import Artifact, ModelArtifact, DatasetArtifact

from sagemaker.lineage.query import (
    LineageQuery,
    LineageFilter,
    LineageSourceEnum,
    LineageEntityEnum,
    LineageQueryDirectionEnum,
)

# Find the endpoint context and model artifact that should be used for the lineage queries.
endpoint_arn = "arn:aws:sagemaker:<region>:<>:endpoint/lambda-deploy-endpoint"

contexts = Context.list(source_uri=endpoint_arn)
context_name = list(contexts)[0].context_name
endpoint_context = EndpointContext.load(context_name=context_name)

In [None]:
# Graph APIs
# Here we use the boto3 `query_lineage` API to generate the query response to plot.
query_response = sm_client.query_lineage(
    StartArns=[endpoint_context.context_arn], Direction="Ascendants", IncludeEdges=True
)

viz = Visualizer()
viz.render(query_response, "Endpoint")