diff --git a/requirements/extras/local_requirements.txt b/requirements/extras/local_requirements.txt index 17512c3388..439f42010c 100644 --- a/requirements/extras/local_requirements.txt +++ b/requirements/extras/local_requirements.txt @@ -1,4 +1,4 @@ urllib3==1.26.8 docker-compose==1.29.2 docker~=5.0.0 -PyYAML==5.4.1 +PyYAML==5.4.1 \ No newline at end of file diff --git a/requirements/extras/test_requirements.txt b/requirements/extras/test_requirements.txt index 2247394441..7197147032 100644 --- a/requirements/extras/test_requirements.txt +++ b/requirements/extras/test_requirements.txt @@ -18,3 +18,4 @@ fabric==2.6.0 requests==2.27.1 sagemaker-experiments==0.1.35 Jinja2==3.0.3 +pyvis==0.2.1 diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index 5f3164bdc0..a46e88867a 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -364,7 +364,7 @@ def _get_visualization_elements(self): elements = {"nodes": verts, "edges": edges} return elements - def visualize(self): + def visualize(self, path="pyvisExample.html"): """Visualize lineage query result.""" lineage_graph = { # nodes can have shape / color @@ -398,7 +398,7 @@ def visualize(self): pyvis_vis = PyvisVisualizer(lineage_graph) elements = self._get_visualization_elements() - return pyvis_vis.render(elements=elements) + return pyvis_vis.render(elements=elements, path=path) class LineageFilter(object): diff --git a/tests/integ/sagemaker/lineage/conftest.py b/tests/integ/sagemaker/lineage/conftest.py index 5e201eef42..7450cc5935 100644 --- a/tests/integ/sagemaker/lineage/conftest.py +++ b/tests/integ/sagemaker/lineage/conftest.py @@ -19,6 +19,7 @@ import pytest import logging import uuid +import json from sagemaker.lineage import ( action, context, @@ -891,3 +892,17 @@ def _deploy_static_endpoint(execution_arn, sagemaker_session): pass else: raise (e) + + +@pytest.fixture +def extract_data_from_html(): + def _method(data): + start = data.find("[") + end = data.find("]") + res = data[start + 1 : end].split("}, ") + res = [i + "}" for i in res] + res[-1] = res[-1][:-1] + data_dict = [json.loads(i) for i in res] + return data_dict + + return _method diff --git a/tests/integ/sagemaker/lineage/helpers.py b/tests/integ/sagemaker/lineage/helpers.py index fb71d1d88c..3ab22ce332 100644 --- a/tests/integ/sagemaker/lineage/helpers.py +++ b/tests/integ/sagemaker/lineage/helpers.py @@ -78,3 +78,94 @@ def visit(arn, visited: set): ret = [] return visit(start_arn, set()) + + +class LineageResourceHelper: + def __init__(self, sagemaker_session): + self.client = sagemaker_session.sagemaker_client + self.artifacts = [] + self.actions = [] + self.contexts = [] + self.associations = [] + + def create_artifact(self, artifact_name, artifact_type="Dataset"): + response = self.client.create_artifact( + ArtifactName=artifact_name, + Source={ + "SourceUri": "Test-artifact-" + artifact_name, + "SourceTypes": [ + {"SourceIdType": "S3ETag", "Value": "Test-artifact-sourceId-value"}, + ], + }, + ArtifactType=artifact_type, + ) + self.artifacts.append(response["ArtifactArn"]) + + return response["ArtifactArn"] + + def create_action(self, action_name, action_type="ModelDeployment"): + response = self.client.create_action( + ActionName=action_name, + Source={ + "SourceUri": "Test-action-" + action_name, + "SourceType": "S3ETag", + "SourceId": "Test-action-sourceId-value", + }, + ActionType=action_type, + ) + self.actions.append(response["ActionArn"]) + + return response["ActionArn"] + + def create_context(self, context_name, context_type="Endpoint"): + response = self.client.create_context( + ContextName=context_name, + Source={ + "SourceUri": "Test-context-" + context_name, + "SourceType": "S3ETag", + "SourceId": "Test-context-sourceId-value", + }, + ContextType=context_type, + ) + self.contexts.append(response["ContextArn"]) + + return response["ContextArn"] + + def create_association(self, source_arn, dest_arn, association_type="AssociatedWith"): + response = self.client.add_association( + SourceArn=source_arn, DestinationArn=dest_arn, AssociationType=association_type + ) + if "SourceArn" in response.keys(): + self.associations.append((source_arn, dest_arn)) + return True + else: + return False + + def clean_all(self): + # clean all lineage data created by LineageResourceHelper + + time.sleep(1) # avoid GSI race condition between create & delete + + for source, dest in self.associations: + try: + self.client.delete_association(SourceArn=source, DestinationArn=dest) + except Exception as e: + print("skipped " + str(e)) + + for artifact_arn in self.artifacts: + try: + self.client.delete_artifact(ArtifactArn=artifact_arn) + except Exception as e: + print("skipped " + str(e)) + + for action_arn in self.actions: + try: + self.client.delete_action(ActionArn=action_arn) + except Exception as e: + print("skipped " + str(e)) + + for context_arn in self.contexts: + try: + self.client.delete_context(ContextArn=context_arn) + except Exception as e: + print("skipped " + str(e)) diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py new file mode 100644 index 0000000000..d9b3f879a8 --- /dev/null +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -0,0 +1,237 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License. +"""This module contains code to test SageMaker ``LineageQueryResult.visualize()``""" +from __future__ import absolute_import +import time +import os + +import pytest + +import sagemaker.lineage.query +from sagemaker.lineage.query import LineageQueryDirectionEnum +from tests.integ.sagemaker.lineage.helpers import name, LineageResourceHelper + + +def test_LineageResourceHelper(sagemaker_session): + # check if LineageResourceHelper works properly + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + try: + art1 = lineage_resource_helper.create_artifact(artifact_name=name()) + art2 = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association(source_arn=art1, dest_arn=art2) + except Exception as e: + print(e) + assert False + finally: + lineage_resource_helper.clean_all() + + +@pytest.mark.skip("visualizer load test") +def test_wide_graph_visualize(sagemaker_session): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + wide_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + + # create wide graph + # Artifact ----> Artifact + # \ \ \-> Artifact + # \ \--> Artifact + # \---> ... + try: + for i in range(150): + artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association( + source_arn=wide_graph_root_arn, dest_arn=artifact_arn + ) + + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query(start_arns=[wide_graph_root_arn]) + lq_result.visualize(path="wideGraph.html") + + print("vertex len = ") + print(len(lq_result.vertices)) + assert False + + except Exception as e: + print(e) + assert False + + finally: + lineage_resource_helper.clean_all() + + +@pytest.mark.skip("visualizer load test") +def test_long_graph_visualize(sagemaker_session): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + long_graph_root_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + last_arn = long_graph_root_arn + + # create long graph + # Artifact -> Artifact -> ... -> Artifact + try: + for i in range(10): + new_artifact_arn = lineage_resource_helper.create_artifact(artifact_name=name()) + lineage_resource_helper.create_association( + source_arn=last_arn, dest_arn=new_artifact_arn + ) + last_arn = new_artifact_arn + + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query( + start_arns=[long_graph_root_arn], direction=LineageQueryDirectionEnum.DESCENDANTS + ) + # max depth = 10 -> graph rendered only has length of ten (in DESCENDANTS direction) + lq_result.visualize(path="longGraph.html") + + except Exception as e: + print(e) + assert False + + finally: + lineage_resource_helper.clean_all() + + +def test_graph_visualize(sagemaker_session, extract_data_from_html): + lineage_resource_helper = LineageResourceHelper(sagemaker_session=sagemaker_session) + + # create lineage data + # image artifact ------> model artifact(startarn) -> model deploy action -> endpoint context + # /-> + # dataset artifact -/ + try: + graph_startarn = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="Model" + ) + image_artifact = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="Image" + ) + lineage_resource_helper.create_association( + source_arn=image_artifact, dest_arn=graph_startarn, association_type="ContributedTo" + ) + dataset_artifact = lineage_resource_helper.create_artifact( + artifact_name=name(), artifact_type="DataSet" + ) + lineage_resource_helper.create_association( + source_arn=dataset_artifact, dest_arn=graph_startarn, association_type="AssociatedWith" + ) + modeldeploy_action = lineage_resource_helper.create_action( + action_name=name(), action_type="ModelDeploy" + ) + lineage_resource_helper.create_association( + source_arn=graph_startarn, dest_arn=modeldeploy_action, association_type="ContributedTo" + ) + endpoint_context = lineage_resource_helper.create_context( + context_name=name(), context_type="Endpoint" + ) + lineage_resource_helper.create_association( + source_arn=modeldeploy_action, + dest_arn=endpoint_context, + association_type="AssociatedWith", + ) + time.sleep(3) + + # visualize + lq = sagemaker.lineage.query.LineageQuery(sagemaker_session) + lq_result = lq.query(start_arns=[graph_startarn]) + lq_result.visualize(path="testGraph.html") + + # check generated graph info + fo = open("testGraph.html", "r") + lines = fo.readlines() + for line in lines: + if "nodes = " in line: + node = line + if "edges = " in line: + edge = line + + node_dict = extract_data_from_html(node) + edge_dict = extract_data_from_html(edge) + + # check node number + assert len(node_dict) == 5 + + expected_nodes = { + graph_startarn: { + "color": "#146eb4", + "label": "Model", + "shape": "star", + "title": "Artifact", + }, + image_artifact: { + "color": "#146eb4", + "label": "Image", + "shape": "dot", + "title": "Artifact", + }, + dataset_artifact: { + "color": "#146eb4", + "label": "DataSet", + "shape": "dot", + "title": "Artifact", + }, + modeldeploy_action: { + "color": "#88c396", + "label": "ModelDeploy", + "shape": "dot", + "title": "Action", + }, + endpoint_context: { + "color": "#ff9900", + "label": "Endpoint", + "shape": "dot", + "title": "Context", + }, + } + + # check node properties + for node in node_dict: + for label, val in expected_nodes[node["id"]].items(): + assert node[label] == val + + # check edge number + assert len(edge_dict) == 4 + + expected_edges = { + image_artifact: { + "from": image_artifact, + "to": graph_startarn, + "title": "ContributedTo", + }, + dataset_artifact: { + "from": dataset_artifact, + "to": graph_startarn, + "title": "AssociatedWith", + }, + graph_startarn: { + "from": graph_startarn, + "to": modeldeploy_action, + "title": "ContributedTo", + }, + modeldeploy_action: { + "from": modeldeploy_action, + "to": endpoint_context, + "title": "AssociatedWith", + }, + } + + # check edge properties + for edge in edge_dict: + for label, val in expected_edges[edge["from"]].items(): + assert edge[label] == val + + except Exception as e: + print(e) + assert False + + finally: + lineage_resource_helper.clean_all() + os.remove("testGraph.html")