## 1 import

In [1]:
import sys, os
sys.path.append(os.path.join(os.getcwd(), 'modules'))

import torch
from PIL import Image
from diffusers import (
    AutoencoderKL, 
    UNet2DConditionModel, 
    # StableDiffusionPipeline,
    EulerAncestralDiscreteScheduler
)
from transformers import CLIPTextModel, CLIPTokenizer#, CLIPImageProcessor
from modified_imagic_stable_diffusion import ImagicStableDiffusionPipeline

In [2]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


## 2 load checkpoints and configs

In [3]:
sd_path = "runwayml/stable-diffusion-v1-5"
clip_path = "openai/clip-vit-large-patch14"

# runwayml/stable-diffusion-v1-5
vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(sd_path, subfolder="unet")
scheduler = EulerAncestralDiscreteScheduler.from_pretrained(sd_path, subfolder="scheduler")
safety_checker = None
# safety_checker = StableDiffusionPipeline.from_pretrained(sd_path, subfolder="safety_checker")

# openai/clip-vit-large-patch14
tokenizer = CLIPTokenizer.from_pretrained(clip_path)
text_encoder = CLIPTextModel.from_pretrained(clip_path)
feature_extractor = None
# feature_extractor = CLIPImageProcessor.from_pretrained(clip_path) # safety_checker => feature_extractor



In [4]:
vae = vae.to(device)
text_encoder = text_encoder.to(device)
unet = unet.to(device)

pipeline = ImagicStableDiffusionPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    scheduler=scheduler,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor
)

pipeline = pipeline.to(device)
print(pipeline.dtype)

torch.float32


## optimize the target embedding and fine-tune the pre-trained model

In [None]:
prompt = "A cat wearing the Renaissance-style knight's helmet and full plate armour, with only its bare legs showing"

init_image = Image.open("./image_stocks/ginger_cat2.jpg")
guidance_scale = 13

generator = torch.Generator(device).manual_seed(0)
display(init_image)

In [None]:
pipeline.train(
    prompt = prompt,
    image = init_image,
    guidance_scale = guidance_scale,
    generator = generator,
    height = 512,
    width = 512,
    embedding_learning_rate = 5e-4,
    diffusion_model_learning_rate = 1e-7,
    text_embedding_optimization_steps = 1000,
    model_fine_tuning_optimization_steps = 1000,
    show_progress = True,
    )

## call inference procedure from the fine-tuned model

In [None]:
def save_images(image_list:list, alpha_list:list, folder_name:str="outputs", filename_prefix:str="generated_image"):
    if not os.path.exists(folder_name):
        os.makedirs(folder_name)
    for i, image in enumerate(image_list):
        filename = os.path.join(folder_name, f"{filename_prefix}_alpha-{alpha_list[i]}.png")
        image.save(filename)

In [None]:
imgs = []
alphas = [0.0, 0.4, 0.6, 0.8, 1, 1.2, 1.4]
for alpha in alphas:
    generator = torch.Generator("cuda").manual_seed(0)
    img = pipeline(
        alpha = alpha,
        height = 512,
        width = 512,
        guidance_scale = guidance_scale,
        generator = generator,
        init_timestep_rate = 0,
        )
    imgs.append(img.images[0])
    print(f"alpha: {alpha}")
    display(img.images[0])
save_images(imgs, alphas, filename_prefix="Puss0")

In [None]:
imgs = []
alphas = [0.0, 0.4, 0.6, 0.8, 1, 1.2, 1.4]
for alpha in alphas:
    generator = torch.Generator("cuda").manual_seed(0)
    img = pipeline(
        alpha = alpha,
        height = 512,
        width = 512,
        guidance_scale = guidance_scale,
        generator = generator,
        init_timestep_rate = 0.1,
        )
    imgs.append(img.images[0])
    print(f"alpha: {alpha}")
    display(img.images[0])
save_images(imgs, alphas, filename_prefix="Puss0.1")

In [None]:
imgs = []
alphas = [0.0, 0.4, 0.6, 0.8, 1, 1.2, 1.4]
for alpha in alphas:
    generator = torch.Generator("cuda").manual_seed(0)
    img = pipeline(
        alpha = alpha,
        height = 512,
        width = 512,
        guidance_scale = guidance_scale,
        generator = generator,
        init_timestep_rate = 0.2,
        )
    imgs.append(img.images[0])
    print(f"alpha: {alpha}")
    display(img.images[0])
save_images(imgs, alphas, filename_prefix="Puss0.2")

In [None]:
imgs = []
alphas = [0.0, 0.4, 0.6, 0.8, 1, 1.2, 1.4]
for alpha in alphas:
    generator = torch.Generator("cuda").manual_seed(0)
    img = pipeline(
        alpha = alpha,
        height = 512,
        width = 512,
        guidance_scale = guidance_scale,
        generator = generator,
        init_timestep_rate = 0.3,
        )
    imgs.append(img.images[0])
    print(f"alpha: {alpha}")
    display(img.images[0])
save_images(imgs, alphas, filename_prefix="Puss0.3")

In [None]:
imgs = []
alphas = [0.0, 0.4, 0.6, 0.8, 1, 1.2, 1.4]
for alpha in alphas:
    generator = torch.Generator("cuda").manual_seed(0)
    img = pipeline(
        alpha = alpha,
        height = 512,
        width = 512,
        guidance_scale = guidance_scale,
        generator = generator,
        init_timestep_rate = 0.4,
        )
    imgs.append(img.images[0])
    print(f"alpha: {alpha}")
    display(img.images[0])
save_images(imgs, alphas, filename_prefix="Puss0.4")