# Deploy NVIDIA Inference Microservice (NIM) on Amazon SageMaker

NVIDIA NIM for LLMs brings the power of state-of-the-art large language models (LLM) to your applications, providing unmatched natural language processing and understanding capabilities. Whether you’re developing chatbots, content analyzers—or any application that needs to understand and generate human language—NVIDIA NIM for LLMs has you covered. Built on the NVIDIA software platform incorporating CUDA, TRT, TRT-LLM, and Triton, NVIDIA NIM for LLMs brings state of the art GPU accelerated Large Language model serving.

In this example we show how to deploy `LLaMa-2 7B` prebuilt and optimized model with NIM on Amazon SageMaker.

## Setup

Installs the dependencies and setup roles required to package the model and create SageMaker endpoint. 

In [None]:
import boto3, json, sagemaker, time
from sagemaker import get_execution_role
from pathlib import Path

sess = boto3.Session()
sm = sess.client("sagemaker")
sagemaker_session = sagemaker.Session(boto_session=sess)
role = get_execution_role()
client = boto3.client("sagemaker-runtime")
region = boto3.Session().region_name

We define the NIM image that we will be using for deploying on SageMaker endpoint.

In [None]:
nim_image_uri = "<ACCOUNT>.dkr.ecr.<REGION>.amazonaws.com/nim-24.02-sm"

### Packaging model and uploading to s3

The prebuilt model engine we downloaded from NGC is already in `.tar.gz` format so we will upload it to S3 to later create a SageMaker Model.

Let's start. First we set the path to downloaded model `tar.gz` file.

In [None]:
# set me if you want to try different model
# MODEL_DOWNLOADED_PATH = <MODEL PATH>

# for llama-7B example we can set
MODEL_DOWNLOADED_PATH = 'llama-2-7b-chat_vLLAMA-2-7B-CHAT-4K-FP16-1-A100.24.02.rc2/LLAMA-2-7B-CHAT-4K-FP16-1-A100.24.02.rc2.tar.gz'

Next we upload it to S3

In [None]:
current_directory = Path.cwd()
path = current_directory / MODEL_DOWNLOADED_PATH
key_prefix = "nim-model"
model_uri = sagemaker_session.upload_data(path=path, key_prefix=key_prefix)

### Create SageMaker Endpoint

Next we can start creating a sagemaker model from the model we uploaded to s3 in the previous step.

In this step we also need to provide additional Environment Variables
- `SAGEMAKER_MODEL_NAME` which specifies the name of the model to be loaded by NIM container on SageMaker. You can provide any name, you just have to make sure it matches the name you provide in the inference request also.
- `SAGEMAKER_NUM_GPUS` which specifies the number of GPUs the model was prebuilt to run inference on. This was specified in the name of the prebuilt model engine you downloaded from NGC. 

For example if the model is `LLAMA-2-7B-CHAT-4K-FP16-1-A100.24.02` then that means it's LLama 2 7B Chat model with `4K` context len, optimized for `FP16` precision and designed to run on `1 A100 GPU`.

Similarly, if it's `LLAMA-2-70B-CHAT-4K-FP16-4-A100-24.02` then that means it's LLama 2 70B Chat model with `4K` context len, optimized for `FP16` precision and designed to run on `4 A100` GPUs

Here we set model name as `llama-2-7b` and `num of GPUs` as `1` because that's what our example prebuilt engine is built for.

In [None]:
SAGEMAKER_MODEL_NAME = "llama-2-7b"
SAGEMAKER_NUM_GPUS = "1"

In [None]:
container = {
    "Image": nim_image_uri,
    "ModelDataUrl": model_uri,
    "Environment": {"SAGEMAKER_MODEL_NAME": SAGEMAKER_MODEL_NAME,
                    "SAGEMAKER_NUM_GPUS": SAGEMAKER_NUM_GPUS}
}
sm_prefix = "nim-model-"

sm_model_name = sm_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())
create_model_response = sm.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Using the model above, we create an endpoint configuration where we can specify the type of instance we want in the endpoint. 

**IMPORTANT: In this case since prebuilt engine was optimized for A100 GPU we specify the `Instance Type` as `ml.p4d.24xlarge` which has A100 40GB. On other hand, `ml.p4de.24xlarge` has A100 with larger memory of 80GB so if you want to try larger LLMs like Mixtral-8x7B or LLaMa-70B or have larger batch size or larger sequence length or heavier traffic then please use `ml.p4de.24xlarge`.**

In [None]:
endpoint_config_name = sm_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_config_response = sm.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": "ml.p4d.24xlarge",
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
            "ModelDataDownloadTimeoutInSeconds": 2700,
            "ContainerStartupHealthCheckTimeoutInSeconds": 3600
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Using the above endpoint configuration we create a new sagemaker endpoint and wait for the deployment to finish. The status will change to InService once the deployment is successful.

In [None]:
endpoint_name = sm_prefix + time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

create_endpoint_response = sm.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

### Run Inference

Once we have the endpoint's status as `InService` we can use a sample text to do a prompt text completion inference request using json as the payload format. For inference request format, currently NIM on SageMaker supports the OpenAI API completions inference protocol. For explanation of supported parameters please see [this link](https://platform.openai.com/docs/api-reference/completions/create). 

In [None]:
payload = {
  "model": SAGEMAKER_MODEL_NAME,
  "prompt": "The capital of France is called",
  "max_tokens": 100,
  "temperature": 1,
  "n": 1,
  "stream": False,
  "stop": ["string"],
  "frequency_penalty": 0.0
}

response = client.invoke_endpoint(
    EndpointName=endpoint_name, ContentType="application/json", Body=json.dumps(payload)
)

output = json.loads(response["Body"].read().decode("utf8"))
print(json.dumps(output, indent=2))

### Try streaming inference

NIM on SageMaker also supports streaming inference and you can enable that by setting **`"stream"` as `True`** in the payload and by using [`invoke_endpoint_with_response_stream`](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sagemaker-runtime/client/invoke_endpoint_with_response_stream.html) method.

In [None]:
payload = {
  "model": SAGEMAKER_MODEL_NAME,
  "prompt": "The capital of France is called",
  "max_tokens": 100,
  "temperature": 1,
  "n": 1,
  "stream": True,
  "stop": ["string"],
  "frequency_penalty": 0.0
}
response = client.invoke_endpoint_with_response_stream(
    EndpointName=endpoint_name,
    Body=json.dumps(payload),
    ContentType="application/json",
    Accept="application/jsonlines",
)

We do some postprocessing on the event stream to handle the streaming output tokens.

In [None]:
from utils import LineIterator
import re

event_stream = response["Body"]

# Create an instance of LineIterator
line_iterator = LineIterator(event_stream)

# Iterate over the lines
prev = None
for line in line_iterator:
    # Decode the line into bytes
    decoded_line = line[len(b'data: '):].decode("utf-8").rstrip('\n')

    if decoded_line == " [DONE]":
        print(prev)
        print("\nStreaming Generation Finished!")
        break
    else:
        # Extract the desired information from the JSON
        decoded_json = json.loads(decoded_line)
        text = decoded_json['choices'][0]['text']
        # print(text)
        words_and_punctuations = re.findall(r"[\w']+|[.,!?;&()\"–—:;!*#@$%/\\<>\[\]{}|^~=+]", text)# Get the last word
        # print(words_and_punctuations[-1])
        # print("===========")
        if len(words_and_punctuations) > 0:
            if not has_same_prefix(prev, words_and_punctuations[-1]) and prev is not None:
                # print("**************")
                print(prev, end=' ')
            prev = words_and_punctuations[-1]

### Terminate endpoint and clean up artifacts

In [None]:
sm.delete_endpoint(EndpointName=endpoint_name)
sm.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm.delete_model(ModelName=sm_model_name)