# Deploy NVIDIA Inference Microservice on Amazon SageMaker

## Setup

Installs the dependencies required to package the model and run inferences using Triton server.

Also define the IAM role that will give SageMaker access to the model artifacts and the NVIDIA Triton ECR image.

In [1]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role
from pathlib import Path

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")
region = boto3.Session().region_name

sagemaker.config INFO - Not applying SDK defaults from location: /etc/xdg/sagemaker/config.yaml
sagemaker.config INFO - Not applying SDK defaults from location: /home/ec2-user/.config/sagemaker/config.yaml


### Packaging model and uploading to s3

In [None]:
MODEL_DOWNLOADED_PATH = 'llama-2-7b-chat_vLLAMA-2-7B-CHAT-4K-FP16-1-A100.24.02.rc2/LLAMA-2-7B-CHAT-4K-FP16-1-A100.24.02.rc2.tar.gz'

In [None]:
#MODEL_DOWNLOADED_PATH = <MODEL PATH>

In [15]:
current_directory = Path.cwd()
path = current_directory / MODEL_DOWNLOADED_PATH

In [16]:
model_uri = sagemaker_session.upload_data(path=path, key_prefix="nim")

In [5]:
nim_image_uri = "354625738399.dkr.ecr.us-east-1.amazonaws.com/nim-24.02-sm-final"

In [None]:
#nim_image_uri = "<ACCOUNT>.dkr.ecr.<REGION>.amazonaws.com/nim-24.02-sm-final"

In [6]:
container = {
    "Image": nim_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_MODEL_NAME": "llama-2-7b",
                    "SAGEMAKER_NUM_GPUS": "1"}
}

### Create SageMaker Endpoint

In [1]:
sm_prefix = "nim-model-"

sm_model_name = sm_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

NameError: name 'sm_prefix' is not defined

In [9]:
create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Model Arn: arn:aws:sagemaker:us-east-1:354625738399:model/symlink-nim-llama-2-7b-a100-2024-03-12-01-28-48


In [10]:
endpoint_config_name = sm_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.p4d.24xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Endpoint Config Arn: arn:aws:sagemaker:us-east-1:354625738399:endpoint-config/symlink-nim-llama-2-7b-a100-2024-03-12-01-28-48


In [11]:
endpoint_name = sm_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

Endpoint Arn: arn:aws:sagemaker:us-east-1:354625738399:endpoint/symlink-nim-llama-2-7b-a100-2024-03-12-01-28-49


In [12]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: Creating
Status: InService
Arn: arn:aws:sagemaker:us-east-1:354625738399:endpoint/symlink-nim-llama-2-7b-a100-2024-03-12-01-28-49
Status: InService


In [14]:
payload = {
  "model": "llama-2-7b",
  "prompt": "The capital of France is called",
  "max_tokens": 100,
  "temperature": 1,
  "n": 1,
  "stream": False,
  "stop": ["string"],
  "frequency_penalty": 0.0
}

response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/json", Body=json.dumps(payload)
)

print(json.loads(response["Body"].read().decode("utf8")))

{'id': 'cmpl-4f77577f-24a3-4e70-bfec-1fcef3420720', 'object': 'text_completion', 'created': 1710207480, 'model': 'llama-2-7b', 'choices': [{'index': 0, 'text': " Paris. It's a beautiful city with many famous landmarks, such as the Eiffel Tower and Notre Dame Cathedral.\nParis has been home to some very important historical events in Europe including: 1789 French Revolution began here; Napoleon Bonaparte was crowned emperor at the cathedral (Notre-Dame)in year 1804 & ended his reign there after defeat by allied forces during Franco Prussian War(1", 'logprobs': {'text_offset': [], 'token_logprobs': [0.0, 0.0], 'tokens': [], 'top_logprobs': []}}], 'usage': {'prompt_tokens': 7, 'total_tokens': 107, 'completion_tokens': 100}}


## Try streaming

In [None]:
import io
class LineIterator:
    """
    A helper class for parsing the byte stream input. 
    
    The output of the model will be in the following format:
    ```
    b'{"outputs": [" a"]}\n'
    b'{"outputs": [" challenging"]}\n'
    b'{"outputs": [" problem"]}\n'
    ...
    ```
    
    While usually each PayloadPart event from the event stream will contain a byte array 
    with a full json, this is not guaranteed and some of the json objects may be split across
    PayloadPart events. For example:
    ```
    {'PayloadPart': {'Bytes': b'{"outputs": '}}
    {'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
    ```
    
    This class accounts for this by concatenating bytes written via the 'write' function
    and then exposing a method which will return lines (ending with a '\n' character) within
    the buffer via the 'scan_lines' function. It maintains the position of the last read 
    position to ensure that previous bytes are not exposed again. 
    """
    
    def __init__(self, stream):
        self.byte_iterator = iter(stream)
        self.buffer = io.BytesIO()
        self.read_pos = 0

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            self.buffer.seek(self.read_pos)
            line = self.buffer.readline()
            if line and line[-1] == ord('\n'):
                self.read_pos += len(line)
                return line[:-1]
            try:
                chunk = next(self.byte_iterator)
            except StopIteration:
                if self.read_pos < self.buffer.getbuffer().nbytes:
                    continue
                raise
            if 'PayloadPart' not in chunk:
                print('Unknown event type:' + chunk)
                continue
            self.buffer.seek(0, io.SEEK_END)
            self.buffer.write(chunk['PayloadPart']['Bytes'])

In [None]:
payload = {
  "model": "llama-2-7b",
  "prompt": "The capital of France is called",
  "max_tokens": 100,
  "temperature": 1,
  "n": 1,
  "stream": True,
  "stop": ["string"],
  "frequency_penalty": 0.0
}

response = client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name, ContentType="application/json", Body=json.dumps(payload)
)

In [None]:
event_stream = response['Body'].read().decode("utf8"))

for line in LineIterator(event_stream):
    resp = json.loads(line)
    print(resp.get("outputs")[0], end='')

In [None]:
event_stream = response['Body'])

for line in LineIterator(event_stream):
    resp = json.loads(line.read().decode("utf8"))
    print(resp.get("outputs")[0], end='')

## Terminate endpoint and clean up artifacts

In [None]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_model(ModelName=sm_model_name)