# 스테이블 디퓨전 파인튜닝

## 1. 의존성 설치

In [1]:
!pip -q install wandb datasets torchvision

## 2. 환경설정
- 의존성 설정
- 데이터 폴더 및 파인튜닝용 프롬프트 생성 (dataset_info.json)

In [2]:
# for common
import os
import json
from pathlib import Path

import wandb
from datasets import load_dataset, Image
import boto3, botocore
import sagemaker
from sagemaker import get_execution_role

# for resizing
from PIL import Image
from torchvision import transforms

# for training
from sagemaker import image_uris, model_uris, script_uris

from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.parameter import ContinuousParameter, IntegerParameter
from sagemaker.tuner import HyperparameterTuner

# for inference
import matplotlib.pyplot as plt
import numpy as np
from sagemaker.predictor import Predictor

In [3]:
try:
    aws_role = sagemaker.get_execution_role()
except:
    iam = boto3.client("iam")
    # TODO: replace with your role name (i.e. "AmazonSageMaker-ExecutionRole-20211014T154824")
    aws_role = iam.get_role(RoleName="<replace with your RoleName>")["Role"]["Arn"]

boto_session = boto3.Session()
aws_region = boto_session.region_name
sess = sagemaker.Session(boto_session=boto_session)

print(aws_role)
print(aws_region)
print(sess.boto_region_name)

Couldn't call 'get_role' to get Role ARN from role name dongkyl to get Role path.


ClientError: An error occurred (ValidationError) when calling the GetRole operation: The specified value for roleName is invalid. It must contain only alphanumeric characters and/or the following: +=,.@_-

## 3. 데이터셋 확인

- 로컬에 준비된 이미지를 사용하거나,
- [haandol/icon](https://huggingface.co/datasets/haandol/icon) 데이터셋 사용.
- [Unsplash](https://unsplash.com) 에서 다운받은 라이센스 프리 이미지 데이터셋


In [79]:
local_training_dataset_folder= Path("text_to_image_training_images")
if not os.path.exists(local_training_dataset_folder):
    os.mkdir(local_training_dataset_folder)

In [80]:
use_local_images = False

if not use_local_images:
    dataset_name = 'haandol/icon'
    dataset = load_dataset(dataset_name, split='train', download_mode='force_redownload')
    
    metadata = []
    for i, datum in enumerate(dataset):
        fn = f'{i}'.zfill(3) + '.jpg'
        datum['image'].convert('RGB').save(local_training_dataset_folder / fn)
        datum['text']
        metadata.append(json.dumps({
            'file_name': fn,
            'text': datum['text']
        }))
    
    with open(local_training_dataset_folder / 'metadata.jsonl', 'w') as fp:
        fp.write('\n'.join(metadata))



In [81]:
validation_prompts = [
    'a photo of red telephone in front of the Eiffel tower',
    'a photo of black telephone handset in front of the Eiffel tower',
    'a white speech bubble in front of the Eiffel tower',
    'a white speech bubble on black background',
    'a woman profile picture in front of the Eiffel tower, looking at viewer',
    'a man profile picture in front of the Eiffel tower, looking at viewer',
    'a white closed envelope on black background',
    'a red closed envelope on white background',
    'a red closed envelope on white background, flat vector',
    'a telephone handset on white background, flat vector',
]
# Instance prompt is fed into the training script via dataset_info.json present in the training folder. Here, we write that file.
with open(os.path.join(local_training_dataset_folder, "dataset_info.json"), "w") as f:
    f.write(json.dumps({
        'validation_prompts': validation_prompts
    }))

### 데이터셋 업로드
- 버킷 생성

In [82]:
mySession = boto3.session.Session()
AwsRegion = mySession.region_name
account_id = boto3.client("sts").get_caller_identity().get("Account")

s3 = boto3.client("s3")
s3.download_file(
    f"jumpstart-cache-prod-{AwsRegion}",
    "ai_services_assets/custom_labels/cl_jumpstart_ic_notebook_utils.py",
    "utils.py",
)
from utils import create_bucket_if_not_exists

In [83]:
training_bucket = f"stable-diffusion-text-to-image-{AwsRegion}-{account_id}"

create_bucket_if_not_exists(training_bucket)

train_s3_path = f"s3://{training_bucket}/text-to-image/"

Using an existing bucket stable-diffusion-text-to-image-us-east-1-981794133797


In [84]:
!aws s3 cp --recursive $local_training_dataset_folder $train_s3_path

upload: text_to_image_training_images/.ipynb_checkpoints/metadata-checkpoint.jsonl to s3://stable-diffusion-text-to-image-us-east-1-981794133797/text-to-image/.ipynb_checkpoints/metadata-checkpoint.jsonl
upload: text_to_image_training_images/002.jpg to s3://stable-diffusion-text-to-image-us-east-1-981794133797/text-to-image/002.jpg
upload: text_to_image_training_images/.ipynb_checkpoints/000-checkpoint.jpg to s3://stable-diffusion-text-to-image-us-east-1-981794133797/text-to-image/.ipynb_checkpoints/000-checkpoint.jpg
upload: text_to_image_training_images/.ipynb_checkpoints/013-checkpoint.jpg to s3://stable-diffusion-text-to-image-us-east-1-981794133797/text-to-image/.ipynb_checkpoints/013-checkpoint.jpg
upload: text_to_image_training_images/.ipynb_checkpoints/dataset_info-checkpoint.json to s3://stable-diffusion-text-to-image-us-east-1-981794133797/text-to-image/.ipynb_checkpoints/dataset_info-checkpoint.json
upload: text_to_image_training_images/000.jpg to s3://stable-diffusion-text-

## 4. 세이지메이커 학습작업으로 파인튜닝

### 4.1. 트레이닝 파라미터 설정

- 기본 모델은 StableDiffusion 2.1 base 모델
- 세이지메이커 학습완료 모델의 output uri(s3) 로 변경하여 해당 모델에 추가학습 가능

In [85]:
train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

# Tested with ml.g4dn.2xlarge (16GB GPU memory) and ml.g5.2xlarge (24GB GPU memory) instances. Other instances may work as well.
# If ml.g5.2xlarge instance type is available, please change the following instance type to speed up training.
training_instance_type = "ml.g5.12xlarge"

# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)
train_model_uri

INFO:sagemaker.image_uris:image_uri is not presented, retrieving image_uri based on instance_type, framework etc.


---

- `/train/text-to-image/build_and_push.sh` 로 학습이미지 ECR 에 푸시 후
- 해당 학습용 이미지 주소를 아래에 채우기

---

In [87]:
tag = '65739bf0d3'
train_image_uri = f'981794133797.dkr.ecr.us-east-1.amazonaws.com/text-to-image-training-gpu:{tag}'
train_image_uri

'981794133797.dkr.ecr.us-east-1.amazonaws.com/text-to-image-training-gpu:7accbb38c1'

---

- 모델파일 저장될 위치 지정

---

In [88]:
output_bucket = sess.default_bucket()
output_prefix = "text-to-image-training"

s3_output_location = f"s3://{output_bucket}/{output_prefix}/output"
s3_output_location

's3://sagemaker-us-east-1-981794133797/text-to-image-training/output'

### 4.2. 하이퍼파라미터 

- diffusers [text_to_image 예제](https://github.com/huggingface/diffusers/blob/v0.16.1/examples/text_to_image/train_text_to_image.py) 의 입력 파라미터를 사용


In [None]:
hyperparameters = {
    'learning_rate': 1e-06,
    'max_train_steps': 15000,
    'num_train_epochs': 1,
    'batch_size': 1,
    'validation_steps': 50,
    'checkpointing_steps': 400,
    'checkpoints_total_limit': 10,
    'gradient_accumulation_steps': 4,
    'use_ema': True,
    'resolution': 512,
    'use_8bit_adam': True,
    'gradient_checkpointing': True,
    'lr_warmup_steps': 100,
    'input_perturbation': 0.1,
    'lr_scheduler': 'cosine_with_restarts',
    'lr_num_cycles': 3000, # 5 cycles
}
hyperparameters

{'use_ema': True,
 'resolution': 512,
 'center_crop': True,
 'random_flip': True,
 'learning_rate': 1e-05,
 'max_train_steps': 15000,
 'gradient_accumulation_steps': 4,
 'lr_warmup_steps': 0,
 'mixed_precision': 'fp16',
 'enable_xformers_memory_efficient_attention': True}

### 4.3. 학습하기

---

- HPO 를 이용하여 학습 (7개 작업을 동시에 진행)
- [FID](https://wandb.ai/wandb_fc/korean/reports/-Frechet-Inception-distance-FID-GANs---Vmlldzo0MzQ3Mzc) 를 이용하여 최적 학습결과 찾기
- Weights and Biases 를 이용할 경우 `use_wandb=True` 로 설정 하고 로그인

---

In [90]:
training_job_name = "sd-text-to-image"
training_job_name

'sd-text-to-image-2023-06-05-16-51-33-630'

In [None]:
s3_checkpoint_location = f"s3://{output_bucket}/{training_job_name}/checkpoints"
s3_checkpoint_location

In [91]:
use_wandb = True
if use_wandb:
    wandb.login()
    hyperparameters['report_to'] = 'wandb'
    environment = {
        'WANDB_API_KEY': '',  # Update API key
        'WANDB_PROJECT': "sd-text-to-image",
    }

In [None]:
use_hpo = False

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    model_uri=train_model_uri,
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    volume_size=128,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
    checkpoint_local_path="/opt/ml/checkpoints",
    checkpoint_s3_uri=s3_checkpoint_location,
    environment=environment if use_wandb else None,
)

if use_hpo:
    sd_estimator.set_hyperparameters(compute_fid="True")
    hyperparameter_ranges = {
        "learning_rate": ContinuousParameter(5e-7, 2e-6, "Linear"),
        "max_train_steps": IntegerParameter(4000, 12000, "Linear"),
    }
    metric_definitions = [
        {'Name': 'fid_score', 'Regex': 'fid_score=([-+]?\\d\\.?\\d*)'},
    ]
    tuner_parameters = {
        "estimator": sd_estimator,
        "metric_definitions": metric_definitions,
        "objective_metric_name": "fid_score",
        "objective_type": "Minimize",
        "hyperparameter_ranges": hyperparameter_ranges,
        "max_jobs": 14,
        "max_parallel_jobs": 7,
        "strategy": "Bayesian",
        "base_tuning_job_name": training_job_name,
    }

    bayesian_tuner = HyperparameterTuner(**tuner_parameters)
    bayesian_tuner.fit({"training": train_s3_path}, wait=False)
else:
    sd_estimator.fit({"training": train_s3_path}, logs=True)

INFO:sagemaker:Creating training-job with name: sd-text-to-image-2023-06-05-16-51-33-63-2023-06-05-16-51-34-659


2023-06-05 16:51:34 Starting - Starting the training job.

## 5. 세이지메이커 엔드포인트 생성

- 파인튜닝된 모델로 세이지메이커 엔드포인트 프로비저닝
- 엔드포인트 생성대신 `inference.ipynb` 노트북으로 직접 아이콘 생성도 가능

In [None]:
%time

inference_instance_type = "ml.g5.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

---

다른 엔드포인트로 생성할 경우 아래 코드 실행하여 predictor 생성

```python
endpoint_name = 'sd-dreambooth-txt2img-stab-2023-05-27-07-04-28-234'
finetuned_predictor = Predictor(
    endpoint_name=endpoint_name,
    sagemaker_session=sess,
)
```

---

In [None]:
endpoint_name = f'{training_job_name}-'

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = (bayesian_tuner if use_hpo else sd_estimator).deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)

---

사용가능한 파라미터 목록:

* **prompt**: prompt to guide the image generation. Must be specified and can be a string or a list of strings.
* **width**: width of the hallucinated image. If specified, it must be a positive integer divisible by 8.
* **height**: height of the hallucinated image. If specified, it must be a positive integer divisible by 8.
* **num_inference_steps**: Number of denoising steps during image generation. More steps lead to higher quality image. If specified, it must a positive integer.
* **guidance_scale**: Higher guidance scale results in image closely related to the prompt, at the expense of image quality. If specified, it must be a float. guidance_scale<=1 is ignored.
* **negative_prompt**: guide image generation against this prompt. If specified, it must be a string or a list of strings and used with guidance_scale. If guidance_scale is disabled, this is also disabled. Moreover, if prompt is a list of strings then negative_prompt must also be a list of strings. 
* **num_images_per_prompt**: number of images returned per prompt. If specified it must be a positive integer. 
* **seed**: Fix the randomized state for reproducibility. If specified, it must be an integer.
---

In [None]:
def query(model_predictor, text):
    """Query the model predictor."""

    encoded_text = text.encode("utf-8")

    query_response = model_predictor.predict(
        encoded_text,
        {
            "ContentType": "application/x-text",
            "Accept": "application/json",
        },
    )
    return query_response


def parse_response(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_image"], response_dict["prompt"]


def display_img_and_prompt(img, prmpt):
    """Display hallucinated image."""
    plt.figure(figsize=(12, 12))
    plt.imshow(np.array(img))
    plt.axis("off")
    plt.title(prmpt)
    plt.show()

In [None]:
def query_endpoint_with_json_payload(model_predictor, payload, content_type, accept):
    """Query the model predictor with json payload."""

    encoded_payload = json.dumps(payload).encode("utf-8")

    query_response = model_predictor.predict(
        encoded_payload,
        {
            "ContentType": content_type,
            "Accept": accept,
        },
    )
    return query_response


def parse_response_multiple_images(query_response):
    """Parse response and return generated image and the prompt"""

    response_dict = json.loads(query_response)
    return response_dict["generated_images"], response_dict["prompt"]

In [None]:
payload = {
    "prompt": "a photo of telephone handset",
    "negative_prompt": "cropped, shadow, out of frame, duplicate, watermark, signature, text, ugly, sketch, deformed, mutated, blurry",
    "width": 512,
    "height": 512,
    "num_images_per_prompt": 1,
    "num_inference_steps": 30,
    "guidance_scale": 7.5,
}

query_response = query_endpoint_with_json_payload(
    finetuned_predictor, payload, "application/json", "application/json"
)
generated_images, prompt = parse_response_multiple_images(query_response)
display_img_and_prompt(generated_images[0], prompt)