# Stable Diffusion with Distributed Training and Hosting on Amazon SageMaker

In this notebook, you will learn how you can fine-tune a pretrained [Stable Diffusion](https://stability.ai/blog/stable-diffusion-public-release) model on SageMaker and deploy it for inference.

Produced by Stability.ai, Stable Diffusion is an open source model available for researchers and the broader ML community. We're pointing to the core content available on Hugging Face [here](https://huggingface.co/CompVis/stable-diffusion-v1-4) and provide private access in the limited context of hands-on workshops. If you'd like longer term access to Stable Diffusion, you'll need to sign up on the Hugging Face Hub, accept the terms, create a token, and download the model and dataset. 

In this lab, we've done that for you already. So let's get started!


This notebook is purely educational for showing how to fine-tune latent-stable-diffusion on Amazon SageMaker. Neither the images produced or code represent Amazon or its views in any way shape or form. To properly leverage this codebase, read the corresponding licenses from [CompVis](https://huggingface.co/spaces/CompVis/stable-diffusion-license) (the model) and Pokemon captions dataset.

Model weights were provided by CompVis/stable-diffusion-v1-4. You can find the licensing, README and more [here](https://huggingface.co/CompVis/stable-diffusion-v1-4).  Please note that the finetune.py script has been slightly modified from a PR request [here](https://github.com/huggingface/diffusers/pull/356)

In [None]:
%pip install -U sagemaker

In [None]:
!pip install transformers==4.21.3 datasets==2.5.2

### step 1. Inspect the Dataset

Lets take a look at the dataset we will use to fine tune the stable diffusion model. For this exercise we will use the pokemon blip captions dataset. 

In [None]:
from datasets import load_dataset

dataset = load_dataset("lambdalabs/pokemon-blip-captions")

#### Lets plot some sample images and captions

In [None]:
import matplotlib.pyplot as plt

%matplotlib inline

def plot_image(text, image):
    plt.figure()
    plt.title(text)
    plt.imshow(image)

for i in range(5):  
    plot_image(dataset['train'][i]['text'],dataset['train'][i]['image'])

### Step 2. Run distributed training on Amazon SageMaker
Next, let's configure the scripts to run on SageMaker training jobs with high performance GPU's. First, you'll need to determine which instances to use. We'd suggest you start with the `ml.g5.12xlarge`, which has 4 GPUs and is known to work nicely with this training script and dataset.

The training script we're working with today uses Hugging Face's [`accelerate`](https://huggingface.co/docs/accelerate/index) library to run data parallel on all available GPUs. While likely not as performant on AWS as [SageMaker Distributed Data Parallel](https://docs.aws.amazon.com/sagemaker/latest/dg/data-parallel.html), it's still an easy and efficient way to run data parallel on SageMaker Training. 

In [None]:
instance_type = 'ml.g5.12xlarge'

#### Point to an AWS-managed Deep Learning Container
At AWS we provide 70+ prebuilt containers that are battle-tested, and known to run efficiently across SageMaker instances and features.

Available images are listed here: https://github.com/aws/deep-learning-containers/blob/master/available_images.md 

You're welcome to bring your own Dockerfile, and either extend an AWS Deep Learning Container, or simply add the [sagemaker-training toolkit](https://github.com/aws/sagemaker-training-toolkit) to enable remote training job features like script-mode, local mode, distributed training, etc.

In [None]:
import sagemaker
from sagemaker.huggingface import HuggingFace
from sagemaker.pytorch import PyTorch

def get_estimator(instance_type):
    
    sess = sagemaker.Session()

    role = sagemaker.get_execution_role()

    est = PyTorch(entry_point='train.py',
                      source_dir='scripts',
                      framework_version="1.13.1",
                      sagemaker_session=sess,
                      role=role,
                      instance_type=instance_type,
                      keep_alive_time_in_seconds = 3600,
                      # output_path = can define s3 output here,
                      py_version='py39',
                      base_job_name='stable-diffusion', 
                      instance_count=1,
                      checkpoint_local_path="",
                      # all opt/ml paths point to SageMaker training 
                      hyperparameters={
                        'pretrained_model_name_or_path':'CompVis/stable-diffusion-v1-4',
                        'dataset_name':'lambdalabs/pokemon-blip-captions',
                        'caption_column':'text',
                        'image_column':'image',
                        'resolution':256,
                        'mixed_precision':'fp16',
                        'train_batch_size':2,
                        'learning_rate': '1e-10',
                        'max_train_steps':100,
                        'num_train_epochs':1,
                        'seed':100,
                        'output_dir':'/opt/ml/model/sd-output-final',   
                      },
                      distribution={"pytorchddp":{"enabled": True}},
                   
                )
    return est

est = get_estimator(instance_type)

#### Start the training job

In [None]:
# Please note training can take upwards of 25 minutes (13 minutes for saving the model). 
# only run this cell ONCE!
est.fit(wait=True)

### Step 3. Distributed Inference
Next, we'll point to the model we just trained in the previous step and use it to spin up a SageMaker endpoint.

In [None]:
# define from the S3 path if you need to manually point to your model artifact
# SageMaker hosting will want to see the model artifact be wrapped in tar.gz format
#model_data = ''

In [None]:
from sagemaker.huggingface import HuggingFaceModel
import sagemaker

role = sagemaker.get_execution_role()

sess = sagemaker.Session()

# hard code point to an image we're hosting for this workshop
image_uri = '911195073761.dkr.ecr.us-east-1.amazonaws.com/sd-inference-gpu:latest'

est=HuggingFaceModel(role=role,
                     py_version='py38',
                     model_data=est.model_data,
                     image_uri=image_uri,
                     sagemaker_session= sess,
                     # set this to the number of GPUs in the intance type you'd like to use
                     model_server_workers= 1
)

Deploy your model for inference!

In [None]:
pred = est.deploy(instance_type='ml.g5.2xlarge',
                  initial_instance_count=1)

Provide prompts for training. The first text argument is based on this current dataset.

In [None]:
prompts = ['A drawing of a green pokemon with red eyes.', 
           'A pokemon wearing karate belt in the snow.', 
           'pokemon swimming at the Olympics 400m Butter-fly event.',
           'A pokemon is kicking a soccer ball.']

For more parameters feel free to explore [here](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion), just add 'parameters':{'key':'value'} to the input dict.

In [None]:
outputs = [pred.predict({'inputs':prompt}) \
           for prompt in prompts]

In [None]:
outputs = [output['images'][0] for output in outputs]

In [None]:
from PIL import Image
from io import BytesIO
import base64

def process_result(out):
    return Image.open(BytesIO(base64.b64decode(out)))

In [None]:
images = [[process_result(output),prompt] for output,prompt in zip(outputs,prompts)]

#### Visualize the results from the inference

In [None]:
import matplotlib.pyplot as plt

for i in range(len(images)):
    plt.figure()
    plt.title(images[i][1])
    plt.imshow(images[i][0])

#### Generate images from text
Now let's test the results line by line!

In [None]:
prompt = "a beautiful hot arabian desert"

output = pred.predict({'inputs':prompt})
process_result(output['images'][0])

In [None]:
prompt = "a delicious arabian dessert"

output = pred.predict({'inputs':prompt})
process_result(output['images'][0])

In [None]:
# clean up your endpoint
#pred.delete_endpoint()