Skip to content

Commit

Permalink
Fix: Add Image URI overrides for transformers models (aws#4693)
Browse files Browse the repository at this point in the history
* Fix: Add Image URI overrides for transformers models

* Increase coverage

* Fix formatting
  • Loading branch information
samruds authored and jiapinw committed Jun 25, 2024
1 parent 32917ac commit 55c4600
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 6 deletions.
28 changes: 23 additions & 5 deletions src/sagemaker/serve/builder/transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,25 @@ def _prepare_for_mode(self):
"""Abstract method"""

def _create_transformers_model(self) -> Type[Model]:
"""Initializes HF model with or without image_uri"""
if self.image_uri is None:
pysdk_model = self._get_hf_metadata_create_model()
else:
pysdk_model = HuggingFaceModel(
image_uri=self.image_uri,
vpc_config=self.vpc_config,
env=self.env_vars,
role=self.role_arn,
sagemaker_session=self.sagemaker_session,
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
return pysdk_model

def _get_hf_metadata_create_model(self) -> Type[Model]:
"""Initializes the model after fetching image
1. Get the metadata for deciding framework
Expand Down Expand Up @@ -132,22 +151,21 @@ def _create_transformers_model(self) -> Type[Model]:
vpc_config=self.vpc_config,
)

if not self.image_uri and self.mode == Mode.LOCAL_CONTAINER:
if self.mode == Mode.LOCAL_CONTAINER:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, "local"
)
elif not self.image_uri:
else:
self.image_uri = pysdk_model.serving_image_uri(
self.sagemaker_session.boto_region_name, self.instance_type
)

logger.info("Detected %s. Proceeding with the the deployment.", self.image_uri)
if pysdk_model is None or self.image_uri is None:
raise ValueError("PySDK model unable to be created, try overriding image_uri")

if not pysdk_model.image_uri:
pysdk_model.image_uri = self.image_uri

self._original_deploy = pysdk_model.deploy
pysdk_model.deploy = self._transformers_model_builder_deploy_wrapper
return pysdk_model

@_capture_telemetry("transformers.deploy")
Expand Down
2 changes: 1 addition & 1 deletion tests/integ/sagemaker/serve/test_serve_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,4 @@ def test_pytorch_transformers_sagemaker_endpoint(
logger.exception(caught_ex)
assert (
False
), f"{caught_ex} was thrown when running pytorch transformers sagemaker endpoint test"
), f"{caught_ex} thrown when running pytorch transformers sagemaker endpoint test"
26 changes: 26 additions & 0 deletions tests/unit/sagemaker/serve/builder/test_transformers_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,29 @@ def test_image_uri_override(

with self.assertRaises(ValueError) as _:
model.deploy(mode=Mode.IN_PROCESS)

@patch("sagemaker.serve.builder.model_builder.ModelBuilder._build_for_transformers")
@patch(
"sagemaker.serve.builder.transformers_builder._get_nb_instance",
return_value="ml.g5.24xlarge",
)
@patch("sagemaker.serve.builder.transformers_builder._capture_telemetry", side_effect=None)
@patch(
"sagemaker.huggingface.llm_utils.get_huggingface_model_metadata",
return_value=None,
)
def test_failure_hf_md(
self, mock_model_md, mock_get_nb_instance, mock_telemetry, mock_build_for_transformers
):
builder = ModelBuilder(
model=mock_model_id,
schema_builder=mock_schema_builder,
mode=Mode.LOCAL_CONTAINER,
)

builder._prepare_for_mode = MagicMock()
builder._prepare_for_mode.side_effect = None

builder.build()

mock_build_for_transformers.assert_called_once()

0 comments on commit 55c4600

Please sign in to comment.