In [None]:
from sagemaker import get_execution_role
from sagemaker.pytorch import PyTorchModel  
from sagemaker.serializers import JSONSerializer
from sagemaker.deserializers import JSONDeserializer
from sagemaker.async_inference import AsyncInferenceConfig
from sagemaker.utils import name_from_base
import sagemaker
import boto3
import json

In [None]:
%store -r prefix
%store -r model_uncompressed_s3
%store -r bucket_name

### > setup uncompressed model info

In [None]:
model_data={
    'S3DataSource': {
        'S3Uri': model_uncompressed_s3,
        'S3DataType': 'S3Prefix',
        'CompressionType': 'None'
    }
}

### > Create a new PyTorch model

You can use this model to deploy to realtime or Async endpoint. This model is really slow, realtime endpoint will always error due to inference time out. Therefore we will use Async endpoint instead.

In [None]:
model_name = name_from_base(f"{prefix}-model")

model = PyTorchModel(
    model_data=model_data,
    framework_version="2.1",
    py_version="py310",
    role=get_execution_role(),
    env={
        'SAGEMAKER_TS_RESPONSE_TIMEOUT': '900',
        'SM_MODEL_DIR': '/opt/ml/model',
        'SAGEMAKER_PROGRAM': 'inference.py'
    },
    name=model_name
)

endpoint_name = name_from_base(f"{prefix}-endpoint")

predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.2xlarge",
    endpoint_name=endpoint_name,
)

### > Create async endpoitn with SageMaker SDK

In [None]:
from sagemaker.async_inference import AsyncInferenceConfig

# Create an AsyncInferenceConfig object
async_config = AsyncInferenceConfig(
    output_path=f"s3://{bucket_name}/{prefix}/output", 
    max_concurrent_invocations_per_instance = 2,
    # notification_config = {
            #   "SuccessTopic": "arn:aws:sns:us-east-2:123456789012:MyTopic",
            #   "ErrorTopic": "arn:aws:sns:us-east-2:123456789012:MyTopic",
    # }, #  Notification configuration 
)

In [None]:
# Deploy the model for async inference
endpoint_name = name_from_base(f"{prefix}-async-endpoint")

async_predictor = model.deploy(
    async_inference_config=async_config,
    instance_type="ml.g5.2xlarge",
    initial_instance_count=1,
    endpoint_name=endpoint_name,
    serializer=JSONSerializer(),
)

### > invoke async endpoint

In [None]:
import uuid

payload = {"text": "The lead engineer, a confident woman, stands before them, her presentation deck loaded.", 
            "voice_id": "female_voice", #male_voice, female_voice, adam, vladimire, swami
            "output_bucket": bucket_name,
            "output_key": f"{prefix}/output/wav_file/{str(uuid.uuid4())}.wav",
            "inference_params": {}}

In [None]:
response = async_predictor.predict_async(
    data=payload,
    initial_args={'ContentType': 'application/json'})
print(response.output_path)

### > Invoke in batch

In [None]:
for voice_id in ["male_voice", "female_voice", "adam", "vladimire", "swami"]:
    p = payload.copy()
    p["voice_id"] = voice_id
    p["output_key"] = f"{prefix}/output/wav_file/{str(uuid.uuid4())}.wav"

    
    response = async_predictor.predict_async(
        data=p,
        initial_args={'ContentType': 'application/json'}
    )
    print(p)