diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py
index f6966d9fa3..659f88a59c 100644
--- a/src/sagemaker/lineage/query.py
+++ b/src/sagemaker/lineage/query.py
@@ -16,6 +16,7 @@
from datetime import datetime
from enum import Enum
from typing import Optional, Union, List, Dict
+import re
from sagemaker.lineage._utils import get_resource_name_from_arn, get_module
@@ -260,6 +261,8 @@ def __init__(self, graph_styles):
(
self.Network,
self.Options,
+ self.IFrame,
+ self.BeautifulSoup,
) = self._import_visual_modules()
self.graph_styles = graph_styles
@@ -300,13 +303,60 @@ def _import_visual_modules(self):
get_module("pyvis")
from pyvis.network import Network
from pyvis.options import Options
+ from IPython.display import IFrame
- return Network, Options
+ get_module("bs4")
+ from bs4 import BeautifulSoup
+
+ return Network, Options, IFrame, BeautifulSoup
def _node_color(self, entity):
"""Return node color by background-color specified in graph styles."""
return self.graph_styles[entity]["style"]["background-color"]
+ def _get_legend_line(self, component_name):
+ """Generate lengend div line for each graph component in graph_styles."""
+ if self.graph_styles[component_name]["isShape"] == "False":
+ return '
'.format(
+ color=self.graph_styles[component_name]["style"]["background-color"],
+ name=self.graph_styles[component_name]["name"],
+ )
+ else:
+ return '{shape}
\
+ \
+ {name}
'.format(
+ shape=self.graph_styles[component_name]["style"]["shape"],
+ name=self.graph_styles[component_name]["name"],
+ )
+
+ def _add_legend(self, path):
+ """Embed legend to html file generated by pyvis."""
+ f = open(path, "r")
+ content = self.BeautifulSoup(f, "html.parser")
+
+ legend = """
+
+ """
+ # iterate through graph styles to get legend
+ for component in self.graph_styles.keys():
+ legend += self._get_legend_line(component_name=component)
+
+ legend += "
"
+
+ legend_div = self.BeautifulSoup(legend, "html.parser")
+
+ content.div.insert_after(legend_div)
+
+ html = content.prettify()
+
+ with open(path, "w", encoding="utf8") as file:
+ file.write(html)
+
def render(self, elements, path="pyvisExample.html"):
"""Render graph for lineage query result.
@@ -325,23 +375,51 @@ def render(self, elements, path="pyvisExample.html"):
display graph: The interactive visualization is presented as a static HTML file.
"""
- net = self.Network(height="500px", width="100%", notebook=True, directed=True)
+ net = self.Network(height="600px", width="82%", notebook=True, directed=True)
net.set_options(self._options)
# add nodes to graph
for arn, source, entity, is_start_arn in elements["nodes"]:
+ entity_text = re.sub(r"(\w)([A-Z])", r"\1 \2", entity)
+ source = re.sub(r"(\w)([A-Z])", r"\1 \2", source)
+ account_id = re.search(r":\d{12}:", arn)
+ name = re.search(r"\/.*", arn)
+ node_info = (
+ "Entity: "
+ + entity_text
+ + "\nType: "
+ + source
+ + "\nAccount ID: "
+ + str(account_id.group()[1:-1])
+ + "\nName: "
+ + str(name.group()[1:])
+ )
if is_start_arn: # startarn
net.add_node(
- arn, label=source, title=entity, color=self._node_color(entity), shape="star"
+ arn,
+ label=source,
+ title=node_info,
+ color=self._node_color(entity),
+ shape="star",
+ borderWidth=3,
)
else:
- net.add_node(arn, label=source, title=entity, color=self._node_color(entity))
+ net.add_node(
+ arn,
+ label=source,
+ title=node_info,
+ color=self._node_color(entity),
+ borderWidth=3,
+ )
# add edges to graph
for src, dest, asso_type in elements["edges"]:
- net.add_edge(src, dest, title=asso_type)
+ net.add_edge(src, dest, title=asso_type, width=2)
+
+ net.write_html(path)
+ self._add_legend(path)
- return net.show(path)
+ return self.IFrame(path, width="100%", height="600px")
class LineageQueryResult(object):
@@ -391,7 +469,7 @@ def __str__(self):
"""
return (
- "{\n"
+ "{"
+ "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items())
+ "\n}"
)
diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py
index 555b98452e..4b9e816623 100644
--- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py
+++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py
@@ -14,6 +14,7 @@
from __future__ import absolute_import
import time
import os
+import re
import pytest
@@ -160,31 +161,56 @@ def test_graph_visualize(sagemaker_session, extract_data_from_html):
"color": "#146eb4",
"label": "Model",
"shape": "star",
- "title": "Artifact",
+ "title": "Entity: Artifact"
+ + "\nType: Model"
+ + "\nAccount ID: "
+ + str(re.search(r":\d{12}:", graph_startarn).group()[1:-1])
+ + "\nName: "
+ + str(re.search(r"\/.*", graph_startarn).group()[1:]),
},
image_artifact: {
"color": "#146eb4",
"label": "Image",
"shape": "dot",
- "title": "Artifact",
+ "title": "Entity: Artifact"
+ + "\nType: Image"
+ + "\nAccount ID: "
+ + str(re.search(r":\d{12}:", image_artifact).group()[1:-1])
+ + "\nName: "
+ + str(re.search(r"\/.*", image_artifact).group()[1:]),
},
dataset_artifact: {
"color": "#146eb4",
- "label": "DataSet",
+ "label": "Data Set",
"shape": "dot",
- "title": "Artifact",
+ "title": "Entity: Artifact"
+ + "\nType: Data Set"
+ + "\nAccount ID: "
+ + str(re.search(r":\d{12}:", dataset_artifact).group()[1:-1])
+ + "\nName: "
+ + str(re.search(r"\/.*", dataset_artifact).group()[1:]),
},
modeldeploy_action: {
"color": "#88c396",
- "label": "ModelDeploy",
+ "label": "Model Deploy",
"shape": "dot",
- "title": "Action",
+ "title": "Entity: Action"
+ + "\nType: Model Deploy"
+ + "\nAccount ID: "
+ + str(re.search(r":\d{12}:", modeldeploy_action).group()[1:-1])
+ + "\nName: "
+ + str(re.search(r"\/.*", modeldeploy_action).group()[1:]),
},
endpoint_context: {
"color": "#ff9900",
"label": "Endpoint",
"shape": "dot",
- "title": "Context",
+ "title": "Entity: Context"
+ + "\nType: Endpoint"
+ + "\nAccount ID: "
+ + str(re.search(r":\d{12}:", endpoint_context).group()[1:-1])
+ + "\nName: "
+ + str(re.search(r"\/.*", endpoint_context).group()[1:]),
},
}
diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py
index 0a357eb1fc..bac5cb6cdb 100644
--- a/tests/unit/sagemaker/lineage/test_query.py
+++ b/tests/unit/sagemaker/lineage/test_query.py
@@ -580,7 +580,7 @@ def test_query_lineage_result_str(sagemaker_session):
assert (
response_str
- == "{\n'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}],"
+ == "{'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}],"
+ "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', "
+ "'_session': }, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': "
+ "'Model', '_session': }],\n\n'startarn': "