# Part 1: Sagemaker Async Endpoint Creation

installing dependencies

In [None]:

!pip install boto3 sagemaker


initializing packages and variables we will user later

In [5]:
import sagemaker
from sagemaker.pytorch import PyTorchModel
import boto3
import datetime
import time
from time import strftime,gmtime
import json
import os
import urllib
import sys
import io

boto_session = boto3.session.Session()
sm_session = sagemaker.session.Session()
sm_client = boto_session.client("sagemaker")
sm_runtime = boto_session.client("sagemaker-runtime")
region = boto_session.region_name
sns_client = boto3.client('sns',region_name=region)

bucket="sm-ball-tracking-output-blobs"
prefix = 'async-inference'

print(region)
print(prefix)

ca-central-1
async-inference


## Sagemaker Inference IAM Role
This iam role is required by sagemaker endpoint to get the model data, send notifications and also to upload/download video files/labels on s3

#### Delete Role
In case you need to add some new permissions in the future, you can delete the existing iam role and recreate it, otherwise skip the deletion and recreation part.

In [24]:
import boto3
import json
from botocore.exceptions import ClientError

# Specify the role name
role_name = 'SageMaker-Role'

# Create an IAM client
iam_client = boto3.client('iam')

try:
    response = iam_client.list_attached_role_policies(RoleName=role_name)
    attached_policies = response['AttachedPolicies']
    # Detach policies
    for policy in attached_policies:
        policy_arn = policy['PolicyArn']
        iam_client.detach_role_policy(RoleName=role_name, PolicyArn=policy_arn)
        print(f"Detached policy: {policy_arn}")

# Delete inline policies
    iam_client.delete_role(RoleName=role_name)
except Exception as e:
    print(e)

Detached policy: arn:aws:iam::aws:policy/AmazonSNSFullAccess
Detached policy: arn:aws:iam::aws:policy/AmazonSageMakerFullAccess
Detached policy: arn:aws:iam::aws:policy/AmazonS3FullAccess


### Create Role

In [25]:
import boto3
import json
from botocore.exceptions import ClientError

# Specify the role name
role_name = 'SageMaker-Role'

# Managed policies for SageMaker
managed_policy_arns = [
    'arn:aws:iam::aws:policy/AmazonSageMakerFullAccess',
    'arn:aws:iam::aws:policy/AmazonS3FullAccess',
    'arn:aws:iam::aws:policy/AmazonSNSFullAccess'
]

# Create an IAM client
iam_client = boto3.client('iam')
    
# Create the role
try:
    assume_role_policy_document = {
        "Version": "2012-10-17",
        "Statement": [
            {
                "Effect": "Allow",
                "Principal": {
                    "Service": "sagemaker.amazonaws.com"
                },
                "Action": "sts:AssumeRole"
            }
        ]
    }
    
    create_role_response = iam_client.create_role(
        RoleName=role_name,
        AssumeRolePolicyDocument=json.dumps(assume_role_policy_document)
    )
    print("SageMaker role created successfully:", create_role_response['Role']['Arn'])
    
    # Attach managed policies to the role
    for policy_arn in managed_policy_arns:
        iam_client.attach_role_policy(
            RoleName=role_name,
            PolicyArn=policy_arn
        )
        print(f"Attached policy {policy_arn} to the role.")

except ClientError as e:
    if e.response['Error']['Code'] == 'EntityAlreadyExists':
        print("Role with the same name already exists.")
    else:
        print("Error creating SageMaker role:", e)


SageMaker role created successfully: arn:aws:iam::800241512715:role/SageMaker-Role
Attached policy arn:aws:iam::aws:policy/AmazonSageMakerFullAccess to the role.
Attached policy arn:aws:iam::aws:policy/AmazonS3FullAccess to the role.
Attached policy arn:aws:iam::aws:policy/AmazonSNSFullAccess to the role.


In [26]:
role_arn = create_role_response['Role']['Arn']

Just use initialize this role_arn with hardcoded if you do not want to delete nad recreate the iam role as in the code above. This will work too

In [15]:
role_arn="arn:aws:iam::800241512715:role/SageMaker-Role"

### Create Failure and Success Notification Topic For Inference Success or Failure

In [21]:
response = sns_client.create_topic(Name="Async-ML-ErrorTopic")
error_topic= response['TopicArn']
print(error_topic)

arn:aws:sns:ca-central-1:800241512715:Async-ML-ErrorTopic


In [22]:
response = sns_client.create_topic(Name="Async-ML-SuccessTopic")
success_topic = response['TopicArn']
print(success_topic)

arn:aws:sns:ca-central-1:800241512715:Async-ML-SuccessTopic


In [23]:
response = sns_client.list_topics()
topics = response["Topics"]
print(topics)

[{'TopicArn': 'arn:aws:sns:ca-central-1:800241512715:Async-ML-ErrorTopic'}, {'TopicArn': 'arn:aws:sns:ca-central-1:800241512715:Async-ML-SuccessTopic'}, {'TopicArn': 'arn:aws:sns:ca-central-1:800241512715:Test-VoD-v120-NotificationSnsTopicB941FD22-lz6dg1qRRgHh'}, {'TopicArn': 'arn:aws:sns:ca-central-1:800241512715:aws-controltower-SecurityNotifications'}, {'TopicArn': 'arn:aws:sns:ca-central-1:800241512715:test-new-videos'}]


## Subscription for notifications
change the email address to desired address you want the status notifications to be delivered. When you execute this cell. You should recieve a subcription confirmation email from aws. You must approve it by clicking on the link that aws provides you in the email. Otherwise you would not recieve any notifications. 

In [44]:


email_id = 'mfahadm8@gmail.com'
email_sub_1 = sns_client.subscribe(
    TopicArn=success_topic,
    Protocol='email',
    Endpoint=email_id)

email_sub_2 = sns_client.subscribe(
    TopicArn=error_topic,
    Protocol='email',
    Endpoint=email_id)


## Async Endpoint Creation Step By Step

#### Create/Update Model Artifacts in S3

If you update the inference code in future, please delete and recreate the model package (tar.gz) and then upload it to s3. The next two cells does the same job for you

In [1]:
!rm -f model.tar.gz

In [2]:
!tar -czvf model.tar.gz yolov7.pt code >> /dev/null 2>&1
!aws s3 cp model.tar.gz s3://sm-ball-tracking-inputs/models/ >> /dev/null 2>&1

###  Model Config Creation In Sagemaker

In [16]:
from sagemaker.image_uris import retrieve

deploy_instance_type = 'ml.c5.xlarge'
pytorch_inference_image_uri = retrieve('pytorch',
                                       region,
                                       version='1.12',
                                       py_version='py38',
                                       instance_type = deploy_instance_type,
                                       accelerator_type=None,
                                       image_scope='inference')
print(pytorch_inference_image_uri)

763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.12-cpu-py38


##### Optional- deleting the model if you want to recreate it

In [17]:
try:
    model_name = 'ball-tracking-yolo7'
    create_model_response = sm_client.delete_model( ModelName = model_name)
except:
    print("Model not found!")

Model not found!


In [18]:
container = pytorch_inference_image_uri
model_artifact="s3://sm-ball-tracking-inputs/models/model.tar.gz"
model_name = 'ball-tracking-yolo7'
print(container)
print(model_name)

create_model_response = sm_client.create_model(
    ModelName = model_name,
    ExecutionRoleArn = role_arn,
    PrimaryContainer = {
        'Image': container,
        'ModelDataUrl': model_artifact,
        'Mode': 'SingleModel',
        'Environment': {
            'SAGEMAKER_CONTAINER_LOG_LEVEL':'20',
            'SAGEMAKER_PROGRAM':'inference.py',
            'SAGEMAKER_REGION':region,
            'SAGEMAKER_SUBMIT_DIRECTORY':'/opt/ml/model/code',
            'TS_MAX_REQUEST_SIZE': '100000000', #default max request size is 6 Mb for torchserve, need to update it to support the 70 mb input payload
            'TS_MAX_RESPONSE_SIZE': '100000000',
            'TS_DEFAULT_RESPONSE_TIMEOUT': '1000'
        }
    },    
)

763104351884.dkr.ecr.ca-central-1.amazonaws.com/pytorch-inference:1.12-cpu-py38
ball-tracking-yolo7


### Endpoint Config Creation For Sagemaker

##### Optional- deleting the endpoint configuration if you want to recreate it

In [19]:
try:
    endpoint_config_name = "BallTrackingV7AsyncEndpointConfig"
    create_endpoint_config_response = sm_client.delete_endpoint_config(
    EndpointConfigName=endpoint_config_name
    )
except:
    print("Endpoint Configuratoin not found!")

Endpoint Configuratoin not found!


In [24]:
print(model_name)
endpoint_config_name = "BallTrackingV7AsyncEndpointConfig"
create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "VariantName": "variant1",
            "ModelName": model_name,
            "InstanceType": "ml.c5.xlarge",
            "InitialInstanceCount": 1
        }
    ],
    AsyncInferenceConfig={
        "OutputConfig": {
            "S3OutputPath": f"s3://{bucket}/{prefix}/output",
            #  Optionally specify Amazon SNS topics
            "NotificationConfig": {
              "SuccessTopic": success_topic,
              "ErrorTopic": error_topic,
            },
            "S3FailurePath": f"s3://{bucket}/{prefix}/failure",
        },
        "ClientConfig": {
            "MaxConcurrentInvocationsPerInstance": 2
        }
    }
)
print(f"Created EndpointConfig: {create_endpoint_config_response['EndpointConfigArn']}")

ball-tracking-yolo7
Created EndpointConfig: arn:aws:sagemaker:ca-central-1:800241512715:endpoint-config/balltrackingv7asyncendpointconfig


### Endpoint Creation

##### optional - deleting existing endpoint

In [25]:
try:
    endpoint_name = "ball-tracking-v7"
    sm_client.delete_endpoint(EndpointName=endpoint_name)
except Exception as e:
    print(e)

An error occurred (ValidationException) when calling the DeleteEndpoint operation: Could not find endpoint "arn:aws:sagemaker:ca-central-1:800241512715:endpoint/ball-tracking-v7".


In [26]:
endpoint_name = "ball-tracking-v7"
endpoint_config_name= "BallTrackingV7AsyncEndpointConfig"
create_endpoint_response = sm_client.create_endpoint(EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name)
print(f"Creating Endpoint: {create_endpoint_response['EndpointArn']}")

Creating Endpoint: arn:aws:sagemaker:ca-central-1:800241512715:endpoint/ball-tracking-v7


In [27]:
waiter = boto3.client('sagemaker').get_waiter('endpoint_in_service')
print("Waiting for endpoint to create...")
waiter.wait(EndpointName=endpoint_name)
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
print(f"Endpoint Status: {resp['EndpointStatus']}")

Waiting for endpoint to create...


Endpoint Status: InService


## Testing

In [28]:
from sagemaker.predictor_async import AsyncPredictor
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer
import uuid
import os

LABELS_BUCKET="sm-ball-tracking-output-labels"
VIDEO_BUCKET="sm-ball-tracking-output-blobs"
input_video_uri="s3://test-vod-v120-source71e471f1-5vcytwlc3m1b/test-videos/20200616_VB_trim.mp4"
inference_id=str(uuid.uuid4())
endpoint_name="ball-tracking-v7"
predictor=Predictor(endpoint_name=endpoint_name,sagemaker_session=sm_session,serializer=JSONSerializer())
async_predictor = AsyncPredictor(predictor=predictor)
s3_input_path_without_prefix = input_video_uri[len("s3://"):]
input_bucket_name, input_key = s3_input_path_without_prefix.split('/', 1)
input_base_file = os.path.basename(input_key)
input_base_filename= os.path.splitext(input_base_file)[0]

input_data = {
    'input_location': input_video_uri,
    'output_label_location':  "s3://"+LABELS_BUCKET+"/"+inference_id+"/"+input_base_filename+".csv",
    'output_video_location':  "s3://"+VIDEO_BUCKET+"/"+inference_id+"/"+input_base_file
}


input_s3_uri=f"s3://{bucket}/{prefix}/input/{inference_id}.json"
# Call the predict method to send the input data to the endpoint asynchronously
response = async_predictor.predict_async(data=input_data,input_path=input_s3_uri,inference_id=inference_id)


In [38]:

result=response.get_result().decode('utf-8')
message = json.loads(result)
if "Error" in message:
    message=message["Error"]
print(message)

{'output_label_location': 's3://sm-ball-tracking-output-labels/f88e140c-647a-4b54-890a-d956a29dd29c/20200616_VB_trim.csv', 'output_video_location': 's3://sm-ball-tracking-output-blobs/f88e140c-647a-4b54-890a-d956a29dd29c/20200616_VB_trim.mp4'}
