In [None]:
import sys
import argparse

sys.path.insert(0, "../sd-scripts")
import torch
from library import train_util, sdxl_train_util, sdxl_model_util
from library.sdxl_lpw_stable_diffusion import (
    SdxlStableDiffusionLongPromptWeightingPipeline,
)
from functools import partial
from PIL import Image

from library.video_inpainting_patch import (
    VideoInpaintingPatchPipeline,
    VideoInpaintingPatch,
)
from networks import lora

In [None]:
import argparse
import train_network


def setup_parser() -> argparse.ArgumentParser:
    parser = train_network.setup_parser()
    sdxl_train_util.add_sdxl_training_arguments(parser)
    return parser


parser = setup_parser()

argv = [
    "--config_file",
    "/home/longc/data/code/lora-scripts/config/video-example/video-debug.toml",
]
args = parser.parse_args(argv)
args = train_util.read_config_from_file(args, parser, argv)

In [None]:
print("preparing accelerator")
accelerator = train_util.prepare_accelerator(args)
is_main_process = accelerator.is_main_process

# mixed precisionに対応した型を用意しておき適宜castする
weight_dtype, _ = train_util.prepare_dtype(args)
device = accelerator.device
# device = "cpu"

device = torch.device(device)
if device.type == "cpu":
    weight_dtype = torch.float32

In [None]:
tokenizers = sdxl_train_util.load_tokenizers(args)
if not isinstance(tokenizers, list):
    tokenizers = [tokenizers]

In [None]:
(
    load_stable_diffusion_format,
    text_encoder1,
    text_encoder2,
    vae,
    unet,
    logit_scale,
    ckpt_info,
) = sdxl_train_util.load_target_model(
    args, accelerator, sdxl_model_util.MODEL_VERSION_SDXL_BASE_V1_0, weight_dtype
)

text_encoders = [text_encoder1, text_encoder2]

In [None]:

if device.type == "cuda":
    # xformers memory efficient attention
    train_util.replace_unet_modules(unet, args.mem_eff_attn, args.xformers, args.sdpa)
    if torch.__version__ >= "2.0.0":
        vae.set_use_memory_efficient_attention_xformers(args.xformers)

In [None]:
lora_file = "/home/longc/data/code/lora-scripts/output/video-debug2/video-debug-step00100000.safetensors"
multiplier = 1.0
lora_network, weights_sd = lora.create_network_from_weights(
    multiplier, lora_file, vae, text_encoders, unet, for_inference=True
)
lora_network.merge_to(
    text_encoders, unet, weights_sd, weight_dtype, device if args.lowram else "cpu"
)

In [None]:
unet = unet.to(device, dtype=weight_dtype).eval()
for t_enc in text_encoders:
    t_enc.to(device, dtype=weight_dtype).eval()

vae = vae.to(device, dtype=weight_dtype).eval()

In [None]:
patch_file = "/home/longc/data/code/lora-scripts/output/video-debug2/video-debug-step00100000_inpainting_head.pth"

# load inpainting head
inpainting_head = VideoInpaintingPatch(sdxl_model_util.VAE_SCALE_FACTOR).to(device)
inpainting_head.load_state_dict(torch.load(patch_file, map_location="cpu"))
inpainting_head.to(device, dtype=weight_dtype).eval()


In [None]:
from pathlib import Path
import copy


PROMPT_TEMPLATE = {
    "prompt": "(masterpiece, best quality:1.2), a person preparing a pizza on a table in a kitchen,",
    "negative_prompt": "lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts,signature, watermark, username, blurry, ",
    "width": 1024,
    "height": 576,
    "scale": 4.0,
    "strength": 1,
    "sample_steps": 50,
    "seed": 1337,
    "image": "/home/longc/data/code/lora-scripts/config/video-example/images/00006154.jpg",
    "mask": "/home/longc/data/code/lora-scripts/config/video-example/masks/00006154.png",
    "prev_image": "/home/longc/data/code/lora-scripts/video-inpainting/output/video-debug/sample/video-debug_20240124104531_e000000_00_1337.png",
    "prev_mask": "/home/longc/data/code/lora-scripts/config/video-example/masks/00006152.png",
}


def create_prompts(data_dir, max_size=1024):
    data_dir = Path(data_dir)
    images = sorted(data_dir.glob("images/*.jpg"))
    masks = sorted(data_dir.glob("masks/*.png"))
    captions = sorted(data_dir.glob("caption/*.txt"))
    n_images = len(images)
    for prev, curr in zip(range(0, n_images - 1), range(1, n_images)):
        with open(captions[curr], "r") as f:
            caption = f.read()
        image = Image.open(images[curr])
        w, h = image.width, image.height
        scale = min(max_size / w, max_size / h)
        if scale < 1:
            w = int(w * scale)
            h = int(h  * scale)
            
        prompt = copy.copy(PROMPT_TEMPLATE)
        prompt.update(
            {
                "prompt": caption,
                "width": w,
                "height": h,
                "image": str(images[curr]),
                "mask": str(masks[curr]),
                "prev_image": str(images[prev]),
                "prev_mask": str(masks[prev]),
            }
        )
        yield prompt

In [None]:
prompts = list(create_prompts("/home/longc/data/code/lora-scripts/config/video-example-2"))

In [None]:
args.output_dir = "./output/video-debug-iterate7"
args.sample_sampler = "euler_a"

In [None]:
from tqdm import tqdm


for i in tqdm(range(len(prompts)), total=len(prompts)):
    generated_files = train_util.sample_images_common(
        partial(VideoInpaintingPatchPipeline, inpainting_head=inpainting_head),
        accelerator,
        args,
        epoch=0,
        steps=0,
        device=device,
        vae=vae,
        tokenizer=tokenizers,
        text_encoder=text_encoders,
        unet=unet,
        prompts_data=prompts[i : i + 1],
        verbose=False
    )
    # if i + 1 < len(prompts):
    #     prompts[i + 1]["prev_image"] = generated_files[0]