diff --git a/src/sagemaker/lineage/query.py b/src/sagemaker/lineage/query.py index a46e88867a..0acdccd698 100644 --- a/src/sagemaker/lineage/query.py +++ b/src/sagemaker/lineage/query.py @@ -104,6 +104,18 @@ def __str__(self): """ return str(self.__dict__) + def __repr__(self): + """Define string representation of ``Edge``. + + Format: + { + 'source_arn': 'string', 'destination_arn': 'string', + 'association_type': 'string' + } + + """ + return "\n\t" + str(self.__dict__) + class Vertex: """A vertex for a lineage graph.""" @@ -155,6 +167,19 @@ def __str__(self): """ return str(self.__dict__) + def __repr__(self): + """Define string representation of ``Vertex``. + + Format: + { + 'arn': 'string', 'lineage_entity': 'string', + 'lineage_source': 'string', + '_session': + } + + """ + return "\n\t" + str(self.__dict__) + def to_lineage_object(self): """Convert the ``Vertex`` object to its corresponding lineage object. @@ -312,29 +337,23 @@ def __str__(self): Format: { 'edges':[ - "{ - 'source_arn': 'string', 'destination_arn': 'string', - 'association_type': 'string' - }", - ... - ], + {'source_arn': 'string', 'destination_arn': 'string', 'association_type': 'string'}, + ...], + 'vertices':[ - "{ - 'arn': 'string', 'lineage_entity': 'string', - 'lineage_source': 'string', - '_session': - }", - ... - ], - 'startarn':[ - 'string', - ... - ] + {'arn': 'string', 'lineage_entity': 'string', 'lineage_source': 'string', + '_session': }, + ...], + + 'startarn':['string', ...] } """ - result_dict = vars(self) - return str({k: [str(val) for val in v] for k, v in result_dict.items()}) + return ( + "{\n" + + "\n\n".join("'{}': {},".format(key, val) for key, val in self.__dict__.items()) + + "\n}" + ) def _covert_edges_to_tuples(self): """Convert edges to tuple format for visualizer.""" diff --git a/tests/integ/sagemaker/lineage/test_lineage_visualize.py b/tests/integ/sagemaker/lineage/test_lineage_visualize.py index d9b3f879a8..555b98452e 100644 --- a/tests/integ/sagemaker/lineage/test_lineage_visualize.py +++ b/tests/integ/sagemaker/lineage/test_lineage_visualize.py @@ -57,10 +57,6 @@ def test_wide_graph_visualize(sagemaker_session): lq_result = lq.query(start_arns=[wide_graph_root_arn]) lq_result.visualize(path="wideGraph.html") - print("vertex len = ") - print(len(lq_result.vertices)) - assert False - except Exception as e: print(e) assert False diff --git a/tests/unit/sagemaker/lineage/test_query.py b/tests/unit/sagemaker/lineage/test_query.py index ae76fd199c..0a357eb1fc 100644 --- a/tests/unit/sagemaker/lineage/test_query.py +++ b/tests/unit/sagemaker/lineage/test_query.py @@ -18,6 +18,7 @@ from sagemaker.lineage.lineage_trial_component import LineageTrialComponent from sagemaker.lineage.query import LineageEntityEnum, LineageSourceEnum, Vertex, LineageQuery import pytest +import re def test_lineage_query(sagemaker_session): @@ -524,3 +525,64 @@ def test_vertex_to_object_unconvertable(sagemaker_session): with pytest.raises(ValueError): vertex.to_lineage_object() + + +def test_get_visualization_elements(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + { + "Arn": "arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", + "Type": "Model", + "LineageType": "Context", + }, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + query_response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + elements = query_response._get_visualization_elements() + + assert elements["nodes"][0] == ("arn1", "Endpoint", "Artifact", False) + assert elements["nodes"][1] == ("arn2", "Model", "Context", False) + assert elements["nodes"][2] == ( + "arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext", + "Model", + "Context", + True, + ) + assert elements["edges"][0] == ("arn1", "arn2", "Produced") + + +def test_query_lineage_result_str(sagemaker_session): + lineage_query = LineageQuery(sagemaker_session) + sagemaker_session.sagemaker_client.query_lineage.return_value = { + "Vertices": [ + {"Arn": "arn1", "Type": "Endpoint", "LineageType": "Artifact"}, + {"Arn": "arn2", "Type": "Model", "LineageType": "Context"}, + ], + "Edges": [{"SourceArn": "arn1", "DestinationArn": "arn2", "AssociationType": "Produced"}], + } + + query_response = lineage_query.query( + start_arns=["arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext"] + ) + + response_str = query_response.__str__() + pattern = r"Mock id='\d*'" + replace = r"Mock id=''" + response_str = re.sub(pattern, replace, response_str) + + assert ( + response_str + == "{\n'edges': [\n\t{'source_arn': 'arn1', 'destination_arn': 'arn2', 'association_type': 'Produced'}]," + + "\n\n'vertices': [\n\t{'arn': 'arn1', 'lineage_entity': 'Artifact', 'lineage_source': 'Endpoint', " + + "'_session': }, \n\t{'arn': 'arn2', 'lineage_entity': 'Context', 'lineage_source': " + + "'Model', '_session': }],\n\n'startarn': " + + "['arn:aws:sagemaker:us-west-2:0123456789012:context/mycontext'],\n}" + )