diff --git a/src/sagemaker/jumpstart/accessors.py b/src/sagemaker/jumpstart/accessors.py index dfc833ec28..c9f805c225 100644 --- a/src/sagemaker/jumpstart/accessors.py +++ b/src/sagemaker/jumpstart/accessors.py @@ -257,6 +257,7 @@ def get_model_specs( hub_arn: Optional[str] = None, s3_client: Optional[boto3.client] = None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, ) -> JumpStartModelSpecs: """Returns model specs from JumpStart models cache. diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index 0d93ffb76b..75317b6784 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -39,6 +39,7 @@ get_wildcard_model_version_msg, get_wildcard_proprietary_model_version_msg, ) +from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg from sagemaker.jumpstart.parameters import ( JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS, JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS, diff --git a/src/sagemaker/jumpstart/constants.py b/src/sagemaker/jumpstart/constants.py index 21412d65ea..1b679d44f6 100644 --- a/src/sagemaker/jumpstart/constants.py +++ b/src/sagemaker/jumpstart/constants.py @@ -172,8 +172,7 @@ JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json" JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json" -# works cross-partition -HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$" +HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$" HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$" INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py" diff --git a/src/sagemaker/jumpstart/estimator.py b/src/sagemaker/jumpstart/estimator.py index 6406932924..4cdd540111 100644 --- a/src/sagemaker/jumpstart/estimator.py +++ b/src/sagemaker/jumpstart/estimator.py @@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook(): model_version=model_version, hub_arn=hub_arn, model_type=self.model_type, + hub_arn=hub_arn, tolerate_vulnerable_model=tolerate_vulnerable_model, tolerate_deprecated_model=tolerate_deprecated_model, role=role, diff --git a/src/sagemaker/jumpstart/factory/estimator.py b/src/sagemaker/jumpstart/factory/estimator.py index fb598256fa..fa04b46a7c 100644 --- a/src/sagemaker/jumpstart/factory/estimator.py +++ b/src/sagemaker/jumpstart/factory/estimator.py @@ -81,6 +81,7 @@ def get_init_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, tolerate_vulnerable_model: Optional[bool] = None, tolerate_deprecated_model: Optional[bool] = None, region: Optional[str] = None, @@ -140,6 +141,7 @@ def get_init_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, role=role, region=region, instance_count=instance_count, diff --git a/src/sagemaker/jumpstart/factory/model.py b/src/sagemaker/jumpstart/factory/model.py index 6f7a83cef1..273257088e 100644 --- a/src/sagemaker/jumpstart/factory/model.py +++ b/src/sagemaker/jumpstart/factory/model.py @@ -550,6 +550,7 @@ def get_deploy_kwargs( model_version: Optional[str] = None, hub_arn: Optional[str] = None, model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS, + hub_arn: Optional[str] = None, region: Optional[str] = None, initial_instance_count: Optional[int] = None, instance_type: Optional[str] = None, @@ -584,6 +585,7 @@ def get_deploy_kwargs( model_version=model_version, hub_arn=hub_arn, model_type=model_type, + hub_arn=hub_arn, region=region, initial_instance_count=initial_instance_count, instance_type=instance_type, diff --git a/src/sagemaker/jumpstart/types.py b/src/sagemaker/jumpstart/types.py index 82bfc5c13f..84fe68c0ef 100644 --- a/src/sagemaker/jumpstart/types.py +++ b/src/sagemaker/jumpstart/types.py @@ -1420,6 +1420,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "initial_instance_count", "instance_type", "region", @@ -1453,6 +1454,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1499,6 +1501,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.initial_instance_count = initial_instance_count self.instance_type = instance_type self.region = region @@ -1535,6 +1538,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "instance_type", "instance_count", "region", @@ -1596,6 +1600,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", } def __init__( @@ -1725,6 +1730,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "inputs", "wait", @@ -1741,6 +1747,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs): "model_version", "hub_arn", "model_type", + "hub_arn", "region", "tolerate_deprecated_model", "tolerate_vulnerable_model", @@ -1769,6 +1776,7 @@ def __init__( self.model_version = model_version self.hub_arn = hub_arn self.model_type = model_type + self.hub_arn = hub_arn self.region = region self.inputs = inputs self.wait = wait diff --git a/src/sagemaker/jumpstart/utils.py b/src/sagemaker/jumpstart/utils.py index 82b43e7d22..4fc8752625 100644 --- a/src/sagemaker/jumpstart/utils.py +++ b/src/sagemaker/jumpstart/utils.py @@ -14,8 +14,8 @@ from __future__ import absolute_import import logging import os -from typing import Any, Dict, List, Set, Optional, Tuple, Union import re +from typing import Any, Dict, List, Set, Optional, Tuple, Union from urllib.parse import urlparse import boto3 from packaging.version import Version diff --git a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py index 0f69cb572a..93d7098870 100644 --- a/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py +++ b/tests/unit/sagemaker/hyperparameters/jumpstart/test_validate.py @@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs): s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() @@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py index bd4383499d..45af6faeed 100644 --- a/tests/unit/sagemaker/image_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/image_uris/jumpstart/test_common.py @@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py index ea6835bec3..05c4df28f9 100644 --- a/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py +++ b/tests/unit/sagemaker/instance_types/jumpstart/test_instance_types.py @@ -123,9 +123,9 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode region=region, model_id=model_id, version=model_version, + model_type=JumpStartModelType.OPEN_WEIGHTS, hub_arn=None, s3_client=mock_client, - model_type=JumpStartModelType.OPEN_WEIGHTS, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py index 15b6d8fba3..0c2393b71a 100644 --- a/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py @@ -14,8 +14,9 @@ from unittest.mock import Mock from sagemaker.jumpstart.types import HubArnExtractedInfo -from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME from sagemaker.jumpstart.curated_hub import utils +from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME +from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo def test_get_info_from_hub_resource_arn(): diff --git a/tests/unit/sagemaker/jumpstart/test_accessors.py b/tests/unit/sagemaker/jumpstart/test_accessors.py index fb22287909..6ab71c3f53 100644 --- a/tests/unit/sagemaker/jumpstart/test_accessors.py +++ b/tests/unit/sagemaker/jumpstart/test_accessors.py @@ -137,6 +137,31 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache): > 0 ) +@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache") +def test_jumpstart_models_cache_get_model_specs(mock_cache): + mock_cache.get_specs = Mock() + mock_cache.get_hub_model = Mock() + model_id, version = "pytorch-ic-mobilenet-v2", "*" + region = "us-west-2" + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, model_id=model_id, version=version + ) + mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version) + mock_cache.get_hub_model.assert_not_called() + + accessors.JumpStartModelsAccessor.get_model_specs( + region=region, + model_id=model_id, + version=version, + hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub", + ) + mock_cache.get_hub_model.assert_called_once_with( + hub_model_arn=( + f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}" + ) + ) + # necessary because accessors is a static module reload(accessors) diff --git a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py index a5d1ee3ac2..ed7c870a0e 100644 --- a/tests/unit/sagemaker/jumpstart/test_notebook_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_notebook_utils.py @@ -751,4 +751,5 @@ def test_get_model_url( s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) diff --git a/tests/unit/sagemaker/jumpstart/test_utils.py b/tests/unit/sagemaker/jumpstart/test_utils.py index 472b2dfdd9..fa1a3fc72f 100644 --- a/tests/unit/sagemaker/jumpstart/test_utils.py +++ b/tests/unit/sagemaker/jumpstart/test_utils.py @@ -1214,35 +1214,6 @@ def test_mime_type_enum_from_str(): assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type -def test_extract_info_from_hub_content_arn(): - model_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2" - ) - assert utils.extract_info_from_hub_content_arn(model_arn) == ( - "MockHub", - "us-west-2", - "my-mock-model", - "1.0.2", - ) - - hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub" - assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None) - - invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "nonsense-string" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = "" - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - invalid_arn = ( - "arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0" - ) - assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None) - - class TestIsValidModelId(TestCase): @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest") @patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs") diff --git a/tests/unit/sagemaker/jumpstart/utils.py b/tests/unit/sagemaker/jumpstart/utils.py index 128e41a796..4e816b6b97 100644 --- a/tests/unit/sagemaker/jumpstart/utils.py +++ b/tests/unit/sagemaker/jumpstart/utils.py @@ -22,7 +22,7 @@ JUMPSTART_REGION_NAME_SET, ) from sagemaker.jumpstart.types import ( - HubDataType, + HubContentType, JumpStartCachedContentKey, JumpStartCachedContentValue, JumpStartModelSpecs, @@ -253,6 +253,9 @@ def patched_retrieval_function( model_type=JumpStartModelType.PROPRIETARY, ) ) + # TODO: Implement + if datatype == HubContentType.HUB: + return None raise ValueError(f"Bad value for datatype: {datatype}") diff --git a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py index 2bb327c26f..06587a2074 100644 --- a/tests/unit/sagemaker/model_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/model_uris/jumpstart/test_common.py @@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() diff --git a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py index 2a4d913a75..b2e055dd3c 100644 --- a/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py +++ b/tests/unit/sagemaker/resource_requirements/jumpstart/test_resource_requirements.py @@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_get_model_specs.reset_mock() diff --git a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py index 87364a16fc..14ad48082e 100644 --- a/tests/unit/sagemaker/script_uris/jumpstart/test_common.py +++ b/tests/unit/sagemaker/script_uris/jumpstart/test_common.py @@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -73,6 +74,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -93,6 +95,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once() @@ -113,6 +116,7 @@ def test_jumpstart_common_script_uri( s3_client=mock_client, hub_arn=None, model_type=JumpStartModelType.OPEN_WEIGHTS, + hub_arn=None, ) patched_verify_model_region_and_return_specs.assert_called_once()