# Generative AI: Stable Diffusion SDXL
---

**Stable Diffusion SDXL** is one of the foundation model used for `text-to-image` and `image-to-image` tasks. This will allow the customers to respond to use cases, like:

- Creative idea for media, fashion, and other industry
- Image inpainting
- Image colorization


Documentation of this model can be found on [Stability.ai](https://stability.ai/sdxl-aws-documentation) website.

<div class="alert alert-block alert-info">
    <b>Note</b>: In this notebook, we will use SDXL 1.0, which required you to subscribe to the model in the AWS marketplace page. Read the <b>Important</b> below to get more information.
</div>


### Important

Please visit model detail page in https://aws.amazon.com/marketplace/pp/prodview-pe7wqwehghdtm to learn more. If you do not have access to the link, please contact account admin for the help.

You will find details about the model including pricing, supported region, and end user license agreement. To use the model, please click “Continue to Subscribe” from the detail page, come back here and learn how to deploy and inference.


## Set up

In [None]:
!pip install 'stability-sdk[sagemaker] @ git+https://github.com/Stability-AI/stability-sdk.git@sagemaker' --quiet
!pip install protobuf==3.20 --quiet
!pip install --upgrade pip sagemaker boto3 --quiet

In [None]:
import sagemaker
from sagemaker import ModelPackage, get_execution_role
from stability_sdk_sagemaker.predictor import StabilityPredictor
from stability_sdk_sagemaker.models import get_model_package_arn
from stability_sdk.api import GenerationRequest, GenerationResponse, TextPrompt

from PIL import Image
from typing import Union, Tuple
import io
import os
import base64
import boto3

In [None]:
import logging

logger = logging.getLogger()
logging.basicConfig(format='%(asctime)s %(levelname)s:%(message)s', level=logging.INFO, datefmt='%I:%M:%S')
logger.setLevel(logging.INFO)

## Create JumpStart model and endpoint

In [None]:
from sagemaker.utils import name_from_base

endpoint_name = name_from_base('sdxl-1-0')  # insert your model endpoint name here

Get the model package mapping, please note that the current model is not available in all AWS regions. But it will keep updating.

In [None]:
model_package_map = {
    "us-east-1": "arn:aws:sagemaker:us-east-1:865070037744:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "us-east-2": "arn:aws:sagemaker:us-east-2:057799348421:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "us-west-2": "arn:aws:sagemaker:us-west-2:594846645681:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "ca-central-1": "arn:aws:sagemaker:ca-central-1:470592106596:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "eu-central-1": "arn:aws:sagemaker:eu-central-1:446921602837:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "eu-west-1": "arn:aws:sagemaker:eu-west-1:985815980388:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "eu-west-2": "arn:aws:sagemaker:eu-west-2:856760150666:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "ap-northeast-2": "arn:aws:sagemaker:ap-northeast-2:745090734665:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "ap-northeast-1": "arn:aws:sagemaker:ap-northeast-1:977537786026:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6",
    "ap-south-1": "arn:aws:sagemaker:ap-south-1:077584701553:model-package/sdxl-v1-0-2042286-300c57d4fa1e39968d711d754640b0b6"
}

REGION = boto3.Session().region_name
logger.info(f'Region => {REGION}')

if REGION not in model_package_map.keys():
    raise ("UNSUPPORTED REGION")
    
package_arn = model_package_map[REGION]
ROLE = get_execution_role()
sagemaker_session = sagemaker.Session()

In [None]:
sdxl_model = ModelPackage(
    role=ROLE,
    model_package_arn=package_arn,
    sagemaker_session=sagemaker_session,
    predictor_cls=StabilityPredictor
)

instance_type = "ml.g5.2xlarge" 
deployed_model = sdxl_model.deploy(
    initial_instance_count=1,
    instance_type=instance_type,
    endpoint_name=endpoint_name
)

### Validate endpoint: `text-to-image`

API reference [here](https://platform.stability.ai/docs/features/api-parameters)

In [None]:
def decode_and_show(model_response: GenerationResponse) -> None:
    """
    Decodes and displays an image from SDXL output

    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.

    Returns:
        None
    """
    image = model_response.artifacts[0].base64
    image_data = base64.b64decode(image.encode())
    image = Image.open(io.BytesIO(image_data))
    display(image)

    
def save_img_to_png(model_response: GenerationResponse, img_path: str):
    """
    Save the image from SDXL output
    
    Args:
        model_response (GenerationResponse): The response object from the deployed SDXL model.
        img_path (str): The path where the image will be saved onto
        
    Returns:
        None
    """
    image = model_response.artifacts[0].base64
    image_data = base64.b64decode(image.encode())
    image = Image.open(io.BytesIO(image_data))
    image.save(img_path)


In [None]:
text = "Landscape of buildings, multiple buildings, dark, raining, glimpse, quiet, 8k, realistic"

output = deployed_model.predict(
    GenerationRequest(
        text_prompts=[TextPrompt(text=text)],
        height=1024,
        width=1024,
        steps=100,
        cfg_scale=12,
    )
)
decode_and_show(output)
save_img_to_png(output, 'out_img/dark_raining_landscape.png')

Explore multiple `TextPrompt` and `negative_prompt` API.

In [None]:
output = deployed_model.predict(
    GenerationRequest(
        text_prompts=[
            TextPrompt(text='beautiful night sky above japanese town', weight=0.9),
            TextPrompt(text='Snow falling down, night sky', weight=0.5),
            TextPrompt(text='Clouds', weight=-1)
        ],
        style_preset="anime",
        height=1024,
        width=1024,
        negative_prompts=['people in image', 'cloud', 'poorly rendered'],
    )
)
decode_and_show(output)
save_img_to_png(output, 'out_img/night_japan_town.png')

In [None]:
text = "A cute fluffy white cat stands on its hind legs, peering curiously into an ornate golden mirror. In the reflection the cat sees itself."

negative_prompts = ['distorted cat features', 'distorted lion features', 'poorly rendered']

output = deployed_model.predict(
    GenerationRequest(
        text_prompts=[TextPrompt(text=text)],
        style_preset="neon-punk",
        seed=4343434,
        height=640,
        width=1536,
        steps=150,
        cfg_scale=7,
        negative_prompts=negative_prompts
    )
)
decode_and_show(output)

### Validate endpoint: `image-to-image`

I use both public cat image (from stability.ai) and my personal cat photo to demonstrate this. Feel free to adjust different prompt parameters to test it out!

In [None]:
!wget https://platform.stability.ai/Cat_August_2010-4.jpg
display(Image.open('Cat_August_2010-4.jpg'))

In [None]:
def encode_image(image_path: str, resize: bool = True, size: Tuple[int, int] = (1024, 1024)) -> Union[str, None]:
    """
    Encode an image as a base64 string, optionally resizing it to a supported resolution.

    Args:
        image_path (str): The path to the image file.
        resize (bool, optional): Whether to resize the image. Defaults to True.

    Returns:
        Union[str, None]: The encoded image as a string, or None if encoding failed.
    """
    assert os.path.exists(image_path)

    if resize:
        image = Image.open(image_path)
        image = image.resize(size)
        image.save("image_path_resized.png")
        image_path = "image_path_resized.png"
        
    image = Image.open(image_path)
    assert image.size == size
    with open(image_path, "rb") as image_file:
        img_byte_array = image_file.read()
        # Encode the byte array as a Base64 string
        try:
            base64_str = base64.b64encode(img_byte_array).decode("utf-8")
            return base64_str
        except Exception as e:
            print(f"Failed to encode image {image_path} as base64 string.")
            print(e)
            return None
    

In [None]:
cat_path = "Cat_August_2010-4.jpg"

size = (1536, 640)
cat_data = encode_image(cat_path, size=size)

output = deployed_model.predict(
    GenerationRequest(
        text_prompts=[TextPrompt(text="sketch cat image")],
        init_image=cat_data,
        cfg_scale=9,
        image_strength=0.8,
        seed=42,
        height=size[0],
        width=size[1],
        init_image_mode="STEP_SCHEDULE"
    )
)
decode_and_show(output)

In [None]:
cat_path = 'my-cat.jpg'
size = (1024, 1024)
cat_data = encode_image(cat_path, size=size)

output = deployed_model.predict(
    GenerationRequest(
        text_prompts=[TextPrompt(text="looking up, playful, handsome looking, anime style")],
        init_image=cat_data,
        cfg_scale=20,
        steps=130,
        image_strength=.4,
        height=size[0],
        width=size[1],
        style_preset='anime'
    )
)
decode_and_show(output)

## Optional: Delete model and endpoint
---

You can use the below command to get the `EndpointName` for deletion.

In [None]:
!aws sagemaker list-endpoints

In [None]:
deployed_model.delete_model()
deployed_model.delete_endpoint()