Skip to content

Commit

Permalink
feat: jsch jumpstart estimator support (aws#4439)
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 13, 2024
1 parent 29718b4 commit d7c4307
Show file tree
Hide file tree
Showing 19 changed files with 65 additions and 35 deletions.
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,7 @@ def get_model_specs(
hub_arn: Optional[str] = None,
s3_client: Optional[boto3.client] = None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
get_wildcard_model_version_msg,
get_wildcard_proprietary_model_version_msg,
)
from sagemaker.jumpstart.exceptions import get_wildcard_model_version_msg
from sagemaker.jumpstart.parameters import (
JUMPSTART_DEFAULT_MAX_S3_CACHE_ITEMS,
JUMPSTART_DEFAULT_MAX_SEMANTIC_VERSION_CACHE_ITEMS,
Expand Down
3 changes: 1 addition & 2 deletions src/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY = "models_manifest.json"
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY = "proprietary-sdk-manifest.json"

# works cross-partition
HUB_MODEL_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub_content/(.*?)/Model/(.*?)/(.*?)$"
HUB_CONTENT_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub-content/(.*?)/(.*?)/(.*?)/(.*?)$"
HUB_ARN_REGEX = r"arn:(.*?):sagemaker:(.*?):(.*?):hub/(.*?)$"

INFERENCE_ENTRY_POINT_SCRIPT_NAME = "inference.py"
Expand Down
1 change: 1 addition & 0 deletions src/sagemaker/jumpstart/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,6 +534,7 @@ def _validate_model_id_and_get_type_hook():
model_version=model_version,
hub_arn=hub_arn,
model_type=self.model_type,
hub_arn=hub_arn,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
role=role,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def get_init_kwargs(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: Optional[JumpStartModelType] = JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: Optional[bool] = None,
tolerate_deprecated_model: Optional[bool] = None,
region: Optional[str] = None,
Expand Down Expand Up @@ -140,6 +141,7 @@ def get_init_kwargs(
model_version=model_version,
hub_arn=hub_arn,
model_type=model_type,
hub_arn=hub_arn,
role=role,
region=region,
instance_count=instance_count,
Expand Down
2 changes: 2 additions & 0 deletions src/sagemaker/jumpstart/factory/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def get_deploy_kwargs(
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
model_type: JumpStartModelType = JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
initial_instance_count: Optional[int] = None,
instance_type: Optional[str] = None,
Expand Down Expand Up @@ -584,6 +585,7 @@ def get_deploy_kwargs(
model_version=model_version,
hub_arn=hub_arn,
model_type=model_type,
hub_arn=hub_arn,
region=region,
initial_instance_count=initial_instance_count,
instance_type=instance_type,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"initial_instance_count",
"instance_type",
"region",
Expand Down Expand Up @@ -1453,6 +1454,7 @@ class JumpStartModelDeployKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"region",
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
Expand Down Expand Up @@ -1499,6 +1501,7 @@ def __init__(
self.model_version = model_version
self.hub_arn = hub_arn
self.model_type = model_type
self.hub_arn = hub_arn
self.initial_instance_count = initial_instance_count
self.instance_type = instance_type
self.region = region
Expand Down Expand Up @@ -1535,6 +1538,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"instance_type",
"instance_count",
"region",
Expand Down Expand Up @@ -1596,6 +1600,7 @@ class JumpStartEstimatorInitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
}

def __init__(
Expand Down Expand Up @@ -1725,6 +1730,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"region",
"inputs",
"wait",
Expand All @@ -1741,6 +1747,7 @@ class JumpStartEstimatorFitKwargs(JumpStartKwargs):
"model_version",
"hub_arn",
"model_type",
"hub_arn",
"region",
"tolerate_deprecated_model",
"tolerate_vulnerable_model",
Expand Down Expand Up @@ -1769,6 +1776,7 @@ def __init__(
self.model_version = model_version
self.hub_arn = hub_arn
self.model_type = model_type
self.hub_arn = hub_arn
self.region = region
self.inputs = inputs
self.wait = wait
Expand Down
2 changes: 1 addition & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
from __future__ import absolute_import
import logging
import os
from typing import Any, Dict, List, Set, Optional, Tuple, Union
import re
from typing import Any, Dict, List, Set, Optional, Tuple, Union
from urllib.parse import urlparse
import boto3
from packaging.version import Version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def add_options_to_hyperparameter(*largs, **kwargs):
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)

patched_get_model_specs.reset_mock()
Expand Down Expand Up @@ -516,6 +517,7 @@ def test_jumpstart_validate_all_hyperparameters(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)

patched_get_model_specs.reset_mock()
Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/image_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -78,6 +79,7 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -100,6 +102,7 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -122,6 +125,7 @@ def test_jumpstart_common_image_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ def test_jumpstart_instance_types(patched_get_model_specs, patched_validate_mode
region=region,
model_id=model_id,
version=model_version,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
s3_client=mock_client,
model_type=JumpStartModelType.OPEN_WEIGHTS,
)

patched_get_model_specs.reset_mock()
Expand Down
3 changes: 2 additions & 1 deletion tests/unit/sagemaker/jumpstart/curated_hub/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@

from unittest.mock import Mock
from sagemaker.jumpstart.types import HubArnExtractedInfo
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.curated_hub import utils
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME
from sagemaker.jumpstart.curated_hub.types import HubArnExtractedInfo


def test_get_info_from_hub_resource_arn():
Expand Down
25 changes: 25 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,31 @@ def test_jumpstart_proprietary_models_cache_get(mock_cache):
> 0
)

@patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._cache")
def test_jumpstart_models_cache_get_model_specs(mock_cache):
mock_cache.get_specs = Mock()
mock_cache.get_hub_model = Mock()
model_id, version = "pytorch-ic-mobilenet-v2", "*"
region = "us-west-2"

accessors.JumpStartModelsAccessor.get_model_specs(
region=region, model_id=model_id, version=version
)
mock_cache.get_specs.assert_called_once_with(model_id=model_id, semantic_version_str=version)
mock_cache.get_hub_model.assert_not_called()

accessors.JumpStartModelsAccessor.get_model_specs(
region=region,
model_id=model_id,
version=version,
hub_arn=f"arn:aws:sagemaker:{region}:123456789123:hub/my-mock-hub",
)
mock_cache.get_hub_model.assert_called_once_with(
hub_model_arn=(
f"arn:aws:sagemaker:{region}:123456789123:hub-content/my-mock-hub/Model/{model_id}/{version}"
)
)

# necessary because accessors is a static module
reload(accessors)

Expand Down
1 change: 1 addition & 0 deletions tests/unit/sagemaker/jumpstart/test_notebook_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,4 +751,5 @@ def test_get_model_url(
s3_client=DEFAULT_JUMPSTART_SAGEMAKER_SESSION.s3_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
29 changes: 0 additions & 29 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1214,35 +1214,6 @@ def test_mime_type_enum_from_str():
assert MIMEType.from_suffixed_type(mime_type_with_suffix) == mime_type


def test_extract_info_from_hub_content_arn():
model_arn = (
"arn:aws:sagemaker:us-west-2:000000000000:hub_content/MockHub/Model/my-mock-model/1.0.2"
)
assert utils.extract_info_from_hub_content_arn(model_arn) == (
"MockHub",
"us-west-2",
"my-mock-model",
"1.0.2",
)

hub_arn = "arn:aws:sagemaker:us-west-2:000000000000:hub/MockHub"
assert utils.extract_info_from_hub_content_arn(hub_arn) == ("MockHub", "us-west-2", None, None)

invalid_arn = "arn:aws:sagemaker:us-west-2:000000000000:endpoint/my-endpoint-123"
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)

invalid_arn = "nonsense-string"
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)

invalid_arn = ""
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)

invalid_arn = (
"arn:aws:sagemaker:us-west-2:000000000000:hub-content/MyHub/Notebook/my-notebook/1.0.0"
)
assert utils.extract_info_from_hub_content_arn(invalid_arn) == (None, None, None, None)


class TestIsValidModelId(TestCase):
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor._get_manifest")
@patch("sagemaker.jumpstart.utils.accessors.JumpStartModelsAccessor.get_model_specs")
Expand Down
5 changes: 4 additions & 1 deletion tests/unit/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
JUMPSTART_REGION_NAME_SET,
)
from sagemaker.jumpstart.types import (
HubDataType,
HubContentType,
JumpStartCachedContentKey,
JumpStartCachedContentValue,
JumpStartModelSpecs,
Expand Down Expand Up @@ -253,6 +253,9 @@ def patched_retrieval_function(
model_type=JumpStartModelType.PROPRIETARY,
)
)
# TODO: Implement
if datatype == HubContentType.HUB:
return None

raise ValueError(f"Bad value for datatype: {datatype}")

Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/model_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -73,6 +74,7 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -93,6 +95,7 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -113,6 +116,7 @@ def test_jumpstart_common_model_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def test_jumpstart_resource_requirements(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_get_model_specs.reset_mock()

Expand Down
4 changes: 4 additions & 0 deletions tests/unit/sagemaker/script_uris/jumpstart/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -73,6 +74,7 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -93,6 +95,7 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand All @@ -113,6 +116,7 @@ def test_jumpstart_common_script_uri(
s3_client=mock_client,
hub_arn=None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn=None,
)
patched_verify_model_region_and_return_specs.assert_called_once()

Expand Down

0 comments on commit d7c4307

Please sign in to comment.