# Using OctoAI SDXL on SageMaker through Model Packages

OctoAI SDXL is a performant and feature-rich SDXL implementation, allowing users to easily generate images, modify existing images using image-to-image, as well as a host of other capabilities. You can read more about what's possible in the [documentation](https://octo.ai/docs/octostack/sagemaker)

This sample notebook shows you how to deploy OctoAI SDXL using Amazon SageMaker.

> **Note**: This is a reference notebook and it cannot run unless you make changes suggested in the notebook.

## Pre-requisites:
1. Before running this notebook, please make sure you got this notebook from the model catalog on SageMaker AWS Management Console.
1. **Note**: This notebook contains elements which render correctly in Jupyter interface. Open this notebook from an Amazon SageMaker Notebook Instance or Amazon SageMaker Studio.
1. Ensure that IAM role used has **AmazonSageMakerFullAccess**.

## Contents:
1. [Select model package](#1.-Subscribe-to-the-model-package)
2. [Create an endpoint and perform real-time inference](#2.-Create-an-endpoint-and-perform-real-time-inference)
   1. [Create an endpoint](#A.-Create-an-endpoint)
   2. [Create input payload](#B.-Create-input-payload)
   3. [Perform real-time inference](#C.-Perform-real-time-inference)
   4. [Visualize output](#D.-Visualize-output)
   5. [Delete the endpoint](#E.-Delete-the-endpoint)
3. [Clean-up](#4.-Clean-up)
    1. [Delete the model](#A.-Delete-the-model)
    

## Usage instructions
You can run this notebook one cell at a time (By using Shift+Enter for running a cell).

## 1. Select to the model package
Confirm that you recieved this notebook from model catalog on SageMaker AWS Management Console.

Note that you will have to subscribe to the OctoAI SDXL product on the AWS Marketplace.

In [41]:
# Mapping for Model Packages (initially only us-east-1 and eu-west-1 is supported)
model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "us-west-1": "arn:aws:sagemaker:us-west-1:382657785993:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "eu-west-3": "arn:aws:sagemaker:eu-west-3:843114510376:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "eu-north-1": "arn:aws:sagemaker:eu-north-1:136758871317:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "ap-southeast-1": "arn:aws:sagemaker:ap-southeast-1:192199979996:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "ap-southeast-2": "arn:aws:sagemaker:ap-southeast-2:666831318237:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
    "sa-east-1": "arn:aws:sagemaker:sa-east-1:270155090741:model-package/octoai-sdxl-f213d1052497ffea6b-5237376082f63d8a848ceecdf1393e1f",
}

In [11]:
import json
from sagemaker import ModelPackage
from sagemaker import get_execution_role
from sagemaker import ModelPackage
import sagemaker as sage
import boto3

In [27]:
boto_session = boto3.Session()
region = boto_session.region_name
if region not in model_package_map.keys():
    raise ("UNSUPPORTED REGION")

model_package_arn = model_package_map[region]

In [28]:

sagemaker_session = sage.Session(boto_session=boto_session)
role = get_execution_role(sagemaker_session=sagemaker_session)

runtime_sm_client = boto_session.client("runtime.sagemaker")

## 2. Create an endpoint and perform real-time inference

If you want to understand how real-time inference with Amazon SageMaker works, see [Documentation](https://docs.aws.amazon.com/sagemaker/latest/dg/deploy-model.html).

In [42]:
# NOTE: you can change this to whatever you'd like to name your SDXL endpoint
model_name = "sdxl"

initial_instance_count = 1

content_type = "application/json"

real_time_inference_instance_type = (
    "ml.g5.4xlarge"
    # other possible options are:
    # ml.g5.xlarge  
    # ml.g5.2xlarge 
    # ml.g5.4xlarge 
    # ml.g5.8xlarge 
    # ml.g5.12xlarge
    # ml.g5.16xlarge
    # ml.g5.24xlarge
    # ml.g5.48xlarge
    # ml.p4d.24xlarge
    # ml.p4de.24xlarge
    # ml.p5.48xlarge
)

In [None]:
# create a deployable model from the model package.
model = ModelPackage(
    role=role, model_package_arn=model_package_arn, sagemaker_session=sagemaker_session
)

# Deploy the model
predictor = model.deploy(initial_instance_count, real_time_inference_instance_type, endpoint_name=model_name, model_data_download_timeout=1200, container_startup_health_check_timeout=600)

Once endpoint has been created, you can use it to generate images, where you can see examples in the next section.

### B. Create input payload

You can use the SDXL API to generate a variety of images, including doing text-to-image, image-to-image, inpainting, outpainting and PhotoMerge. You can look to the [documentation](https://octo.ai/docs/octostack/sagemaker) to see a complete description.

In [44]:
# A simple text-to-image payload, generating a single image
payload = {
    "prompt": "A wizard octopus in the forest conjuring a spell",
    "negative_prompt": "Blurry,low-res,poor quality",
    "steps": 30,
    "num_images": 1,
    "sampler": "DDIM",
    "cfg_scale": 12,
    "width": 1024,
    "height": 1024,
}

<Add code snippet that shows the payload contents>

### C. Perform real-time inference

In [45]:
response = runtime_sm_client.invoke_endpoint(
    EndpointName=model_name,
    ContentType=content_type,
    Body=json.dumps(payload),
)

output = json.loads(response["Body"].read().decode("utf8"))

### D. Visualize output

OctoAI SDXL will return a JSON payload that includes all the images generated, each encoded in base64. The structure looks like this:

```
    {
        "images": [
            {
                "image_b64":"<base64 encoded image>",
                "removed_for_safety":<true|false>,
                "seed":<integer>
            },
            { ... }
        ]
    }
```

First, let's define a simple helper function that will take the output and display all the generated images (in the case where more than a single image was returned)

In [33]:
from IPython import display
from base64 import b64decode

def display_output(endpoint_output):
    # Convert each base64-encoded image and display all of them
    imgs = []
    for output_img in endpoint_output["images"]:
        decoded_image = b64decode(output_img["image_b64"])
        imgs.append(display.Image(decoded_image))
    display.display(*imgs)

We can now use it on the output we previously received

In [46]:
display_output(output)

In addition to SDXL, your endpoint also supports SDXL Lightning, allowing you to create high-quality images faster, by utilizing a shorter number of steps

In [None]:
lightning_payload = {
    "prompt": "A wizard octopus in the forest conjuring a spell",
    "negative_prompt": "Blurry,low-res,poor quality",
    "num_images": 1,
    "sampler": "DDIM",
    "width": 1024,
    "height": 1024,

    # The main changes are:
    # 1. A lower number of steps
    # 2. A lower CFG scale
    # 3. Specifying the SDXL Lightning checkpoint
    # You can still change other parameters, such as number of images, resolution, sampler, etc
    "steps": 8,
    "cfg_scale": 2,
    "checkpoint": "octoai:lightning_sdxl"
}

response = runtime_sm_client.invoke_endpoint(
    EndpointName=model_name,
    ContentType=content_type,
    Body=json.dumps(lightning_payload),
)

output = json.loads(response["Body"].read().decode("utf8"))
display_output(output)

Similar to generating images from text, we can also generate images from other images, utilizing the same performance and functionality (this also works with SDXL Lightning)

In [39]:
import os
from base64 import b64encode, b64decode

# Read the image and encode it as base64
init_image = b64encode(open('./astronaut.png', 'rb').read()).decode("utf-8")

img2img_payload = {
    "prompt": "breathtaking, american woman, award winning photography, best quality, 8K HDR",
    "negative_prompt": "worst quality, low quality, bad quality, lazy eye",
    "width": 1344,
    "height": 768,
    "num_images": 1,
    "sampler": "DDIM",
    "steps": 30,
    "cfg_scale": 12,
    "use_refiner": False,
    "style_preset": "neon-punk",
    "strength": 0.8,
    "init_image": init_image,

    # We use a specific seed to get a specific image out, but you can
    # change this or omit it
    "seed": 2701628909,
}

response = runtime_sm_client.invoke_endpoint(
    EndpointName=model_name,
    ContentType=content_type,
    Body=json.dumps(img2img_payload),
)

output = json.loads(response["Body"].read().decode("utf8"))
print("Generated image")
display_output(output)

# display original image too for comparison
print("Original image")
display.Image(b64decode(init_image))

### E. Delete the endpoint

Now that you have successfully performed a real-time inference, you do not need the endpoint any more. You can terminate the endpoint to avoid being charged.

In [50]:
model.sagemaker_session.delete_endpoint(model_name)
model.sagemaker_session.delete_endpoint_config(model_name)

## 3. Clean-up

### A. Delete the model

In [51]:
model.delete_model()