# Finetuning Stable Diffusion 2.1

***
Stable Diffusion is a text-to-image model that enables you to create photorealistic images from just a text prompt. A diffusion model trains by learning to remove noise that was added to a real image. This de-noising process generates a realistic image. These models can also generate images from text alone by conditioning the generation process on the text. For instance, Stable Diffusion is a latent diffusion where the model learns to recognize shapes in a pure noise image and gradually brings these shapes into focus if the shapes match the words in the input text.

Training and deploying large models and running inference on models such as Stable Diffusion is often challenging and include issues such as cuda out of memory, payload size limit exceeded and so on.  JumpStart simplifies this process by providing ready-to-use scripts that have been robustly tested. Furthermore, it provides guidance on each step of the process including the recommended instance types, how to select parameters to guide image generation process, prompt engineering etc. Moreover, you can deploy and run inference on any of the 80+ Diffusion models from JumpStart without having to write any piece of your own code.

In this notebook, you will learn how to use JumpStart to fine-tune the Stable Diffusion model to your dataset. This can be useful when creating art, logos, custom designs, NFTs, and so on, or fun stuff such as generating custom AI images of your pets or avatars of yourself.


Model license: By using this model, you agree to the [CreativeML Open RAIL-M++ license](https://huggingface.co/stabilityai/stable-diffusion-2/blob/main/LICENSE-MODEL).


Note: This notebook was tested on ml.t3.medium instance in Amazon SageMaker Studio with Python 3 (Data Science) kernel and in Amazon SageMaker Notebook instance with conda_python3 kernel.


### 1. Set Up

***

Before executing the notebook, there are some initial steps required for set up. This notebook requires ipywidgets and latest version of sagemaker.


***

In [None]:
!pip install ipywidgets==7.0.0 --quiet
!pip install --upgrade sagemaker

#### Permissions and environment variables

***
To host on Amazon SageMaker, we need to set up and authenticate the use of AWS services. Here, we use the execution role associated with the current notebook as the AWS account role with SageMaker access.

***

In [None]:
import sagemaker, boto3, json
from sagemaker import get_execution_role

aws_role = get_execution_role()
aws_region = boto3.Session().region_name
sagemaker_session = sagemaker.Session()
default_bucket = sagemaker_session.default_bucket()

---
## 2. Prepare the training images
In this section we will prepare an existing set of 18 images of a **cream longhair dacshund dog** called Peanut, which will be used for finetuning the Stable Diffusion 2.1 model. This will allow us to generate more images of Peanut, using the new finetuned model.

The model can be fine-tuned to any dataset of images. It works well even with as little as five training images.

The fine-tuning script is built on the script from [dreambooth](https://dreambooth.github.io/). The model returned by fine-tuning can be further deployed for inference. Below are the instructions for how the training data should be formatted.

- **Input:** A directory containing the instance images, `dataset_info.json` and (optional) directory `class_data_dir`.
  - Images may be of `.png` or `.jpg` or `.jpeg` format.
  - `dataset_info.json` file must be of the format {'instance_prompt':<<instance_prompt>>,'class_prompt':<<class_prompt>>}.
  - If with_prior_preservation = False, you may choose to ignore 'class_prompt'.
  - `class_data_dir` directory must have class images. If with_prior_preservation = True and class_data_dir is not present or there are not enough images already present in class_data_dir, additional images will be sampled with class_prompt.
- **Output:** A trained model that can be deployed for inference.

The s3 path should look like `s3://bucket_name/input_directory/`. Note the trailing `/` is required.

Here is an example format of the training data.

    input_directory
        |---instance_image_1.png
        |---instance_image_2.png
        |---instance_image_3.png
        |---instance_image_4.png
        |---instance_image_5.png
        |---dataset_info.json
        |---class_data_dir
            |---class_image_1.png
            |---class_image_2.png
            |---class_image_3.png
            |---class_image_4.png

**Prior preservation, instance prompt and class prompt:** Prior preservation is a technique that uses additional images of the same class that we are trying to train on.  For instance, if the training data consists of images of a particular dog, with prior preservation, we incorporate class images of generic dogs. It tries to avoid overfitting by showing images of different dogs while training for a particular dog. Tag indicating the specific dog present in instance prompt is missing in the class prompt. For instance, instance prompt may be "a photo of a riobugger cat" and class prompt may be \"a photo of a cat\". You can enable prior preservation by setting the hyper-parameter with_prior_preservation = True.


We provide default datasets of cat and dogs images. Cat dataset consists of eight images (instance images corresponding to instance prompt) of a single cat with no class images. It can be downloaded from [here](https://github.com/marshmellow77/dreambooth-sm/tree/main/training-images). If using the cat dataset, try the prompt "a photo of a riobugger cat" while doing inference in the demo notebook. Dog dataset consists of 12 images of a single dog with no class images. If using the dog dataset, try the prompt "a photo of a Doppler dog" while doing inference in the demo notebook.

License: [MIT](https://github.com/marshmellow77/dreambooth-sm/blob/main/LICENSE).

In [None]:
# unzip the image dataset
!unzip images/Peanut.zip -d images

In [None]:
# depict some of the images

from IPython.display import Image
Image(filename='images/Peanut/51276808947_e1036dbbcf_c.jpg') 


In [None]:
Image(filename='images/Peanut/51277531746_a1cdca5453_c.jpg') 

In [None]:
# adding a dataset info json file
# please see the instructions at the beggining of Section 2

import json

dataset_info = {
    'instance_prompt': 'A photo of Peanut',
    'class_prompt': 'A photo of cream longhair dachshund'
}

with open('images/Peanut/dataset_info.json', 'w') as fp:
    json.dump(dataset_info, fp)


In [None]:
# copy all files to S3

import boto3
import glob

ls_images = glob.glob('images/Peanut/*.*')

for i,filename in enumerate(ls_images):
    print('Copying image', i+1, 'out of', len(ls_images), 'to S3...\r', end='')
    sagemaker_session.upload_data(
        path=filename,
        bucket=default_bucket,
        key_prefix='stable-diffusion/Peanut'
    )

s3_uri_data = f's3://{default_bucket}/stable-diffusion/Peanut/'
print('\nFiles copied to:', s3_uri_data)

---
## 3. Finetuning 


### Set training-job parameters

These are the parameters related to the training job:
- **Training data path**. This is S3 folder in which the input data is stored.
- **Output path**: This the s3 folder in which the training output is stored.
- **Training instance type**: This indicates the type of machine on which to run the training. 


In [None]:
s3_path_training_dataset = s3_uri_data  # where the training data is
s3_path_output = f'{s3_path_training_dataset}output/'  # where the finetuned model will be stored
training_instance_type = "ml.p3.2xlarge"

print('Trainig instance:',training_instance_type)

### Set algorithm training hyperparameters parameters

For algorithm specific training hyper-parameters, we start by fetching a python dictionary of the training hyper-parameters that the algorithm accepts with their default values. This can then be overridden to custom values.


In [None]:
from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values
hyperparameters["max_steps"] = "400"
hyperparameters["epochs"] = "30"
print('Algorithm hyperparameters:\n',hyperparameters)

### Select or not Automatic Model Tuning

Amazon SageMaker automatic model tuning, also known as hyperparameter tuning, finds the best version of a model by running many training jobs on your dataset using the algorithm and ranges of hyperparameters that you specify. It then chooses the hyperparameter values that result in a model that performs the best, as measured by a metric that you choose. We will use a HyperparameterTuner object to interact with Amazon SageMaker hyperparameter tuning APIs. Here we tune 2 hyper-parameters `learning_rate` and `max_steps`.

Using Automatic Model Tuning may take significantly longer, but will result in a better model. 

In [None]:
use_amt = False  # use Automatic Model Tuning or not. (True takes longer!)

In [None]:
from sagemaker.tuner import IntegerParameter
from sagemaker.tuner import ContinuousParameter
from sagemaker.tuner import HyperparameterTuner

hyperparameter_ranges = {
    "learning_rate": ContinuousParameter(1e-7, 3e-6, "Linear"),
    "max_steps": IntegerParameter(50, 400, "Linear"),
}

### Retrieve training artifacts
We retrieve the training docker container, the training algorithm source, and the pre-trained base model. Note that model_version="*" fetches the latest model.

In [None]:
from sagemaker import image_uris, model_uris, script_uris
import sagemaker.metric_definitions

# Currently, not all the stable diffusion models in jumpstart support finetuning. Thus, we manually select a model
# which supports finetuning.
train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type
)

# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)
# Retrieve the default metric definitions to emit to CloudWatch Logs\n",
metric_definitions = sagemaker.metric_definitions.retrieve_default(
    model_id=train_model_id, 
    model_version=train_model_version,
)

### Start finetuning

We start by creating the estimator object with all the required assets and then launch the training job. 


In [None]:
from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTuner

training_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    metric_definitions=metric_definitions,
    hyperparameters=hyperparameters,
    output_path=s3_path_output,
    base_job_name=training_job_name,
)


if use_amt:
    # Let estimator emit fid_score metric to AMT
    sd_estimator.set_hyperparameters(compute_fid="True")
    tuner_parameters = {
        "estimator": sd_estimator,
        "metric_definitions": [{"Name": "fid_score", "Regex": "fid_score=([-+]?\\d\\.?\\d*)"}],
        "objective_metric_name": "fid_score",
        "objective_type": "Minimize",
        "hyperparameter_ranges": hyperparameter_ranges,
        "max_jobs": 3,
        "max_parallel_jobs": 3,
        "strategy": "Bayesian",
        "base_tuning_job_name": training_job_name,
    }

    tuner = HyperparameterTuner(**tuner_parameters)
    tuner.fit({"training": s3_path_training_dataset}, logs=True)
else:
    # Launch a SageMaker Training job by passing s3 path of the training data
    sd_estimator.fit({"training": s3_path_training_dataset}, logs=True)

---
## 4. Deploy and run inference on the fine-tuned model

A trained model does nothing on its own. We now want to use the model to perform inference. For this example, that means predicting the bounding boxes of an image. We start by retrieving the jumpstart artifacts for deploying an endpoint. However, instead of base_predictor, we  deploy the `od_estimator` that we fine-tuned.


In [None]:
from sagemaker import instance_types

inference_instance_type = "ml.p3.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = (tuner if use_amt else sd_estimator).deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    image_uri=deploy_image_uri,
    endpoint_name=endpoint_name,
)

In [None]:
# helper functions in order to visualize model outputs

import matplotlib.pyplot as plt
import numpy as np


def query(model_predictor, text):
    """Query the model predictor."""

    encoded_text = text.encode("utf-8")

    query_response = model_predictor.predict(
        encoded_text,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_image"], response_dict["prompt"]


def display_img_and_prompt(img, prmpt):
    """Display hallucinated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()

Now that the finetuned model knows who Peanut is, we will try different prompts and create different images of Peanut. Use your imagination and write your own prompts.

**IMPORTANT**: please try each prompt mutliple times, in order to see the image variations that the model creates.

In [None]:
text = "Peanut wearing a yellow hat"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

In [None]:
text = "a painting portrait of Peanut"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

In [None]:
text = "Peanut as a pixar cartoon character"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

In [None]:
text = "Peanut riding a motorbike"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

In [None]:
text = "Peanut on the beach"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

In [None]:
text = "Peanut chasing a red ball"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

Next, we delete the endpoint corresponding to the finetuned model.

In [None]:
# Delete the SageMaker endpoint
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()