diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 0acdccd698..f6966d9fa3 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -230,7 +230,32 @@ class PyvisVisualizer(object): """Create object used for visualizing graph using Pyvis library.""" def __init__(self, graph_styles): - """Init for PyvisVisualizer.""" + """Init for PyvisVisualizer. + + Args: + graph_styles: A dictionary that contains graph style for node and edges by their type. + Example: Display the nodes with different color by their lineage entity / different + shape by start arn. + lineage_graph = { + "TrialComponent": { + "name": "Trial Component", + "style": {"background-color": "#f6cf61"}, + "isShape": "False", + }, + "Context": { + "name": "Context", + "style": {"background-color": "#ff9900"}, + "isShape": "False", + }, + "StartArn": { + "name": "StartArn", + "style": {"shape": "star"}, + "isShape": "True", + "symbol": "★", # shape symbol for legend + }, + } + + """ # import visualization packages ( self.Network, @@ -283,7 +308,23 @@ def _node_color(self, entity): return self.graph_styles[entity]["style"]["background-color"] def render(self, elements, path="pyvisExample.html"): - """Render graph for lineage query result.""" + """Render graph for lineage query result. + + Args: + elements: A dictionary that contains the node and the edges of the graph. + Example: + elements["nodes"] contains list of tuples, each tuple represents a node + format: (node arn, node lineage source, node lineage entity, + node is start arn) + elements["edges"] contains list of tuples, each tuple represents an edge + format: (edge source arn, edge destination arn, edge association type) + + path(optional): The path/filemname of the rendered graph html file. (default path: "pyvisExample.html") + + Returns: + display graph: The interactive visualization is presented as a static HTML file. + + """ net = self.Network(height="500px", width="100%", notebook=True, directed=True) net.set_options(self._options) @@ -384,7 +425,19 @@ def _get_visualization_elements(self): return elements def visualize(self, path="pyvisExample.html"): - """Visualize lineage query result.""" + """Visualize lineage query result. + + Creates a PyvisVisualizer object to render network graph with Pyvis library. + Pyvis library should be installed before using this method (run "pip install pyvis") + The elements(nodes & edges) are preprocessed in this method and sent to + PyvisVisualizer for rendering graph. + + Args: + path(optional): The path/filemname of the rendered graph html file. (default path: "pyvisExample.html") + + Returns: + display graph: The interactive visualization is presented as a static HTML file. + """ lineage_graph = { # nodes can have shape / color "TrialComponent": {