# LLava stateful inference with SageMaker

## Contents

This notebook uses SageMaker notebook instance `conda_pytorch_p310` kernel, demonstrates how to use TorchServe to deploy LLava Model on SageMaker. 

This is the code accompanying the https://aws.amazon.com/blogs/machine-learning/build-ultra-low-latency-multimodal-generative-ai-applications-using-sticky-session-routing-in-amazon/


## Step 0: Let's bump up SageMaker and import stuff

In [1]:
!python --version && aws --version

Python 3.10.14
aws-cli/1.35.9 Python/3.10.14 Linux/5.10.226-214.880.amzn2.x86_64 botocore/1.35.43


In [18]:
!pip install -Uq pip
!pip install -Uq sagemaker
!pip install -Uq torch-model-archiver
!pip install -Uq botocore
!pip install -Uq boto3

In [3]:
# You need boto3 1.35.15 to pass in Session ID to invoke_endpoint (see step 4.1), older versions of boto3 will not accept this input.
!pip freeze|grep boto

boto3==1.35.51
botocore==1.35.51


In [4]:
import os
import shutil
import importlib
import botocore

In [5]:
import sagemaker
from sagemaker import image_uris
import boto3
import os
import time
import json
from pathlib import Path

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


In [6]:
import boto3
import sagemaker
from sagemaker import Model, image_uris, serializers, deserializers

region_name: str = "us-east-1"

boto3_session=boto3.session.Session(region_name=region_name)
# Create a SageMaker runtime client object using your IAM role ARN
smr = boto3.client('sagemaker-runtime', region_name=region_name)
# Create a SageMaker client object
sm = boto3.client('sagemaker', region_name=region_name)
# execution role for the endpoint
role = sagemaker.get_execution_role()  
# sagemaker session for interacting with different AWS APIs
sess= sagemaker.session.Session(boto3_session, sagemaker_client=sm, sagemaker_runtime_client=smr)  
# region name of the current SageMaker Studio environment
region = sess._region_name 
# account_id of the current SageMaker Studio environment
account = sess.account_id()  

# Configuration:
bucket_name = sess.default_bucket()
prefix = "torchserve"
output_path = f"s3://{bucket_name}/{prefix}"
model_name = "llava-sm"
print(f'account={account}, region={region}, role={role}, output_path={output_path}')

account=015469603702, region=us-east-1, role=arn:aws:iam::015469603702:role/sm-vision-llama32-SageMakerEndpointRole-Ab6IVBACmhji, output_path=s3://sagemaker-us-east-1-015469603702/torchserve


## Step 1: Build a BYOD TorchServe Docker container and push it to Amazon ECR

1. Create an ECR repo: https://docs.aws.amazon.com/AmazonECR/latest/userguide/repository-create.html
2. Get Base Image: https://github.com/aws/deep-learning-containers/blob/master/available_images.md

In [8]:
baseimage = f"763104351884.dkr.ecr.{region}.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker"
reponame = "llava-stateful"
versiontag = "1.0"
print("use the output from the print below to run ./build_and_push.sh in a termianl. You get better feedback in terminal.")
print (f"cd docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}")
print("if you do endup running this command in a terminal , you can skip the next cell")

use the output from the print below to run ./build_and_push.sh in a termianl. You get better feedback in terminal.
cd docker && ./build_and_push.sh llava-stateful 1.0 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker us-east-1 015469603702
if you do endup running this command in a terminal , you can skip the next cell


In [9]:
# # %%capture build_output

# # # Build our own docker image
# # !cd docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}
# import subprocess

# # Build our own docker image
# command = f"cd docker && ./build_and_push.sh {reponame} {versiontag} {baseimage} {region} {account}"

# # Run the command and capture output in real-time
# process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)

# # Print output in real-time
# for line in process.stdout:
#     print(line, end='')

# # Wait for the process to complete
# process.wait()

# # Check if the process completed successfully
# if process.returncode != 0:
#     print(f"Error: Command exited with return code {process.returncode}")
# else:
#     print("Image build and push completed successfully")

https://docs.docker.com/engine/reference/commandline/login/#credentials-store

Login Succeeded
https://docs.docker.com/engine/reference/commandline/login/#credentials-store

Login Succeeded
#0 building with "default" instance using docker driver

#1 [internal] load build definition from Dockerfile
#1 transferring dockerfile:
#1 transferring dockerfile: 643B done
#1 DONE 0.1s

#2 [auth] sharing credentials for 763104351884.dkr.ecr.us-east-1.amazonaws.com
#2 DONE 0.0s

#3 [internal] load metadata for 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker
#3 DONE 0.3s

#4 [internal] load .dockerignore
#4 transferring context: 2B done
#4 DONE 0.0s

#5 [internal] load build context
#5 transferring context: 5.78kB done
#5 DONE 0.1s

#6 [1/9] FROM 763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker@sha256:bd4dfe46e1f8c71620a210a572e3527834430d00549de62d5f2d20c708f091eb
#6 resolve 7631043

In [19]:
# Update container
container = f"{account}.dkr.ecr.{region}.amazonaws.com/{reponame}:{versiontag}"
container
print(baseimage)


763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-inference:2.3.0-gpu-py311-cu121-ubuntu20.04-sagemaker


## Step2: Build TorchServe Model Artifacts and Upload to S3

In [20]:
rm -rf code/{model_name}

In [21]:
!cd code && torch-model-archiver --model-name {model_name} --version 1.0 --handler handler/custom_handler.py --config-file handler/model-config.yaml --archive-format no-archive --extra-files handler/ -f

In [22]:
!cd code && aws s3 cp {model_name} {output_path}/{model_name} --recursive

upload: llava-sm/data_types.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/data_types.py
upload: llava-sm/llava/__init__.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/llava/__init__.py
upload: llava-sm/MAR-INF/MANIFEST.json to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/MAR-INF/MANIFEST.json
upload: llava-sm/llava/model/language_model/llava_mistral.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/llava/model/language_model/llava_mistral.py
upload: llava-sm/__init__.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/__init__.py
upload: llava-sm/llava/model/language_model/llava_mpt.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/llava/model/language_model/llava_mpt.py
upload: llava-sm/llava/mm_utils.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/llava/mm_utils.py
upload: llava-sm/llava/model/apply_delta.py to s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/llava/model/a

In [23]:
s3_uri = f"{output_path}/{model_name}/"
print(s3_uri)

s3://sagemaker-us-east-1-015469603702/torchserve/llava-sm/


## Step3: Create SageMaker Endpont

### 3.1 Create Model

In [24]:
from datetime import datetime

instance_type = "ml.p4d.24xlarge"
endpoint_name = sagemaker.utils.name_from_base(model_name)

model = Model(
    name=model_name + datetime.now().strftime("%Y-%m-%d-%H-%M-%S"),
    # Enable SageMaker uncompressed model artifacts via "S3DataType": "S3Prefix"
    model_data={
        "S3DataSource": {
                "S3Uri": s3_uri,
                "S3DataType": "S3Prefix",
                "CompressionType": "None",
        }
    },
    image_uri=container,
    role=role,
    sagemaker_session=sess,
    env={
        # TorchServe configuration file
        "TS_CONFIG_FILE": "/home/model-server/config.properties",
        # Disable token authorization for REST APIs
        "TS_DISABLE_TOKEN_AUTHORIZATION": "true", 
        # Headers to indicate Session ID
        "TS_HEADER_KEY_SEQUENCE_ID": "X-Amzn-SageMaker-Session-Id",
        "TS_REQUEST_SEQUENCE_ID": "X-Amzn-SageMaker-Session-Id",
        # Headers to indicate closed session
        "TS_HEADER_KEY_SEQUENCE_END": "X-Amzn-SageMaker-Closed-Session-Id",
        "TS_REQUEST_SEQUENCE_END": "X-Amzn-SageMaker-Closed-Session-Id",
        # Enable system metrics aggregation
        "TS_DISABLE_SYSTEM_METRICS": "false" 
    },
)
print(model)

<sagemaker.model.Model object at 0x7f10cbf3a9e0>


### 3.2 Deploy Model and Create Endpoint

In [25]:
model.deploy(
    initial_instance_count=1, # increase the number of instances based on your load
    instance_type=instance_type,
    endpoint_name=endpoint_name,
    #volume_size=512, # increase the size to store large model
    model_data_download_timeout=3600, 
    container_startup_health_check_timeout=3600, 
)

--------------------------!

### 3.3 Create a Predictor

In [26]:
predictor = sagemaker.predictor.Predictor(
    endpoint_name=model.endpoint_name,
    sagemaker_session=sess
)
print(predictor)

Predictor: {'endpoint_name': 'llava-sm-2024-10-30-08-12-02-039', 'sagemaker_session': <sagemaker.session.Session object at 0x7f10fc86ca00>, 'serializer': <sagemaker.base_serializers.IdentitySerializer object at 0x7f10cdf0f490>, 'deserializer': <sagemaker.base_deserializers.BytesDeserializer object at 0x7f10cdf0fb50>}


In [27]:
# predictor = sagemaker.predictor.Predictor(
#     endpoint_name='llava-sm-2024-09-04-06-35-10-354',
#     sagemaker_session=sess
# )
# print(predictor)

## Step4: Run Inference

In [28]:
#Add necessary modules path to sys.path
import os, sys

demo_data_path = os.path.join(os.getcwd(), "code/handler")
if demo_data_path not in sys.path:
    sys.path.append(demo_data_path)

In [29]:
#Install dependencies
!pip install torch dataclasses_json

Collecting dataclasses_json
  Downloading dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses_json)
  Downloading marshmallow-3.23.0-py3-none-any.whl.metadata (7.6 kB)
Collecting typing-inspect<1,>=0.4.0 (from dataclasses_json)
  Downloading typing_inspect-0.9.0-py3-none-any.whl.metadata (1.5 kB)
Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses_json)
  Downloading mypy_extensions-1.0.0-py3-none-any.whl.metadata (1.1 kB)
Downloading dataclasses_json-0.6.7-py3-none-any.whl (28 kB)
Downloading marshmallow-3.23.0-py3-none-any.whl (49 kB)
Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)
Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)
Installing collected packages: mypy-extensions, typing-inspect, marshmallow, dataclasses_json
Successfully installed dataclasses_json-0.6.7 marshmallow-3.23.0 mypy-extensions-1.0.0 typing-inspect-0.9.0


### 4.1 Open Session 1

In [35]:
# image_url="https://images.pexels.com/photos/1519753/pexels-photo-1519753.jpeg"
image_url='https://fileinfo.com/img/ss/sm/jpeg_43-2.jpg'

In [36]:
%%time
from data_types import (
    BaseRequest,
    CloseSessionRequest,
    StartSessionRequest,
    TextPromptRequest,
    OpenSessionResponse,
    TextPromptResponse,
    CloseSessionResponse
)

ts_request_sequence_id = "SessionId"


def send_and_check_request(r, seq_id):
    response = smr.invoke_endpoint(
        EndpointName=endpoint_name,
        Body=r.to_json(),
        ContentType="application/json",
        SessionId=seq_id
    )
    assert response["ResponseMetadata"]["HTTPStatusCode"] == 200, f"Sending request failed: {r}"
    return response['Body'].readlines()[0]

open_request = StartSessionRequest(
    type="start_session",
    path=image_url,
)

open_response = send_and_check_request(open_request, "NEW_SESSION")
open_response = OpenSessionResponse.from_json(open_response)
print(open_response)
assert open_response.session_id.startswith("ts-seq-")

OpenSessionResponse(session_id='ts-seq-18902bce-996e-4b7a-a33f-a5b6e10e154c')
CPU times: user 32.9 ms, sys: 0 ns, total: 32.9 ms
Wall time: 384 ms


In [37]:
open_response.session_id

'ts-seq-18902bce-996e-4b7a-a33f-a5b6e10e154c'

### 4.2 Send Text Promt 1

In [38]:
%%time
text_prompt_request1 = TextPromptRequest(
    type="send_text_prompt",
    session_id=open_response.session_id,
    prompt_text="describe the picture"
)

text_prompt_response1 = send_and_check_request(text_prompt_request1, open_response.session_id)
text_prompt_response1 = TextPromptResponse.from_json(text_prompt_response1)
print(text_prompt_response1.response_text)
assert text_prompt_response1.response_text

The image features a lush green field filled with a variety of colorful flowers. The flowers are scattered throughout the field, with some being closer to the foreground and others further in the background. The vibrant colors of the flowers create a beautiful and eye-catching scene. The field appears to be a garden or a park, providing a serene and picturesque environment for visitors to enjoy.
CPU times: user 4.95 ms, sys: 0 ns, total: 4.95 ms
Wall time: 5.07 s


### 4.3 Send Text Promt 2

In [40]:
%%time
text_prompt_request2 = TextPromptRequest(
    type="send_text_prompt",
    session_id=open_response.session_id,
    prompt_text="is there grass in the picture, describe it"
)

text_prompt_response2 = send_and_check_request(text_prompt_request2, open_response.session_id)
text_prompt_response2 = TextPromptResponse.from_json(text_prompt_response2)
print(text_prompt_response2.response_text)
assert text_prompt_response2.response_text

Yes, there is grass in the picture. The grass is green and appears to be growing in the background of the image, alongside the colorful flowers.
CPU times: user 5.01 ms, sys: 0 ns, total: 5.01 ms
Wall time: 1.32 s


### 4.4 Close session

In [41]:
# close session
close_request = CloseSessionRequest(
    type="close_session",
    session_id=open_response.session_id,
)
    
close_response = send_and_check_request(
    close_request, open_response.session_id
)

close_response = CloseSessionResponse.from_json(close_response)
assert close_response.success

In [42]:
sess.delete_endpoint(endpoint_name)
sess.delete_endpoint_config(endpoint_name)
model.delete_model()