## Setup Hosting Container

For production workload, we recommend to built a custom MME container with stable diffusion base model and custom packages pre-installed. This has 2 advantages over the other solution which extend the container on the fly wiht the setup model:

1) Multi-instance scalability: as of today, there is no control over placement of the setup model behind MME endpoints. Therefore when you scale to multiple instances, it's not posible to guarantee you can preload the base stable diffusion model and conda environment on the each instance. Custom containers preloads all shared components and ensure they are available on every instance.

2) Improve cold start: when we invoke a MME model for the first time, every single model will install the conda environment leads to unnecessary overhead. Using custom container, we will directly install the package onto the container. This shave off 10-20s when cold start a model and reduce the redundency of installing the same conda package for each model.

This notebook is tested on a `ml.g4dn.2xlarge` SageMaker notebook instance using a `conda_pytorch_p310` kernel. **DO NOT use SageMaker Studio**

In [None]:
!pip install -Uq nvidia-pyindex 
!pip install -Uq tritonclient[http]
!pip install -Uq sagemaker ipywidgets pillow numpy 
!pip install -Uq transformers==4.26
# !pip install -Uq diffusers==0.25.0 Use newer version of diffusers for LCM
!pip install -Uq diffusers==0.21.4
!pip install -Uq accelerate==0.22.0

In [None]:
import boto3
import sagemaker
from sagemaker import get_execution_role

import tritonclient.http as httpclient
from tritonclient.utils import *
import time
from PIL import Image
import numpy as np

# variables
s3_client = boto3.client("s3")

# sagemaker variables
role = get_execution_role()
sm_client = boto3.client(service_name="sagemaker")
runtime_sm_client = boto3.client("sagemaker-runtime")
sagemaker_session = sagemaker.Session(boto_session=boto3.Session())
region = sagemaker_session.boto_region_name
account = sagemaker_session.account_id()
bucket = sagemaker_session.default_bucket()

prefix = "stable-diffusion-dreambooth"

### Import and Save Stable Diffusion Model

uncomment the code to use LCM. this is distill version of SD to enable fast inference. Please read more [here](https://huggingface.co/docs/diffusers/main/en/using-diffusers/inference_with_lcm_lora)

In [None]:
from diffusers import (
    DiffusionPipeline, 
#     UNet2DConditionModel, 
#     LCMScheduler
)

import torch 

model_name_base = "stabilityai/stable-diffusion-xl-base-1.0"
# lcm_unet_id = "latent-consistency/lcm-sdxl"

# unet = UNet2DConditionModel.from_pretrained(
#     lcm_unet_id,
#     torch_dtype=torch.float16,
#     variant="fp16",
# )
    
pipe = DiffusionPipeline.from_pretrained(
    model_name_base,
#     unet=unet,
    torch_dtype=torch.float16,
)
# pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config)

In [None]:
sd_dir = 'stable_diff'
pipe.save_pretrained(sd_dir)

In [None]:
import os
import tarfile

sd_tar = f"docker/{sd_dir}.tar.gz"

def make_tarfile(output_filename, source_dir):
    with tarfile.open(output_filename, "w:gz") as tar:
        tar.add(source_dir, arcname=os.path.basename(source_dir))

make_tarfile(sd_tar, sd_dir)

### Extend SageMaker Managed Triton Container

In [None]:
# account mapping for SageMaker Triton Image
account_id_map = {
    "us-east-1": "785573368785",
    "us-east-2": "007439368137",
    "us-west-1": "710691900526",
    "us-west-2": "301217895009",
    "eu-west-1": "802834080501",
    "eu-west-2": "205493899709",
    "eu-west-3": "254080097072",
    "eu-north-1": "601324751636",
    "eu-south-1": "966458181534",
    "eu-central-1": "746233611703",
    "ap-east-1": "110948597952",
    "ap-south-1": "763008648453",
    "ap-northeast-1": "941853720454",
    "ap-northeast-2": "151534178276",
    "ap-southeast-1": "324986816169",
    "ap-southeast-2": "355873309152",
    "cn-northwest-1": "474822919863",
    "cn-north-1": "472730292857",
    "sa-east-1": "756306329178",
    "ca-central-1": "464438896020",
    "me-south-1": "836785723513",
    "af-south-1": "774647643957",
}



region = boto3.Session().region_name
if region not in account_id_map.keys():
    raise ("UNSUPPORTED REGION")

base = "amazonaws.com.cn" if region.startswith("cn-") else "amazonaws.com"
mme_triton_image_uri = (
    "{account_id}.dkr.ecr.{region}.{base}/sagemaker-tritonserver:22.12-py3".format(
        account_id=account_id_map[region], region=region, base=base
    )
)
triton_account_id = account_id_map[region]
mme_triton_image_uri

Preview docker file

In [None]:
!cat docker/Dockerfile

Create new container

In [None]:
# Change this var to change the name of new container image
new_image_name = f"sagemaker-tritonserver-{prefix}-prod"

In [None]:
%%capture build_output
!cd docker && bash build_and_push.sh "$new_image_name" "latest" "$mme_triton_image_uri" "$region" "$account" "$triton_account_id"

In [None]:
print(build_output)
if 'Error response from daemon' in str(build_output):    
    raise SystemExit('\n\n!!There was an error with the container build!!')
else:
    extended_triton_image_uri = str(build_output).strip().split('\n')[-1]

Store new container image uri from ECR

In [None]:
%store extended_triton_image_uri
extended_triton_image_uri