diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a198d1ebf5..1ea7baecb6 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -204,10 +204,18 @@ def _artifact_to_lineage_object(self): class DashVisualizer(object): """Create object used for visualizing graph using Dash library.""" - def __init__(self): + def __init__(self, graph_styles): """Init for DashVisualizer.""" # import visualization packages - self.cyto, self.JupyterDash, self.html = self._import_visual_modules() + ( + self.cyto, + self.JupyterDash, + self.html, + self.Input, + self.Output, + ) = self._import_visual_modules() + + self.graph_styles = graph_styles def _import_visual_modules(self): """Import modules needed for visualization.""" @@ -215,21 +223,70 @@ def _import_visual_modules(self): import dash_cytoscape as cyto except ImportError as e: print(e) - print("try pip install dash-cytoscape") + print("Try: pip install dash-cytoscape") + raise try: from jupyter_dash import JupyterDash except ImportError as e: print(e) - print("try pip install jupyter-dash") + print("Try: pip install jupyter-dash") + raise try: from dash import html except ImportError as e: print(e) - print("try pip install dash") + print("Try: pip install dash") + raise + + try: + from dash.dependencies import Input, Output + except ImportError as e: + print(e) + print("Try: pip install dash") + raise + + return cyto, JupyterDash, html, Input, Output + + def _create_legend_component(self, style): + """Create legend component div.""" + text = style["name"] + symbol = "" + color = "#ffffff" + if style["isShape"] == "False": + color = style["style"]["background-color"] + else: + symbol = style["symbol"] + return self.html.Div( + [ + self.html.Div( + symbol, + style={ + "background-color": color, + "width": "1.5vw", + "height": "1.5vw", + "display": "inline-block", + "font-size": "1.5vw", + }, + ), + self.html.Div( + style={ + "width": "0.5vw", + "height": "1.5vw", + "display": "inline-block", + } + ), + self.html.Div( + text, + style={"display": "inline-block", "font-size": "1.5vw"}, + ), + ] + ) - return cyto, JupyterDash, html + def _create_entity_selector(self, entity_name, style): + """Create selector for each lineage entity.""" + return {"selector": "." + entity_name, "style": style["style"]} def _get_app(self, elements): """Create JupyterDash app for interactivity on Jupyter notebook.""" @@ -238,10 +295,17 @@ def _get_app(self, elements): app.layout = self.html.Div( [ + # graph section self.cyto.Cytoscape( - id="cytoscape-layout-1", + id="cytoscape-graph", elements=elements, - style={"width": "100%", "height": "350px"}, + style={ + "width": "84%", + "height": "350px", + "display": "inline-block", + "border-width": "1vw", + "border-color": "#232f3e", + }, layout={"name": "klay"}, stylesheet=[ { @@ -251,6 +315,10 @@ def _get_app(self, elements): "font-size": "3.5vw", "height": "10vw", "width": "10vw", + "border-width": "0.8", + "border-opacity": "0", + "border-color": "#232f3e", + "font-family": "verdana", }, }, { @@ -259,23 +327,61 @@ def _get_app(self, elements): "label": "data(label)", "color": "gray", "text-halign": "left", - "text-margin-y": "3px", - "text-margin-x": "-2px", - "font-size": "3%", - "width": "1%", - "curve-style": "taxi", + "text-margin-y": "2.5", + "font-size": "3", + "width": "1", + "curve-style": "bezier", + "control-point-step-size": "15", "target-arrow-color": "gray", "target-arrow-shape": "triangle", "line-color": "gray", "arrow-scale": "0.5", + "font-family": "verdana", }, }, - ], + {"selector": ".select", "style": {"border-opacity": "0.7"}}, + ] + + [self._create_entity_selector(k, v) for k, v in self.graph_styles.items()], responsive=True, - ) + ), + self.html.Div( + style={ + "width": "0.5%", + "display": "inline-block", + "font-size": "1vw", + "font-family": "verdana", + "vertical-align": "top", + }, + ), + # legend section + self.html.Div( + [self._create_legend_component(v) for k, v in self.graph_styles.items()], + style={ + "display": "inline-block", + "font-size": "1vw", + "font-family": "verdana", + "vertical-align": "top", + }, + ), ] ) + @app.callback( + self.Output("cytoscape-graph", "elements"), + self.Input("cytoscape-graph", "tapNodeData"), + self.Input("cytoscape-graph", "elements"), + ) + def selectNode(tapData, elements): + for n in elements: + if tapData is not None and n["data"]["id"] == tapData["id"]: + # if is tapped node, add "select" class to node + n["classes"] += " select" + elif "classes" in n: + # remove "select" class in "classes" if node not selected + n["classes"] = n["classes"].replace("select", "") + + return elements + return app def render(self, elements, mode): @@ -292,6 +398,7 @@ def __init__( self, edges: List[Edge] = None, vertices: List[Vertex] = None, + startarn: List[str] = None, ): """Init for LineageQueryResult. @@ -301,6 +408,7 @@ def __init__( """ self.edges = [] self.vertices = [] + self.startarn = [] if edges is not None: self.edges = edges @@ -308,56 +416,67 @@ def __init__( if vertices is not None: self.vertices = vertices + if startarn is not None: + self.startarn = startarn + def __str__(self): """Define string representation of ``LineageQueryResult``. Format: { 'edges':[ - { + "{ 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' - }, + }", ... - ] + ], 'vertices':[ - { + "{ 'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', '_session': - }, + }", + ... + ], + 'startarn':[ + 'string', ... ] } """ result_dict = vars(self) - return str({k: [vars(val) for val in v] for k, v in result_dict.items()}) + return str({k: [str(val) for val in v] for k, v in result_dict.items()}) def _covert_vertices_to_tuples(self): """Convert vertices to tuple format for visualizer.""" verts = [] + # get vertex info in the form of (id, label, class) for vert in self.vertices: - verts.append((vert.arn, vert.lineage_source)) + if vert.arn in self.startarn: + # add "startarn" class to node if arn is a startarn + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity + " startarn")) + else: + verts.append((vert.arn, vert.lineage_source, vert.lineage_entity)) return verts def _covert_edges_to_tuples(self): """Convert edges to tuple format for visualizer.""" edges = [] + # get edge info in the form of (source, target, label) for edge in self.edges: edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) return edges def _get_visualization_elements(self): """Get elements for visualization.""" + # get vertices and edges info for graph verts = self._covert_vertices_to_tuples() edges = self._covert_edges_to_tuples() nodes = [ - { - "data": {"id": id, "label": label}, - } - for id, label in verts + {"data": {"id": id, "label": label}, "classes": classes} for id, label, classes in verts ] edges = [ @@ -373,7 +492,38 @@ def visualize(self): """Visualize lineage query result.""" elements = self._get_visualization_elements() - dash_vis = DashVisualizer() + lineage_graph = { + # nodes can have shape / color + "TrialComponent": { + "name": "Trial Component", + "style": {"background-color": "#f6cf61"}, + "isShape": "False", + }, + "Context": { + "name": "Context", + "style": {"background-color": "#ff9900"}, + "isShape": "False", + }, + "Action": { + "name": "Action", + "style": {"background-color": "#88c396"}, + "isShape": "False", + }, + "Artifact": { + "name": "Artifact", + "style": {"background-color": "#146eb4"}, + "isShape": "False", + }, + "StartArn": { + "name": "StartArn", + "style": {"shape": "star"}, + "isShape": "True", + "symbol": "★", # shape symbol for legend + }, + } + + # initialize DashVisualizer instance to render graph & interactive components + dash_vis = DashVisualizer(lineage_graph) dash_server = dash_vis.render(elements=elements, mode="inline") @@ -453,9 +603,8 @@ def _get_vertex(self, vertex): sagemaker_session=self._session, ) - def _convert_api_response(self, response) -> LineageQueryResult: + def _convert_api_response(self, response, converted) -> LineageQueryResult: """Convert the lineage query API response to its Python representation.""" - converted = LineageQueryResult() converted.edges = [self._get_edge(edge) for edge in response["Edges"]] converted.vertices = [self._get_vertex(vertex) for vertex in response["Vertices"]] @@ -538,7 +687,9 @@ def query( Filters=query_filter._to_request_dict() if query_filter else {}, MaxDepth=max_depth, ) - query_response = self._convert_api_response(query_response) + # create query result for startarn info + query_result = LineageQueryResult(startarn=start_arns) + query_response = self._convert_api_response(query_response, query_result) query_response = self._collapse_cross_account_artifacts(query_response) return query_response