Skip to content

Commit

Permalink
initial version of wuerstchen
Browse files Browse the repository at this point in the history
  • Loading branch information
kohya-ss committed Oct 3, 2023
1 parent 2d87bb6 commit 4c5d6d1
Show file tree
Hide file tree
Showing 3 changed files with 850 additions and 1 deletion.
7 changes: 6 additions & 1 deletion library/train_util.py
Expand Up @@ -1995,7 +1995,7 @@ def debug_dataset(train_dataset, show_input_ids=False):

if show_input_ids:
print(f"input ids: {iid}")
if "input_ids2" in example:
if "input_ids2" in example and example["input_ids2"] is not None:
print(f"input ids2: {example['input_ids2'][j]}")
if example["images"] is not None:
im = example["images"][j]
Expand All @@ -2012,6 +2012,11 @@ def debug_dataset(train_dataset, show_input_ids=False):
cond_img = cond_img[:, :, ::-1]
if os.name == "nt":
cv2.imshow("cond_img", cond_img)

for key in example.keys():
if key in ["images", "conditioning_images", "input_ids", "input_ids2"]:
continue
print(f"{key}: {example[key][j] if example[key] is not None else None}")

if os.name == "nt": # only windows
cv2.imshow("img", im)
Expand Down
196 changes: 196 additions & 0 deletions wuerstchen/wuerstchen_inference.py
@@ -0,0 +1,196 @@
# use Diffusers' pipeline to generate images

import argparse
import datetime
import math
import os
import random
import re
from einops import repeat
import numpy as np
import torch

try:
import intel_extension_for_pytorch as ipex

if torch.xpu.is_available():
from library.ipex import ipex_init

ipex_init()
except Exception:
pass
from tqdm import tqdm
from PIL import Image
from transformers import CLIPTextModel, PreTrainedTokenizerFast
from diffusers.pipelines.wuerstchen.modeling_wuerstchen_prior import WuerstchenPrior
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler

# from diffusers.pipelines.wuerstchen.pipeline_wuerstchen_prior import DEFAULT_STAGE_C_TIMESTEPS
from wuerstchen_train import EfficientNetEncoder


def generate(args):
dtype = torch.float32
if args.fp16:
dtype = torch.float16
elif args.bf16:
dtype = torch.bfloat16

device = args.device

os.makedirs(args.outdir, exist_ok=True)

# load tokenizer
print("load tokenizer")
tokenizer = PreTrainedTokenizerFast.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="tokenizer")

# load text encoder
print("load text encoder")
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=dtype
)

# load prior model
print("load prior model")
prior: WuerstchenPrior = WuerstchenPrior.from_pretrained(
args.pretrained_prior_model_name_or_path, subfolder="prior", torch_dtype=dtype
)

# Diffusers版のxformers使用フラグを設定する関数
def set_diffusers_xformers_flag(model, valid):
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)

for child in module.children():
fn_recursive_set_mem_eff(child)

fn_recursive_set_mem_eff(model)

# モデルに xformers とか memory efficient attention を組み込む
if args.diffusers_xformers:
print("Use xformers by Diffusers")
set_diffusers_xformers_flag(prior, True)

# load pipeline
print("load pipeline")
pipeline = AutoPipelineForText2Image.from_pretrained(
args.pretrained_decoder_model_name_or_path,
prior_prior=prior,
prior_text_encoder=text_encoder,
prior_tokenizer=tokenizer,
)
pipeline = pipeline.to(device, torch_dtype=dtype)

# generate image
while True:
width = args.w
height = args.h
seed = args.seed
negative_prompt = None

if args.interactive:
print("prompt:")
prompt = input()
if prompt == "":
break

# parse prompt
prompt_args = prompt.split(" --")
prompt = prompt_args[0]

for parg in prompt_args[1:]:
try:
m = re.match(r"w (\d+)", parg, re.IGNORECASE)
if m:
width = int(m.group(1))
print(f"width: {width}")
continue

m = re.match(r"h (\d+)", parg, re.IGNORECASE)
if m:
height = int(m.group(1))
print(f"height: {height}")
continue

m = re.match(r"d ([\d,]+)", parg, re.IGNORECASE)
if m: # seed
seed = int(m.group(1))
print(f"seed: {seed}")
continue

m = re.match(r"n (.+)", parg, re.IGNORECASE)
if m: # negative prompt
negative_prompt = m.group(1)
print(f"negative prompt: {negative_prompt}")
continue
except ValueError as ex:
print(f"Exception in parsing / 解析エラー: {parg}")
print(ex)
else:
prompt = args.prompt
negative_prompt = args.negative_prompt

if seed is None:
generator = None
else:
generator = torch.Generator(device=device).manual_seed(seed)

with torch.autocast(device):
image = pipeline(
prompt,
negative_prompt=negative_prompt,
# prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
generator=generator,
width=width,
height=height,
).images[0]

# save image
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
image.save(os.path.join(args.outdir, f"image_{timestamp}.png"))

if not args.interactive:
break

print("Done!")


def setup_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()

# train_util.add_sd_models_arguments(parser)
parser.add_argument(
"--pretrained_prior_model_name_or_path",
type=str,
default="warp-ai/wuerstchen-prior",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_decoder_model_name_or_path",
type=str,
default="warp-ai/wuerstchen",
required=False,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)

parser.add_argument("--prompt", type=str, default="A photo of a cat")
parser.add_argument("--negative_prompt", type=str, default="")
parser.add_argument("--outdir", type=str, default=".")
parser.add_argument("--w", type=int, default=1024)
parser.add_argument("--h", type=int, default=1024)
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--fp16", action="store_true")
parser.add_argument("--bf16", action="store_true")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--seed", type=int, default=None)
parser.add_argument("--diffusers_xformers", action="store_true", help="use xformers by diffusers / Diffusersでxformersを使用する")
return parser


if __name__ == "__main__":
parser = setup_parser()

args = parser.parse_args()
generate(args)

0 comments on commit 4c5d6d1

Please sign in to comment.