[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/PhotoMaker-colab/blob/main/PhotoMaker_style_colab.ipynb)

In [None]:
%cd /content
!git clone -b dev https://github.com/camenduru/PhotoMaker
%cd /content/PhotoMaker

!pip install -q torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 torchtext==0.15.2 torchdata==0.6.1 --extra-index-url https://download.pytorch.org/whl/cu118 -U
!pip install -q xformers==0.0.20 diffusers accelerate einops onnxruntime-gpu omegaconf

import torch
import numpy as np
import random
import os
from PIL import Image

from diffusers.utils import load_image
from diffusers import DDIMScheduler
from huggingface_hub import hf_hub_download

from photomaker.pipeline import PhotoMakerStableDiffusionXLPipeline

# gloal variable and function
def image_grid(imgs, rows, cols, size_after_resize):
    assert len(imgs) == rows*cols

    w, h = size_after_resize, size_after_resize
    
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        img = img.resize((w,h))
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

# https://civitai.com/api/download/models/276923
base_model_path = './civitai_models/sdxlUnstableDiffusers_v11.safetensors'
photomaker_path = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
# https://civitai.com/api/download/models/152309?type=Model&format=SafeTensor
lora_path = './civitai_models/xl_more_art-full.safetensors'

device = "cuda"
save_path = "./outputs"

# Load base model
pipe = PhotoMakerStableDiffusionXLPipeline.from_single_file(
    base_model_path, 
    torch_dtype=torch.bfloat16, 
    original_config_file=None,
).to(device)

# Load PhotoMaker checkpoint
pipe.load_photomaker_adapter(
    os.path.dirname(photomaker_path),
    subfolder="",
    weight_name=os.path.basename(photomaker_path),
    trigger_word="img"
)     

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
print("Loading lora...")
pipe.load_lora_weights(os.path.dirname(lora_path), weight_name=os.path.basename(lora_path), adapter_name="xl_more_art-full")
pipe.set_adapters(["photomaker", "xl_more_art-full"], adapter_weights=[1.0, 0.5])
pipe.fuse_lora()

In [None]:
# define and show the input ID images
image_path = './examples/scarletthead_woman/scarlett_0.jpg'

input_id_images = []
input_id_images.append(load_image(image_path))
    
input_grid = image_grid(input_id_images, 1, 1, size_after_resize=224)
print("Input ID images:")
input_grid

In [None]:
## Note that the trigger word `img` must follow the class word for personalization
prompt = "A girl img riding dragon over a whimsical castle, 3d CGI, art by Pixar, half-body, screenshot from animation"
negative_prompt = "realistic, photo-realistic, bad quality, bad anatomy, worst quality, low quality, lowres, extra fingers, blur, blurry, ugly, wrong proportions, watermark, image artifacts, bad eyes, bad hands, bad arms"
generator = torch.Generator(device=device).manual_seed(42)

## Parameter setting
num_steps = 50
style_strength_ratio = 20
start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
if start_merge_step > 30:
    start_merge_step = 30
    
images = pipe(
    prompt=prompt,
    input_id_images=input_id_images,
    negative_prompt=negative_prompt,
    num_images_per_prompt=4,
    num_inference_steps=num_steps,
    start_merge_step=start_merge_step,
    generator=generator,
).images

In [None]:
# Show and save the results
## Downsample for visualization
grid = image_grid(images, 1, 4, size_after_resize=512)

os.makedirs(save_path, exist_ok=True)
for idx, image in enumerate(images):
    image.save(os.path.join(save_path, f"photomaker_style_{idx:02d}.png"))
    
print("Results:")
grid