# Deploy Stable Diffusion on a SageMaker GPU Multi-Model Endpoint with Triton

In this notebook we will host Stable Diffusion  SageMaker GPU Multi-Model Endpoints (MME GPU) powered by NVIDIA Triton Inference Server. We will compile Stable Diffusion for lower latency using [AITemplate](https://github.com/facebookincubator/AITemplate).

Skip to:
1. [Installs and imports](#installs)
2. [Packaging a conda environment, extending Sagemaker Triton container](#condaenv)
3. [Compile model with AITemplate](#aitemplate)
4. [Local testing of Triton model repository](#local)
5. [Deploy to SageMaker Real-Time Endpoint](#deploy)
6. [Analyze endpoint logs](#logs)
7. [Clean up](#cleanup)
------
------

### Part 1 - Installs and imports <a name="installs"></a>

In [None]:
!pip install nvidia-pyindex
!pip install tritonclient[http]
!pip install -U sagemaker ipywidgets pillow numpy transformers accelerate diffusers

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role

import tritonclient.http as httpclient
from tritonclient.utils import *
import time
from PIL import Image
import numpy as np

# variables
s3_client = boto3.client("s3")
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

# sagemaker variables
role = get_execution_role()
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
bucket = sagemaker_session.default_bucket()

### Part 2 - Packaging a conda environment, extending Sagemaker Triton container <a name="condaenv"></a>

When using the Triton Python backend (which our Stable Diffusion model will run on), you can include your own environment and dependencies. The recommended way to do this is to use [conda pack](https://conda.github.io/conda-pack/) to generate a conda environment archive in `tar.gz` format, and point to it in the `config.pbtxt` file of the models that should use it, adding the snippet: 

```
parameters: {
  key: "EXECUTION_ENV_PATH",
  value: {string_value: "path_to_your_env.tar.gz"}
}

```
You can use a different environment with every new loaded model, or the same for all models loaded into the container (read more on this [here](https://github.com/triton-inference-server/python_backend#creating-custom-execution-environments)). We will extend the public SageMaker Triton container image to include our environment, to avoid increasing the model S3 download time. 

Let's start by creating the conda environment with the necessary dependencies; this script will output a `stablediff_env.tar.gz` file.

We pass the conda path because we change a file in an installed library (AITemplate) to suppport A10G GPU's.

In [None]:
%%capture conda_path
!echo $CONDA_PREFIX

In [None]:
temp_path = str(conda_path).strip().split('/')[1:]
conda_path= '/'+'/'.join(temp_path[:-1])
print(conda_path)

In [None]:
!cd docker && bash conda_dependencies.sh "$conda_path"

Now, we get the correct URI for the SageMaker Triton container image. Check out all the available Deep Learning Container images that AWS maintains [here](https://github.com/aws/deep-learning-containers/blob/master/available_images.md). 

In [None]:
# account mapping for SageMaker Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}



region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
mme_triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)

We then build our extended image, which does nothing more than to copy the packaged environment into the container. Let's check out the Dockerfile.

In [None]:
!cat docker/Dockerfile

In [None]:
# Change this var to change the name of new container image
new_image_name = 'sagemaker-tritonserver-stablediffusion'

We catch the docker build process' output so that we can easily capture the output container image URI, and check for build errors.

In [None]:
%%capture build_output
!cd docker && bash build_and_push.sh "$new_image_name" 22.12 "$mme_triton_image_uri" "$region"

In [None]:
if 'Error response from daemon' in str(build_output):
    print(build_output)
    raise SystemExit('\n\n!!There was an error with the container build!!')
else:
    extended_triton_image_uri = str(build_output).strip().split('\n')[-1]

If the previous cell failed, check the docker build logs to understand the error problem, and read the possible resolution in the next cell

In [None]:
"""
If the cell above fails (check out the build_output) because of missing permissions to pull the public Triton base container image,
uncomment the commands in this cell, run them and retry the build
"""
# mapped_region_account = account_id_map[region]
# !aws ecr get-login-password --region "$region" | docker login --username AWS --password-stdin "$mapped_region_account".dkr.ecr."$region".amazonaws.com

----
----

### Part 3 - Compile model with AITemplate <a name="aitemplate"></a>

The next cell will use AITemplate to compile the StableDiffusion 2.1 base model and move it to the Triton model repo.

In [None]:
repo_name = "model_repo_0"

In [None]:
!docker run --gpus=all -it --shm-size=4G --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/model_repository $extended_triton_image_uri /bin/bash /model_repository/workspace/compile_model.sh "$repo_name"


------
------
### Part 4 - Local testing of Triton model repository <a name="local"></a>

Now you can test the model repository and validate it is working. Let's run the Triton docker container locally and invoke the model to check this. We are running the Triton container in detached model with the `-d` flag so that it runs in the background. 

In [None]:
!docker run --gpus=all -d --shm-size=4G --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd)/$repo_name:/model_repository $extended_triton_image_uri tritonserver --model-repository=/model_repository --exit-on-error=false
time.sleep(60)

In [None]:
CONTAINER_ID=!docker container ls -q
FIRST_CONTAINER_ID = CONTAINER_ID[0]

In [None]:
!echo $FIRST_CONTAINER_ID

In [None]:
!docker logs $FIRST_CONTAINER_ID 

<div class="alert alert-warning">
<b>Warning</b>: Rerun the cell above to check the container logs until you verify that Triton has loaded all models successfully, otherwise inference request will fail.
</div>

#### Now we will invoke the script locally

We will use Triton's HTTP client and its utility functions to send a request to `localhost:8000`, where the server is listening. We are sending text as binary data for input and receiving an array that we decode with numpy as output. Check out the code in `model_repository/pipeline/1/model.py` to understand how the input data is decoded and the output data returned, and check out more Triton Python backend [docs](https://github.com/triton-inference-server/python_backend) and [examples](https://github.com/triton-inference-server/python_backend/tree/main/examples) to understand how to handle other data types.

In [None]:
client = httpclient.InferenceServerClient(url="localhost:8000")

prompt = "Pikachu in a detective trench coat, photorealistic, nikon"
text_obj = np.array([prompt], dtype="object").reshape((-1, 1))

input_text = httpclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype))

input_text.set_data_from_numpy(text_obj)

output_img = httpclient.InferRequestedOutput("generated_image")

start = time.time()
query_response = client.infer(model_name="pipeline_0", inputs=[input_text], outputs=[output_img])
print(f"took {time.time()-start} seconds")

image = query_response.as_numpy("generated_image")
im = Image.fromarray(np.squeeze(image))
im.save("generated_image.jpg")

In [None]:
display(im)

Let's stop the container that is running locally so we don't take up notebook resources.

In [None]:
!docker kill $FIRST_CONTAINER_ID

----
----
### Part 5 - Deploy to SageMaker Real-Time Endpoint <a name="deploy"></a>

SageMaker expects a .tar.gz file containing the Triton model repository to be hosted on the endpoint.

In [None]:
prefix = 'stable-diffusion-aitemplate'
tar_file_name = 'sd2-aitemplate.tar.gz'
!tar -C model_repo_0/ -czf "$tar_file_name" .
model_url = sagemaker_session.upload_data(path=tar_file_name, key_prefix=prefix)

Create SM container and model definitions.

In [None]:
ts = time.strftime("%Y-%m-%d-%H-%M-%S", time.gmtime())

container = {
    "Image": extended_triton_image_uri,
    # "Image": mme_triton_image_uri,
    "ModelDataUrl": model_url,
    "Mode": "SingleModel",
    "Environment": {
        "SAGEMAKER_TRITON_DEFAULT_MODEL_NAME": "pipeline_0",
    }
}

In [None]:
sm_model_name = f"{prefix}-mdl-{ts}"

create_model_response = sm_client.create_model(
    ModelName=sm_model_name, ExecutionRoleArn=role, PrimaryContainer=container
)

print("Model Arn: " + create_model_response["ModelArn"])

Create a SageMaker endpoint configuration.

In [None]:
endpoint_config_name = f"{prefix}-epc-{ts}"
instance_type = 'ml.g5.xlarge'

create_endpoint_config_response = sm_client.create_endpoint_config(
    EndpointConfigName=endpoint_config_name,
    ProductionVariants=[
        {
            "InstanceType": instance_type,
            "InitialVariantWeight": 1,
            "InitialInstanceCount": 1,
            "ModelName": sm_model_name,
            "VariantName": "AllTraffic",
        }
    ],
)

print("Endpoint Config Arn: " + create_endpoint_config_response["EndpointConfigArn"])

Create the endpoint, and wait for it to be up and running.

In [None]:
endpoint_name = f"{prefix}-ep-{ts}"

create_endpoint_response = sm_client.create_endpoint(
    EndpointName=endpoint_name, EndpointConfigName=endpoint_config_name
)

print("Endpoint Arn: " + create_endpoint_response["EndpointArn"])

In [None]:
resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
status = resp["EndpointStatus"]
print("Status: " + status)

while status == "Creating":
    time.sleep(60)
    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)
    status = resp["EndpointStatus"]
    print("Status: " + status)

print("Arn: " + resp["EndpointArn"])
print("Status: " + status)

#### Invoke models

In [None]:
prompt = "Smiling person"
inputs = []
outputs = []

text_obj = np.array([prompt], dtype="object").reshape((-1, 1))

inputs.append(httpclient.InferInput("prompt", text_obj.shape, np_to_triton_dtype(text_obj.dtype)))
inputs[0].set_data_from_numpy(text_obj)


outputs.append(httpclient.InferRequestedOutput("generated_image"))

Since we are using the SageMaker Runtime client to send an HTTP request to the endpoint now, we use Triton's `generate_request_body` method to create the right [request format](https://github.com/triton-inference-server/server/tree/main/docs/protocol) for us.

In [None]:
request_body, header_length = httpclient.InferenceServerClient.generate_request_body(
    inputs, outputs=outputs
)

print(request_body)

We are sending our request in binary format for lower inference latency. 

With the binary+json format, we have to specify the length of the request metadata in the header to allow Triton to correctly parse the binary payload. This is done using a custom Content-Type header, which is different from using an `Inference-Header-Content-Length` header on a standalone Triton server because custom headers aren’t allowed in SageMaker. 

In [None]:
header_length_prefix = "application/vnd.sagemaker-triton.binary+json;json-header-size="
for i in range(20):
    tick = time.time()
    response = runtime_sm_client.invoke_endpoint(
            EndpointName=endpoint_name,
            ContentType="application/vnd.sagemaker-triton.binary+json;json-header-size={}".format(header_length),
            Body=request_body,
        )
    print(time.time()-tick)

In [None]:
header_length_str = response["ContentType"][len(header_length_prefix) :]
result = httpclient.InferenceServerClient.parse_response_body(
            response["Body"].read(), header_length=int(header_length_str))
image_array = result.as_numpy("generated_image")
image = Image.fromarray(np.squeeze(image_array))

In [None]:
display(image)

-----
-----
### Part 6 - Analyze endpoint logs <a name="logs"></a>

Let's analyze our endpoint's CloudWatch logs and verify the behaviour triggered by MME: as the GPU ran out of memory space, the first models we invoked are unloaded to make room for the ones invoked later. MME follows a Least Recently Used (LRU) policy to evict models from GPU memory or RAM (in the case of MME on CPU).

First we build the URL where we can access our endpoint's logs.

In [None]:
cloudwatch_log_url = f'https://{region}.console.aws.amazon.com/cloudwatch/home?region={region}#logStream:group=/aws/sagemaker/Endpoints/{endpoint_name}'

print('↓↓↓Click the following link to access the endpoint logs↓↓↓\n')
print(cloudwatch_log_url)

----
----
### Part 7 - Clean up <a name="cleanup"></a>

In [None]:
sm_client.delete_endpoint(EndpointName=endpoint_name)
sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)
sm_client.delete_model(ModelName=sm_model_name)