In [10]:
import boto3
import botocore
import sagemaker
from sagemaker import image_uris
import sys
import time
from sagemaker.huggingface import HuggingFaceModel, get_huggingface_llm_image_uri
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from botocore.exceptions import NoCredentialsError
from sagemaker.jumpstart.model import JumpStartModel

session = sagemaker.Session()
runtime = boto3.client("sagemaker-runtime")
client = boto3.client("sagemaker")
region = session.boto_region_name
prefix='spectra-test'
role = "AmazonSageMaker-ExecutionRole-20240618T160945"
account_id = boto3.client('sts').get_caller_identity().get('Account')
endpoint_name = f"{prefix}-endpoint"

In [6]:
session.default_bucket() 

'sagemaker-us-east-1-452706865406'

In [9]:
inference_component_name = f"{prefix}-inference-component"
print(f"Demo inference component name: {inference_component_name}:: endpoint_name={endpoint_name}")
variant_name = "AllTraffic"

Demo inference component name: spectra-test-inference-component:: endpoint_name=spectra-test-endpoint


In [11]:
# retrieve the llm image uri
hf_inference_dlc = get_huggingface_llm_image_uri("huggingface", version="0.9.3")

# print ecr image uri
print(f"llm image uri: {hf_inference_dlc}")

llm image uri: 763104351884.dkr.ecr.us-east-1.amazonaws.com/huggingface-pytorch-tgi-inference:2.0.1-tgi0.9.3-gpu-py39-cu118-ubuntu20.04


In [12]:
deployment_name = "sm"

flant5xxlmodel = {
    "Image": hf_inference_dlc,
    "Environment": {"HF_MODEL_ID": "google/flan-t5-xxl", "HF_TASK": "text-generation"},
}

# create SageMaker Model
client.create_model(
    ModelName=f"{deployment_name}-model-flan-t5-xxl",
    ExecutionRoleArn=f"arn:aws:iam::{account_id}:role/service-role/{role}",
    Containers=[flant5xxlmodel],
)

{'ModelArn': 'arn:aws:sagemaker:us-east-1:452706865406:model/sm-model-flan-t5-xxl',
 'ResponseMetadata': {'RequestId': '0cf7cc7e-5455-49b5-a9b0-bb8ed31eeadc',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '0cf7cc7e-5455-49b5-a9b0-bb8ed31eeadc',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '82',
   'date': 'Mon, 01 Jul 2024 10:19:39 GMT'},
  'RetryAttempts': 0}}

In [13]:
inference_component_name_flant5 = f"{prefix}-IC-flan-xxl"
variant_name = "AllTraffic"

client.create_inference_component(
    InferenceComponentName=inference_component_name_flant5,
    EndpointName=endpoint_name,
    VariantName=variant_name,
    Specification={
        "ModelName": f"{deployment_name}-model-flan-t5-xxl",
        "ComputeResourceRequirements": {
            "NumberOfAcceleratorDevicesRequired": 1,
            "NumberOfCpuCoresRequired": 1,
            "MinMemoryRequiredInMb": 1024,
        },
    },
    RuntimeConfig={"CopyCount": 1},
)

{'InferenceComponentArn': 'arn:aws:sagemaker:us-east-1:452706865406:inference-component/spectra-test-IC-flan-xxl',
 'ResponseMetadata': {'RequestId': '45efa9af-4754-4a9d-9187-1dbc56fc14f2',
  'HTTPStatusCode': 200,
  'HTTPHeaders': {'x-amzn-requestid': '45efa9af-4754-4a9d-9187-1dbc56fc14f2',
   'content-type': 'application/x-amz-json-1.1',
   'content-length': '113',
   'date': 'Mon, 01 Jul 2024 10:22:14 GMT'},
  'RetryAttempts': 0}}

In [15]:
while True:
    desc = client.describe_inference_component(
        InferenceComponentName=inference_component_name_flant5
    )
    status = desc["InferenceComponentStatus"]
    print(status)
    sys.stdout.flush()
    if status in ["InService", "Failed"]:
        break
    time.sleep(30)

Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
Creating
InService
