diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 72bde00a1a..a198d1ebf5 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -97,12 +97,12 @@ def __str__(self): Format: { - 'source_arn': 'string', 'destination_arn': 'string', + 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' } - + """ - return (str(self.__dict__)) + return str(self.__dict__) class Vertex: @@ -147,13 +147,13 @@ def __str__(self): Format: { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', '_session': } - + """ - return (str(self.__dict__)) + return str(self.__dict__) def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -201,6 +201,90 @@ def _artifact_to_lineage_object(self): return Artifact.load(artifact_arn=self.arn, sagemaker_session=self._session) +class DashVisualizer(object): + """Create object used for visualizing graph using Dash library.""" + + def __init__(self): + """Init for DashVisualizer.""" + # import visualization packages + self.cyto, self.JupyterDash, self.html = self._import_visual_modules() + + def _import_visual_modules(self): + """Import modules needed for visualization.""" + try: + import dash_cytoscape as cyto + except ImportError as e: + print(e) + print("try pip install dash-cytoscape") + + try: + from jupyter_dash import JupyterDash + except ImportError as e: + print(e) + print("try pip install jupyter-dash") + + try: + from dash import html + except ImportError as e: + print(e) + print("try pip install dash") + + return cyto, JupyterDash, html + + def _get_app(self, elements): + """Create JupyterDash app for interactivity on Jupyter notebook.""" + app = self.JupyterDash(__name__) + self.cyto.load_extra_layouts() + + app.layout = self.html.Div( + [ + self.cyto.Cytoscape( + id="cytoscape-layout-1", + elements=elements, + style={"width": "100%", "height": "350px"}, + layout={"name": "klay"}, + stylesheet=[ + { + "selector": "node", + "style": { + "label": "data(label)", + "font-size": "3.5vw", + "height": "10vw", + "width": "10vw", + }, + }, + { + "selector": "edge", + "style": { + "label": "data(label)", + "color": "gray", + "text-halign": "left", + "text-margin-y": "3px", + "text-margin-x": "-2px", + "font-size": "3%", + "width": "1%", + "curve-style": "taxi", + "target-arrow-color": "gray", + "target-arrow-shape": "triangle", + "line-color": "gray", + "arrow-scale": "0.5", + }, + }, + ], + responsive=True, + ) + ] + ) + + return app + + def render(self, elements, mode): + """Render graph for lineage query result.""" + app = self._get_app(elements) + + return app.run_server(mode=mode) + + class LineageQueryResult(object): """A wrapper around the results of a lineage query.""" @@ -226,29 +310,74 @@ def __init__( def __str__(self): """Define string representation of ``LineageQueryResult``. - + Format: { 'edges':[ { - 'source_arn': 'string', 'destination_arn': 'string', + 'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string' }, ... ] 'vertices':[ { - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', '_session': }, ... ] } - + """ result_dict = vars(self) - return (str({k: [vars(val) for val in v] for k, v in result_dict.items()})) + return str({k: [vars(val) for val in v] for k, v in result_dict.items()}) + + def _covert_vertices_to_tuples(self): + """Convert vertices to tuple format for visualizer.""" + verts = [] + for vert in self.vertices: + verts.append((vert.arn, vert.lineage_source)) + return verts + + def _covert_edges_to_tuples(self): + """Convert edges to tuple format for visualizer.""" + edges = [] + for edge in self.edges: + edges.append((edge.source_arn, edge.destination_arn, edge.association_type)) + return edges + + def _get_visualization_elements(self): + """Get elements for visualization.""" + verts = self._covert_vertices_to_tuples() + edges = self._covert_edges_to_tuples() + + nodes = [ + { + "data": {"id": id, "label": label}, + } + for id, label in verts + ] + + edges = [ + {"data": {"source": source, "target": target, "label": label}} + for source, target, label in edges + ] + + elements = nodes + edges + + return elements + + def visualize(self): + """Visualize lineage query result.""" + elements = self._get_visualization_elements() + + dash_vis = DashVisualizer() + + dash_server = dash_vis.render(elements=elements, mode="inline") + + return dash_server class LineageFilter(object):