In [None]:
import os

from huggingface_hub import snapshot_download

local_dir = './model'
snapshot_download(
    repo_id="stabilityai/stable-diffusion-xl-base-1.0",
    allow_patterns="sd_xl_base_1.0.safetensors",
    local_dir=local_dir,
    local_dir_use_symlinks=False)
snapshot_download(
    repo_id="stabilityai/stable-diffusion-xl-refiner-1.0",
    allow_patterns="sd_xl_refiner_1.0.safetensors",
    local_dir=local_dir,
    local_dir_use_symlinks=False)

In [None]:
%%bash

aws s3 cp s3://sagemaker-sdxl/output/1725806787/lora-trained-xl/pytorch_lora_weights.safetensors model/


### 3. Package and upload model archive

In [None]:
import boto3
import sagemaker

sess = sagemaker.Session()
# sagemaker session bucket -> used for uploading data, models and logs
# sagemaker will automatically create this bucket if it not exists
sagemaker_session_bucket = None
if sagemaker_session_bucket is None and sess is not None:
    # set to default bucket if a bucket name is not given
    sagemaker_session_bucket = sess.default_bucket()

try:
    role = sagemaker.get_execution_role()
except ValueError:
    iam = boto3.client('iam')
    role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']

sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)

print(f"sagemaker role arn: {role}")
print(f"sagemaker bucket: {sess.default_bucket()}")
print(f"sagemaker session region: {sess.boto_region_name}")

In [16]:
# Rerun this cell only if you need to re-upload the weights, otherwise you can reuse the existing model_package_name and upload only your new code
from sagemaker.utils import name_from_base

# You may want to make this a fixed name of your choosing instead
model_package_name = name_from_base(f"sdxl-v1")
model_uri = f's3://{sagemaker_session_bucket}/{model_package_name}/'

In [None]:
print(f'Uploading base model to {model_uri}, this will take a while...')
!aws s3 cp model/sd_xl_base_1.0.safetensors {model_uri}
print(f'Uploading refiner model to {model_uri}, this will take a while...')
!aws s3 cp model/sd_xl_refiner_1.0.safetensors {model_uri}
print(f'Uploading LoRA weights to {model_uri}, this will take a while...')
!aws s3 cp model/pytorch_lora_weights.safetensors {model_uri}

In [None]:
# Rerun this cell when you have changed the code or are uploading a fresh copy of the weights
print(f'Uploading code to {model_uri}code')
!aws s3 cp model/code/inference.py {model_uri}code/inference.py
!aws s3 cp model/code/requirements.txt {model_uri}code/requirements.txt
print("Done!")

### 4. Create and deploy a model and perform real-time inference


In [19]:
# Please only use regions with g5 instance support, mentioned at the top of this page
inference_image_uri_region = "us-east-1"

inference_image_uri_region_acct = "763104351884"

inference_image_uri = f"{inference_image_uri_region_acct}.dkr.ecr.{inference_image_uri_region}.amazonaws.com/stabilityai-pytorch-inference:2.0.1-sgm0.1.0-gpu-py310-cu118-ubuntu20.04-sagemaker"

In [None]:
endpoint_name = name_from_base(f"sdxl-v1")
sagemaker_client = boto3.client('sagemaker')
create_model_response = sagemaker_client.create_model(
    ModelName=endpoint_name,
    ExecutionRoleArn=role,
    PrimaryContainer={
        "Image": inference_image_uri,
        "ModelDataSource": {
            "S3DataSource": {               # S3 Data Source configuration:
                "S3Uri": model_uri,         # path to your model and script
                "S3DataType": "S3Prefix",   # causes SageMaker to download from a prefix
                "CompressionType": "None"   # disables compression
            }
        }
    }
)

create_endpoint_config_response = sagemaker_client.create_endpoint_config(
    EndpointConfigName=endpoint_name,
    ProductionVariants=[{
        "ModelName": endpoint_name,
        "VariantName": "sdxl",
        "InitialInstanceCount": 1,
        "InstanceType": "ml.g5.2xlarge",     # 4xlarge is required to load the model
    }]
)


deploy_model_response = sagemaker_client.create_endpoint(
    EndpointName=endpoint_name,
    EndpointConfigName=endpoint_name
)

print('Waiting for the endpoint to be in service, this can take 5-10 minutes...')
waiter = sagemaker_client.get_waiter('endpoint_in_service')
waiter.wait(EndpointName=endpoint_name)
print(f'Endpoint {endpoint_name} is in service, but the model is still loading. This may take another 5-10 minutes.')

In [None]:
from sagemaker.deserializers import BytesDeserializer
from sagemaker.predictor import Predictor
from sagemaker.serializers import JSONSerializer

# Create a predictor with proper serializers
deployed_model = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
    # serializer=JSONSerializer(),
    # deserializer=BytesDeserializer(accept="image/png")
)

---
# Test Run

In [None]:
import base64
import io
import json

from diffusers.utils import load_image
from PIL import Image


def image_to_base64(image: Image.Image) -> str:
    """Convert a PIL Image to a base64 string"""
    buffer = io.BytesIO()
    image.save(buffer, format="PNG")
    return base64.b64encode(buffer.getvalue()).decode("utf-8")


def base64_to_image(base64_image: str) -> Image.Image:
    """Convert a base64 string to a PIL Image"""
    return Image.open(io.BytesIO(base64.decodebytes(bytes(base64_image, "utf-8"))))

In [None]:
# txt2img
body = json.dumps(
    {
        "text_prompts": [{"text": "a sks dog sitting on a bench"}],
    }
)

response = deployed_model.predict(body)

response_body = json.loads(response)
artifacts = response_body["artifacts"]

image = base64_to_image(artifacts)
image.save("out1.png")

In [None]:
# img2img
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"

init_image = load_image(url).convert("RGB")

body = json.dumps(
    {
        "text_prompts": [{"text": "a sks dog sitting on a bench"}],
        "init_image": image_to_base64(init_image),
        # "image_strength": 0.8
    }
)

response = deployed_model.predict(body)

response_body = json.loads(response)
artifacts = response_body["artifacts"]

image = base64_to_image(artifacts)
image.save("out2.png")

In [None]:
# inpaint
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")

body = json.dumps(
    {
        "text_prompts": [{"text": "cat"}],
        "init_image": image_to_base64(init_image),
        "mask_image": image_to_base64(mask_image),
        # "image_strength": 0.8
    }
)

response = deployed_model.predict(body)

response_body = json.loads(response)
artifacts = response_body["artifacts"]

image = base64_to_image(artifacts)
image.save("out3.png")