diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index f6966d9fa3..659f88a59c 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -16,6 +16,7 @@ from datetime import datetime from enum import Enum from typing import Optional, Union, List, Dict +import re from sagemaker.lineage._utils import get_resource_name_from_arn, get_module @@ -260,6 +261,8 @@ def __init__(self, graph_styles): ( self.Network, self.Options, + self.IFrame, + self.BeautifulSoup, ) = self._import_visual_modules() self.graph_styles = graph_styles @@ -300,13 +303,60 @@ def _import_visual_modules(self): get_module("pyvis") from pyvis.network import Network from pyvis.options import Options + from IPython.display import IFrame - return Network, Options + get_module("bs4") + from bs4 import BeautifulSoup + + return Network, Options, IFrame, BeautifulSoup def _node_color(self, entity): """Return node color by background-color specified in graph styles.""" return self.graph_styles[entity]["style"]["background-color"] + def _get_legend_line(self, component_name): + """Generate lengend div line for each graph component in graph_styles.""" + if self.graph_styles[component_name]["isShape"] == "False": + return '
\ +
\ +
{name}
'.format( + color=self.graph_styles[component_name]["style"]["background-color"], + name=self.graph_styles[component_name]["name"], + ) + else: + return '
{shape}
\ +
\ +
{name}
'.format( + shape=self.graph_styles[component_name]["style"]["shape"], + name=self.graph_styles[component_name]["name"], + ) + + def _add_legend(self, path): + """Embed legend to html file generated by pyvis.""" + f = open(path, "r") + content = self.BeautifulSoup(f, "html.parser") + + legend = """ +
+ """ + # iterate through graph styles to get legend + for component in self.graph_styles.keys(): + legend += self._get_legend_line(component_name=component) + + legend += "
" + + legend_div = self.BeautifulSoup(legend, "html.parser") + + content.div.insert_after(legend_div) + + html = content.prettify() + + with open(path, "w", encoding="utf8") as file: + file.write(html) + def render(self, elements, path="pyvisExample.html"): """Render graph for lineage query result. @@ -325,23 +375,51 @@ def render(self, elements, path="pyvisExample.html"): display graph: The interactive visualization is presented as a static HTML file. """ - net = self.Network(height="500px", width="100%", notebook=True, directed=True) + net = self.Network(height="600px", width="82%", notebook=True, directed=True) net.set_options(self._options) # add nodes to graph for arn, source, entity, is_start_arn in elements["nodes"]: + entity_text = re.sub(r"(\w)([A-Z])", r"\1 \2", entity) + source = re.sub(r"(\w)([A-Z])", r"\1 \2", source) + account_id = re.search(r":\d{12}:", arn) + name = re.search(r"\/.*", arn) + node_info = ( + "Entity: " + + entity_text + + "\nType: " + + source + + "\nAccount ID: " + + str(account_id.group()[1:-1]) + + "\nName: " + + str(name.group()[1:]) + ) if is_start_arn: # startarn net.add_node( - arn, label=source, title=entity, color=self._node_color(entity), shape="star" + arn, + label=source, + title=node_info, + color=self._node_color(entity), + shape="star", + borderWidth=3, ) else: - net.add_node(arn, label=source, title=entity, color=self._node_color(entity)) + net.add_node( + arn, + label=source, + title=node_info, + color=self._node_color(entity), + borderWidth=3, + ) # add edges to graph for src, dest, asso_type in elements["edges"]: - net.add_edge(src, dest, title=asso_type) + net.add_edge(src, dest, title=asso_type, width=2) + + net.write_html(path) + self._add_legend(path) - return net.show(path) + return self.IFrame(path, width="100%", height="600px") class LineageQueryResult(object): @@ -391,7 +469,7 @@ def __str__(self): """ return ( - "{\n" + "{" + "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items()) + "\n}" ) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index 555b98452e..4b9e816623 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -14,6 +14,7 @@ from __future__ import absolute_import import time import os +import re import pytest @@ -160,31 +161,56 @@ def test_graph_visualize(sagemaker_session, extract_data_from_html): "color": "#146eb4", "label": "Model", "shape": "star", - "title": "Artifact", + "title": "Entity: Artifact" + + "\nType: Model" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", graph_startarn).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", graph_startarn).group()[1:]), }, image_artifact: { "color": "#146eb4", "label": "Image", "shape": "dot", - "title": "Artifact", + "title": "Entity: Artifact" + + "\nType: Image" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", image_artifact).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", image_artifact).group()[1:]), }, dataset_artifact: { "color": "#146eb4", - "label": "DataSet", + "label": "Data Set", "shape": "dot", - "title": "Artifact", + "title": "Entity: Artifact" + + "\nType: Data Set" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", dataset_artifact).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", dataset_artifact).group()[1:]), }, modeldeploy_action: { "color": "#88c396", - "label": "ModelDeploy", + "label": "Model Deploy", "shape": "dot", - "title": "Action", + "title": "Entity: Action" + + "\nType: Model Deploy" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", modeldeploy_action).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", modeldeploy_action).group()[1:]), }, endpoint_context: { "color": "#ff9900", "label": "Endpoint", "shape": "dot", - "title": "Context", + "title": "Entity: Context" + + "\nType: Endpoint" + + "\nAccount ID: " + + str(re.search(r":\d{12}:", endpoint_context).group()[1:-1]) + + "\nName: " + + str(re.search(r"\/.*", endpoint_context).group()[1:]), }, } diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index 0a357eb1fc..bac5cb6cdb 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -580,7 +580,7 @@ def test_query_lineage_result_str(sagemaker_session): assert ( response_str - == "{\n'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + == "{'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', " + "'_session': }, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': " + "'Model', '_session': }],\n\n'startarn': "