In [None]:
import os
from pathlib import Path
import torch
from torch.utils.data import DataLoader
from torchvision.transforms.functional import pil_to_tensor
from huggingface_hub import hf_hub_download
from PIL import Image, ImageShow
import PIL
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from diffusers.utils import make_image_grid
from IPython.core.debugger import Pdb
import cv2
import numpy as np
from src.models.attention_processor import SkipAttnProcessor
from src.pipelines.spacat_pipeline import TryOnPipeline
from src.dataset.vitonhd import VITONHDDataset
from src.utils import get_project_root, show, init_attn_processor

In [None]:
def merge(
    img1: PIL.Image.Image,
    img2: PIL.Image.Image
) -> PIL.Image.Image:
    assert img1.size[1] == img2.size[1]
    h, w = img1.size[1], img1.size[0]
    w2 = img2.size[0]
    img = Image.new('RGB', (w + w2, h))
    img.paste(img1, (0, 0))
    img.paste(img2, (w, 0))
    return img

In [None]:
""" Download models from Huggingface Hub
"""
PROJECT_ROOT_PATH = get_project_root()
repo_id = 'bui/Navier-1'
base_folder = 'ckpt-75000'

# unet
unet_path = hf_hub_download(
    repo_id=repo_id,
    subfolder=os.path.join(base_folder, 'unet'),
    filename='diffusion_pytorch_model.safetensors',
    local_dir=os.path.join(PROJECT_ROOT_PATH, 'checkpoints', 'navier-1')
)
hf_hub_download(
    repo_id=repo_id,
    subfolder=os.path.join(base_folder, 'unet'),
    filename='config.json',
    local_dir=os.path.join(PROJECT_ROOT_PATH, 'checkpoints', 'navier-1')
)

# vae
hf_hub_download(
    repo_id=repo_id,
    subfolder=os.path.join(base_folder, 'vae'),
    filename='diffusion_pytorch_model.safetensors',
    local_dir=os.path.join(PROJECT_ROOT_PATH, 'checkpoints', 'navier-1')
)
hf_hub_download(
    repo_id=repo_id,
    subfolder=os.path.join(base_folder, 'vae'),
    filename='config.json',
    local_dir=os.path.join(PROJECT_ROOT_PATH, 'checkpoints', 'navier-1')
)
# scheduler
hf_hub_download(
    repo_id=repo_id,
    subfolder=os.path.join(base_folder, 'scheduler'),
    filename='scheduler_config.json',
    local_dir=os.path.join(PROJECT_ROOT_PATH, 'checkpoints', 'navier-1')
)
# model_index.json
hf_hub_download(
    repo_id=repo_id,
    subfolder=base_folder,
    filename='model_index.json',
    local_dir=os.path.join(PROJECT_ROOT_PATH, 'checkpoints', 'navier-1')
)

model_root_path = os.path.dirname(os.path.dirname(unet_path))
model_root_path

In [None]:
vae = AutoencoderKL.from_pretrained(
    model_root_path,
    subfolder='vae',
    torch_dtype=torch.float16
)
scheduler = DDPMScheduler.from_pretrained(
    model_root_path,
    subfolder='scheduler'
)

unet = UNet2DConditionModel.from_pretrained(
    model_root_path,
    subfolder='unet',
    torch_dtype=torch.float16
)
init_attn_processor(unet, cross_attn_cls=SkipAttnProcessor)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [None]:
test_dataset = VITONHDDataset(
    data_rootpath=os.path.join(PROJECT_ROOT_PATH, 'datasets', 'vitonhd'),
    use_trainset=False,
    height=512,
    width=384,
    use_CLIPVision=True
)

bs = 8 # must multiple of 4
f = 30 # for comparison
generator = torch.manual_seed(1996 + bs*f)
test_dataloader = DataLoader(
    test_dataset,
    batch_size=bs,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    generator=generator,
)

In [None]:
pipe = TryOnPipeline(
    unet=unet,
    vae=vae,
    scheduler=scheduler
).to(device)

In [None]:
""" Generate try-on image and concat with its original one (for qualitative comparison)
then build the image grid of all of generated pairs and save it to disk.
"""
h, w = 512, 384
use_poisson_blending = False
display_mask = True
img_list = []
max_run = 5
with torch.inference_mode():
    with torch.amp.autocast(device):
        for idx, batch in enumerate(test_dataloader):
            if idx < max_run:
                images = pipe(
                    image=batch['image'].to(device),
                    mask_image=batch['mask'].to(device),
                    densepose_image=batch['densepose'].to(device),
                    cloth_image=batch['cloth_raw'].to(device),
                    height=h,
                    width=w,
                    generator=generator,
                ).images
                for img, origin_img_path, mask_path in zip(images, batch['original_image_path'], batch['original_mask_path']):
                    origin_img = Image.open(origin_img_path).resize((w, h))
                    mask_img = Image.open(mask_path).resize((w, h))
                    if use_poisson_blending:
                        np_img = np.array(img)
                        np_origin_img = np.array(origin_img)
                        mask = Image.open(mask_path).convert('L').resize((w, h))
                        np_mask = np.array(mask)
                        np_mask = 255 - np_mask
                        blended_img = cv2.seamlessClone(np_origin_img, np_img, np_mask, (w//2, h//2), cv2.NORMAL_CLONE)
                        gen_img = Image.fromarray(blended_img)
                        merged_img = merge(gen_img, origin_img)
                        if display_mask:
                            merged_img = merge(merged_img, mask)
                        img_list.append(merged_img)
                    else:
                        merged_img = merge(img, origin_img)
                        img_list.append(merged_img)
                # break
del pipe
torch.cuda.empty_cache()

In [None]:
grid = make_image_grid(
    images=img_list,
    rows=bs*max_run // 4,
    cols=4
)
grid

In [None]:
# save_path = os.path.join(PROJECT_ROOT_PATH, 'results', 'navier-1', 'beta', base_folder)
# os.makedirs(save_path, exist_ok=True)
# fname = f'bs{bs}-f{f}-poisson.png' if use_poisson_blending else f'bs{bs}-f{f}.png'
# grid.save(Path(save_path, fname))

In [None]:
# single_batch_test = next(iter(test_dataloader))
# idx = 0
# show(torch.cat([
#         single_batch_test['original_image'][idx],
#         single_batch_test['original_mask'][idx],
#         single_batch_test['cloth_raw'][idx],
#         single_batch_test['original_densepose'][idx]
#     ], dim=-1), title=single_batch_test['im_name'][idx]
# )

In [None]:
# print(single_batch_test['image'].shape)
# print(single_batch_test['mask'].shape)
# print(single_batch_test['densepose'].shape)
# print(single_batch_test['cloth_raw'].shape)

In [None]:
# h, w = 512, 384
# images = pipe(
#     image=single_batch_test['image'].to(device),
#     mask_image=single_batch_test['mask'].to(device),
#     densepose_image=single_batch_test['densepose'].to(device),
#     cloth_image=single_batch_test['cloth_raw'].to(device),
#     height=h,
#     width=w
# ).images