Skip to content

Commit

Permalink
feat: jsch jumpstart estimator support (aws#4439)
Browse files Browse the repository at this point in the history
  • Loading branch information
bencrabtree committed Mar 13, 2024
1 parent 700c16d commit 2418598
Show file tree
Hide file tree
Showing 53 changed files with 1,167 additions and 98 deletions.
4 changes: 4 additions & 0 deletions src/sagemaker/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
include_aws_sdk_env_vars: bool = True,
Expand All @@ -46,6 +47,8 @@ def retrieve_default(
retrieve the default environment variables. (Default: None).
model_version (str): Optional. The version of the model for which to retrieve the
default environment variables. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -80,6 +83,7 @@ def retrieve_default(
return artifacts._retrieve_default_environment_variables(
model_id,
model_version,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
instance_type: Optional[str] = None,
include_container_hyperparameters: bool = False,
tolerate_vulnerable_model: bool = False,
Expand All @@ -46,6 +47,8 @@ def retrieve_default(
retrieve the default hyperparameters. (Default: None).
model_version (str): The version of the model for which to retrieve the
default hyperparameters. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
instance_type (str): An instance type to optionally supply in order to get hyperparameters
specific for the instance type.
include_container_hyperparameters (bool): ``True`` if the container hyperparameters
Expand Down Expand Up @@ -80,6 +83,7 @@ def retrieve_default(
return artifacts._retrieve_default_hyperparameters(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
instance_type=instance_type,
region=region,
include_container_hyperparameters=include_container_hyperparameters,
Expand All @@ -93,6 +97,7 @@ def validate(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
hyperparameters: Optional[dict] = None,
validation_mode: HyperparameterValidationMode = HyperparameterValidationMode.VALIDATE_PROVIDED,
tolerate_vulnerable_model: bool = False,
Expand All @@ -107,6 +112,8 @@ def validate(
(Default: None).
model_version (str): The version of the model for which to validate hyperparameters.
(Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
hyperparameters (dict): Hyperparameters to validate.
(Default: None).
validation_mode (HyperparameterValidationMode): Method of validation to use with
Expand Down Expand Up @@ -148,6 +155,7 @@ def validate(
return validate_hyperparameters(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
hyperparameters=hyperparameters,
validation_mode=validation_mode,
region=region,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def retrieve(
training_compiler_config=None,
model_id=None,
model_version=None,
hub_arn=None,
tolerate_vulnerable_model=False,
tolerate_deprecated_model=False,
sdk_version=None,
Expand Down Expand Up @@ -102,6 +103,8 @@ def retrieve(
(default: None).
model_version (str): The version of the JumpStart model for which to retrieve the
image URI (default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): ``True`` if vulnerable versions of model specifications
should be tolerated without an exception raised. If ``False``, raises an exception if
the script used by this version of the model has dependencies with known security
Expand Down Expand Up @@ -147,6 +150,7 @@ def retrieve(
model_id,
model_version,
image_scope,
hub_arn,
framework,
region,
version,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def retrieve_default(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -46,6 +47,8 @@ def retrieve_default(
retrieve the default instance type. (Default: None).
model_version (str): The version of the model for which to retrieve the
default instance type. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
scope (str): The model type, i.e. what it is used for.
Valid values: "training" and "inference".
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -82,6 +85,7 @@ def retrieve_default(
model_id,
model_version,
scope,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand All @@ -95,6 +99,7 @@ def retrieve(
region: Optional[str] = None,
model_id: Optional[str] = None,
model_version: Optional[str] = None,
hub_arn: Optional[str] = None,
scope: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -110,6 +115,8 @@ def retrieve(
retrieve the supported instance types. (Default: None).
model_version (str): The version of the model for which to retrieve the
supported instance types. (Default: None).
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand Down Expand Up @@ -145,6 +152,7 @@ def retrieve(
model_id,
model_version,
scope,
hub_arn,
region,
tolerate_vulnerable_model,
tolerate_deprecated_model,
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/jumpstart/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sagemaker.jumpstart.types import JumpStartModelHeader, JumpStartModelSpecs
from sagemaker.jumpstart.enums import JumpStartModelType
from sagemaker.jumpstart import cache
from sagemaker.jumpstart.curated_hub.utils import construct_hub_model_arn_from_inputs
from sagemaker.jumpstart.constants import JUMPSTART_DEFAULT_REGION_NAME


Expand Down Expand Up @@ -255,6 +256,7 @@ def get_model_specs(
version: str,
s3_client: Optional[boto3.client] = None,
model_type=JumpStartModelType.OPEN_WEIGHTS,
hub_arn: Optional[str] = None,
) -> JumpStartModelSpecs:
"""Returns model specs from JumpStart models cache.
Expand All @@ -274,6 +276,13 @@ def get_model_specs(
{**JumpStartModelsAccessor._cache_kwargs, **additional_kwargs}
)
JumpStartModelsAccessor._set_cache_and_region(region, cache_kwargs)

if hub_arn:
hub_model_arn = construct_hub_model_arn_from_inputs(
hub_arn=hub_arn, model_name=model_id, version=version
)
return JumpStartModelsAccessor._cache.get_hub_model(hub_model_arn=hub_model_arn)

return JumpStartModelsAccessor._cache.get_specs( # type: ignore
model_id=model_id, version_str=version, model_type=model_type
)
Expand Down
9 changes: 9 additions & 0 deletions src/sagemaker/jumpstart/artifacts/environment_variables.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
def _retrieve_default_environment_variables(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -47,6 +48,8 @@ def _retrieve_default_environment_variables(
retrieve the default environment variables.
model_version (str): Version of the JumpStart model for which to retrieve the
default environment variables.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve default environment variables.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -78,6 +81,7 @@ def _retrieve_default_environment_variables(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=script,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down Expand Up @@ -116,6 +120,7 @@ def _retrieve_default_environment_variables(
] = lambda instance_type: _retrieve_gated_model_uri_env_var_value(
model_id=model_id,
model_version=model_version,
hub_arn=hub_arn,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
tolerate_deprecated_model=tolerate_deprecated_model,
Expand Down Expand Up @@ -161,6 +166,7 @@ def _retrieve_default_environment_variables(
def _retrieve_gated_model_uri_env_var_value(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -174,6 +180,8 @@ def _retrieve_gated_model_uri_env_var_value(
retrieve the gated model env var URI.
model_version (str): Version of the JumpStart model for which to retrieve the
gated model env var URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve the gated model env var URI.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -204,6 +212,7 @@ def _retrieve_gated_model_uri_env_var_value(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/hyperparameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
def _retrieve_default_hyperparameters(
model_id: str,
model_version: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
include_container_hyperparameters: bool = False,
tolerate_vulnerable_model: bool = False,
Expand All @@ -44,6 +45,8 @@ def _retrieve_default_hyperparameters(
retrieve the default hyperparameters.
model_version (str): Version of the JumpStart model for which to retrieve the
default hyperparameters.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (str): Region for which to retrieve default hyperparameters.
(Default: None).
include_container_hyperparameters (bool): True if container hyperparameters
Expand Down Expand Up @@ -76,6 +79,7 @@ def _retrieve_default_hyperparameters(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/image_uris.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _retrieve_image_uri(
model_id: str,
model_version: str,
image_scope: str,
hub_arn: Optional[str] = None,
framework: Optional[str] = None,
region: Optional[str] = None,
version: Optional[str] = None,
Expand All @@ -57,6 +58,8 @@ def _retrieve_image_uri(
model_id (str): JumpStart model ID for which to retrieve image URI.
model_version (str): Version of the JumpStart model for which to retrieve
the image URI.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
image_scope (str): The image type, i.e. what it is used for.
Valid values: "training", "inference", "eia". If ``accelerator_type`` is set,
``image_scope`` is ignored.
Expand Down Expand Up @@ -110,6 +113,7 @@ def _retrieve_image_uri(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=image_scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/artifacts/incremental_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _model_supports_incremental_training(
model_id: str,
model_version: str,
region: Optional[str],
hub_arn: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
sagemaker_session: Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION,
Expand All @@ -43,6 +44,8 @@ def _model_supports_incremental_training(
support status for incremental training.
region (Optional[str]): Region for which to retrieve the
support status for incremental training.
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
specifications should be tolerated (exception not raised). If False, raises an
exception if the script used by this version of the model has dependencies with known
Expand All @@ -64,6 +67,7 @@ def _model_supports_incremental_training(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=JumpStartScriptScope.TRAINING,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
8 changes: 8 additions & 0 deletions src/sagemaker/jumpstart/artifacts/instance_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def _retrieve_default_instance_type(
model_id: str,
model_version: str,
scope: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -50,6 +51,8 @@ def _retrieve_default_instance_type(
default instance type.
scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve default instance type.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -82,6 +85,7 @@ def _retrieve_default_instance_type(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down Expand Up @@ -122,6 +126,7 @@ def _retrieve_instance_types(
model_id: str,
model_version: str,
scope: str,
hub_arn: Optional[str] = None,
region: Optional[str] = None,
tolerate_vulnerable_model: bool = False,
tolerate_deprecated_model: bool = False,
Expand All @@ -137,6 +142,8 @@ def _retrieve_instance_types(
supported instance types.
scope (str): The script type, i.e. what it is used for.
Valid values: "training" and "inference".
hub_arn (str): The arn of the SageMaker Hub for which to retrieve
model details from. (default: None).
region (Optional[str]): Region for which to retrieve supported instance types.
(Default: None).
tolerate_vulnerable_model (bool): True if vulnerable versions of model
Expand Down Expand Up @@ -169,6 +176,7 @@ def _retrieve_instance_types(
model_specs = verify_model_region_and_return_specs(
model_id=model_id,
version=model_version,
hub_arn=hub_arn,
scope=scope,
region=region,
tolerate_vulnerable_model=tolerate_vulnerable_model,
Expand Down
Loading

0 comments on commit 2418598

Please sign in to comment.