論文<br>
https://arxiv.org/abs/2302.03668<br>
<br>
GitHub<br>
https://github.com/YuxinWenRick/hard-prompts-made-easy<br>
<br>
<a href="https://colab.research.google.com/github/kaz12tech/ai_demos/blob/master/PEZDispenser_demo.ipynb" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 環境セットアップ

## GPU確認

In [None]:
!nvidia-smi

## GitHubからコード取得

In [None]:
%cd /content

!git clone https://github.com/YuxinWenRick/hard-prompts-made-easy.git

# using Commits on Feb 10, 2023
%cd /content/hard-prompts-made-easy
!git checkout 9d6254a77fa4aa440cc83507cfaf210b35204d16

## ライブラリのインストール

In [None]:
%cd /content/hard-prompts-made-easy

# for PEZ dispenser
!pip3 install transformers==4.23.1 sentence-transformers==2.2.2 ftfy==6.1.1 mediapy==1.1.2 diffusers==0.11.1

## ライブラリのインポート

In [None]:
%cd /content/hard-prompts-made-easy

import argparse

import torch
device = 'cuda' if torch.cuda.is_available() else "cpu"
print("using device is", device)

import open_clip
from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline

import mediapy as media

from optim_utils import (download_image, optimize_prompt)

# 学習済みモデルのセットアップ

## load CLIP

In [None]:
clip_model = 'ViT-H-14'
clip_pretrain = 'laion2b_s32b_b79k'

# load clip model
model, _, clip_preprocess = open_clip.create_model_and_transforms(
    clip_model, 
    pretrained = clip_pretrain,
    device=device)

## load Stable Diffusion

In [None]:
model_id = "stabilityai/stable-diffusion-2-1-base"

# load scheduler
scheduler = DPMSolverMultistepScheduler.from_pretrained(
    model_id, 
    subfolder="scheduler")
# load stable diffusion model
pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    scheduler=scheduler,
    torch_dtype=torch.float16,
    revision="fp16")
pipe = pipe.to(device)
# setting image length
image_length = 512

# テスト画像のセットアップ

In [None]:
# define image urls
urls = [
        "https://a.1stdibscdn.com/alexander-averin-paintings-pony-riding-on-the-beach-post-impressionist-style-oil-painting-for-sale-picture-6/a_7443/a_28523631526593507117/Promenade_detalle_5_master.JPG?disable=upscale&auto=webp&quality=60&width=1318",
       ]

# download image
orig_images = list(filter(None,[download_image(url) for url in urls]))
# show image
media.show_images(orig_images, height=512)

# Image to Optimaize Prompt


## Setting args

In [None]:
args = argparse.Namespace()

args.prompt_len = 16
args.iter = 1500
args.lr = 0.1
args.weight_decay = 0.1
args.prompt_bs = 1
args.print_step = 100
args.batch_size = 1
args.clip_model = clip_model
args.clip_pretrain = clip_pretrain

args

## optimize prompt

In [None]:
# target imageを表現する最適なpromptの探索
learned_prompt = optimize_prompt(
    model, 
    clip_preprocess, 
    args, 
    device, 
    target_images = orig_images)

## Optimize prompt to Image

In [None]:
prompt = learned_prompt

In [None]:
num_images = 4
guidance_scale = 9
num_inference_steps = 25

images = pipe(
    prompt,
    num_images_per_prompt = num_images,
    guidance_scale = guidance_scale,
    num_inference_steps = num_inference_steps,
    height = image_length,
    width = image_length,
    generator = torch.Generator(device).manual_seed(0)
    ).images

In [None]:
print("prompt:", prompt)
media.show_images(images, width=128)

# Prompt to Optimize Prompt

In [None]:
target_prompts = [
    'Very detailed and colorful wall art pictures'
]
print(target_prompts)

## optimize prompt

In [None]:
learned_prompt = optimize_prompt(
    model, 
    clip_preprocess, 
    args, 
    device, 
    target_prompts = target_prompts)

In [None]:
print(learned_prompt)

## Optimize prompt to Image

### Before Optimize

In [None]:
num_images = 4
guidance_scale = 9
num_inference_steps = 25

images = pipe(
    target_prompts,
    num_images_per_prompt = num_images,
    guidance_scale = guidance_scale,
    num_inference_steps = num_inference_steps,
    height = image_length,
    width = image_length,
    generator = torch.Generator(device).manual_seed(0)
    ).images

In [None]:
print("prompt:", target_prompts)
media.show_images(images, width=128)

### After Optimize

In [None]:
num_images = 4
guidance_scale = 9
num_inference_steps = 25

images = pipe(
    learned_prompt,
    num_images_per_prompt = num_images,
    guidance_scale = guidance_scale,
    num_inference_steps = num_inference_steps,
    height = image_length,
    width = image_length,
    generator = torch.Generator(device).manual_seed(0)
    ).images

In [None]:
print("prompt:", learned_prompt)
media.show_images(images, width=128)