# Deploy NVIDIA NIM on Amazon SageMaker

NVIDIA NIM, part of NVIDIA AI Enterprise, brings the power of state-of-the-art large language models (LLM) to your applications, providing unmatched natural language processing and understanding capabilities. Whether you’re developing chatbots, content analyzers—or any application that needs to understand and generate human language—NVIDIA NIM for LLMs has you covered.

In this example we show how to deploy `LLaMa-3 70B` optimized model on p4d instance with NIM on Amazon SageMaker.

<div class="alert alert-block alert-info">
<b>IMPORTANT:</b> To run NIM on SageMaker you will need to have your NGC API KEY because it's required to access NGC resources. Check out these <a href="https://docs.nvidia.com/nim/large-language-models/latest/getting-started.html#ngc-authentication">NIM docs</a> to learn how to get NGC API KEY.
</div>

## Setup

Installs the dependencies and setup roles required to package the model and create SageMaker endpoint. 

In [None]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role
from pathlib import Path

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")
region = sess.region_name
sts_client = sess.client('sts')
account_id = sts_client.get_caller_identity()['Account']

We define the NIM image that we will be using for deploying on SageMaker endpoint.

In [None]:
nim_image_uri = "354625738399.dkr.ecr.us-east-1.amazonaws.com/nim-llama3-70b-instruct"

### Create SageMaker Endpoint

**Before proceeding further, please set your NGC API Key.**

In [None]:
# SET ME
NGC_API_KEY = ""

In [None]:
assert NGC_API_KEY is not None, "NGC API KEY is not set. Please set the NGC_API_KEY variable. It's required for running NIM."

We define sagemaker model from the NIM container making sure to pass in **NGC_API_KEY**

In [None]:
sm_model_name = "nim-llama3-70b-instruct"
container = {
    "Image": nim_image_uri,
    "Environment": {"NGC_API_KEY": NGC_API_KEY}
}
create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Next we create endpoint configuration, here we are deploying the LLama3-8B model on `g5.4xlarge` instance

In [None]:
endpoint_config_name = sm_model_name

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.p4d.24xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
            "ContainerStartupHealthCheckTimeoutInSeconds": 850
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to InService once the deployment is successful.

In [None]:
endpoint_name = sm_model_name

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Run Inference

Once we have the endpoint's status as `InService` we can use a sample text to do a chat completion inference request using json as the payload format. For inference request format, currently NIM on SageMaker supports the OpenAI API chat completions inference protocol. For explanation of supported parameters please see [this link](https://platform.openai.com/docs/api-reference/chat). 

<div class="alert alert-block alert-info">
<b>IMPORTANT:</b> Model name in inference request payload needs to be the name of NIM model. Please DON'T change it below. 
</div>

In [None]:
messages = [
    {"role": "user", "content": "Hello! How are you?"},
    {"role": "assistant", "content": "Hi! I am quite well, how can I help you today?"},
    {"role": "user", "content": "Write a short limerick about the wonders of GPU Computing."}
]
payload = {
  "model": "meta/llama3-70b-instruct",
  "messages": messages,
  "max_tokens": 100
}


response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/json", Body=json.dumps(payload)
)

output = json.loads(response["Body"].read().decode("utf8"))
print(json.dumps(output, indent=2))

### Try streaming inference

NIM on SageMaker also supports streaming inference and you can enable that by setting **`"stream"` as `True`** in the payload and by using [`invoke_endpoint_with_response_stream`](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime/client/invoke_endpoint_with_response_stream.html) method.

In [None]:
messages = [
    {"role": "user", "content": "Hello! How are you?"},
    {"role": "assistant", "content": "Hi! I am quite well, how can I help you today?"},
    {"role": "user", "content": "Write a short limerick about the wonders of GPU Computing."}
]
payload = {
  "model": "meta/llama3-70b-instruct",
  "messages": messages,
  "max_tokens": 100,
  "stream": True
}


response = client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name,
    Body=json.dumps(payload),
    ContentType="application/json",
    Accept="application/jsonlines",
)

We have some postprocessing code for the streaming output.

In [None]:
event_stream = response['Body']

for event in event_stream:
    try:
        payload = event.get('PayloadPart', {}).get('Bytes', b'')
        if payload:
            data_str = payload.decode('utf-8')
            if data_str.startswith('data:'):
                json_data = data_str[5:].strip()
                if json_data:
                    try:
                        data = json.loads(json_data)
                        content = data.get('choices', [{}])[0].get('delta', {}).get('content', "")
                        if content:
                            print(content, end='', flush=True)
                    except json.JSONDecodeError:
                        continue
    except Exception as e:
        print(f"\nError processing event: {e}", flush=True)
        continue

### Terminate endpoint and clean up artifacts

In [None]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_model(ModelName=sm_model_name)