In [1]:
%load_ext autoreload
%autoreload 2

In [10]:
import os
import sys
sys.path.append('/home/cizinsky/garment-texture-completion')
CKPT_ROOT = '/scratch/izar/cizinsky/garment-completion/checkpoints'

from matplotlib import pyplot as plt
from ipywidgets import interact, IntSlider
import pytorch_lightning as pl

import torch
from torchvision.transforms.functional import pil_to_tensor
from torchvision.utils import make_grid

from helpers.pl_module import GarmentInpainterModule
from helpers.dataset import get_dataloaders
from helpers.data_utils import denormalise_image_torch
from helpers.data_utils import torch_image_to_pil, denormalise_image_torch
from helpers.metrics import compute_all_metrics

from tqdm import tqdm

import pandas as pd

In [3]:
pl.seed_everything(42);

Seed set to 42


### Inspection of trained models

---


In [4]:
!ls /scratch/izar/cizinsky/garment-completion/checkpoints/

bright-universe-112  fast-universe-159	    rose-sun-129
cerulean-smoke-152   fearless-jazz-157	    sandy-capybara-119
cerulean-star-109    fluent-aardvark-116    soft-dust-146
cosmic-cosmos-123    glamorous-thunder-121  stellar-feather-136
crisp-cosmos-153     glorious-oath-130	    stellar-pond-132
curious-oath-144     glowing-disco-164	    swift-lake-110
devout-wind-122      graceful-cloud-156     twilight-sponge-160
eager-energy-111     grateful-terrain-163   unique-water-133
exalted-forest-117   honest-bird-113	    valiant-resonance-120
faithful-dew-158     lilac-hill-102	    vivid-smoke-128
fallen-dew-140	     morning-microwave-118  worthy-paper-139
fallen-shape-155     peach-firebrand-106


In [5]:
run_name = "fast-universe-159"
checkpoint_path = f"{CKPT_ROOT}/{run_name}/last.ckpt"
os.path.exists(checkpoint_path)

True

In [6]:
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
cfg = checkpoint["hyper_parameters"]
cfg.data.num_workers = 10
cfg.data.batch_size = 2
trn_dataloader, val_dataloader = get_dataloaders(cfg)

In [7]:
model = GarmentInpainterModule(cfg, trn_dataloader)
model.setup()
model.load_state_dict(checkpoint["state_dict"])
model.eval().cuda()
print("✅ Model loaded!")

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

✅ Model loaded!


In [8]:
max_n_batches = 5
val_batches = [next(iter(val_dataloader)) for _ in range(max_n_batches)]

In [9]:
image_guidance_scale = [2.5, 5.0] # [2.5, 5.0, 7.5, 10.0]
text_guidance_scale = [2.5, 5.0] # [2.5, 5.0, 7.5, 10.0]

rows = []
for img_scale in image_guidance_scale:
    for text_scale in text_guidance_scale:
        sample_idx = 0
        for batch in tqdm(val_batches, desc=f"img_scale={img_scale}, text_scale={text_scale}"):
            pred_imgs = model.inference(batch["partial_diffuse_img"].to("cuda"), num_inference_steps=50, guidance_scale=text_scale, image_guidance_scale=img_scale)
            pred_imgs_tensors = torch.stack([pil_to_tensor(img) for img in pred_imgs]).to("cuda") / 255.0
            target_imgs = denormalise_image_torch(batch["full_diffuse_img"].to("cuda"))
            image_metrics = compute_all_metrics(pred_imgs_tensors, target_imgs)

            for i in range(len(pred_imgs)):
                ith_ssim = image_metrics["ssim"][i]
                ith_psnr = image_metrics["psnr"][i]
                ith_lpips = image_metrics["lpips"][i]
                rows.append({
                    "img_scale": img_scale,
                    "text_scale": text_scale,
                    "ssim": ith_ssim.item(),
                    "psnr": ith_psnr.item(),
                    "lpips": ith_lpips.item(),
                    "sample_idx": sample_idx,
                })
                sample_idx += 1

img_scale=2.5, text_scale=2.5:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=2.5:  20%|██        | 1/5 [00:21<01:27, 21.94s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=2.5:  40%|████      | 2/5 [00:42<01:03, 21.19s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=2.5:  60%|██████    | 3/5 [01:03<00:41, 20.98s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=2.5:  80%|████████  | 4/5 [01:24<00:20, 20.90s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=2.5: 100%|██████████| 5/5 [01:44<00:00, 20.98s/it]
img_scale=2.5, text_scale=5.0:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=5.0:  20%|██        | 1/5 [00:20<01:23, 20.79s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=5.0:  40%|████      | 2/5 [00:41<01:02, 20.79s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=5.0:  60%|██████    | 3/5 [01:02<00:41, 20.80s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=5.0:  80%|████████  | 4/5 [01:23<00:20, 20.81s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=2.5, text_scale=5.0: 100%|██████████| 5/5 [01:44<00:00, 20.81s/it]
img_scale=5.0, text_scale=2.5:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=2.5:  20%|██        | 1/5 [00:20<01:23, 20.82s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=2.5:  40%|████      | 2/5 [00:41<01:02, 20.84s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=2.5:  60%|██████    | 3/5 [01:02<00:41, 20.84s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=2.5:  80%|████████  | 4/5 [01:23<00:20, 20.84s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=2.5: 100%|██████████| 5/5 [01:44<00:00, 20.84s/it]
img_scale=5.0, text_scale=5.0:   0%|          | 0/5 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=5.0:  20%|██        | 1/5 [00:20<01:23, 20.83s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=5.0:  40%|████      | 2/5 [00:41<01:02, 20.84s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=5.0:  60%|██████    | 3/5 [01:02<00:41, 20.83s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=5.0:  80%|████████  | 4/5 [01:23<00:20, 20.84s/it]

  0%|          | 0/50 [00:00<?, ?it/s]

img_scale=5.0, text_scale=5.0: 100%|██████████| 5/5 [01:44<00:00, 20.84s/it]


In [24]:
df = pd.DataFrame(rows, columns=["sample_idx", "img_scale", "text_scale", "ssim", "psnr", "lpips"])
df

# group by img_scale and text_scale and compute mean of ssim, psnr, lpips + do not show sample_idx
df.groupby(["img_scale", "text_scale"]).mean()[["ssim", "psnr", "lpips"]]

Unnamed: 0_level_0,Unnamed: 1_level_0,ssim,psnr,lpips
img_scale,text_scale,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
2.5,2.5,0.070754,12.672301,0.713807
2.5,5.0,0.107561,13.7336,0.717837
5.0,2.5,0.125872,13.31328,0.819344
5.0,5.0,0.125178,14.769055,0.696974


In [25]:
import wandb

In [26]:
run = wandb.init(entity="ludekcizinsky", project="pbr-generation", id="w5daifhx", resume="must")

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mludekcizinsky[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [27]:
wandb_table = wandb.Table(dataframe=df)
run.log({"df_test": wandb_table})

In [28]:
run.finish()

0,1
epoch,0.0
optim/grad_norm_postclip,0.54837
optim/grad_norm_preclip,0.54837
optim/lr,1e-05
train/ddim_loss,0.25987
train/loss,0.20155
train/mse_loss,0.07161
trainer/global_step,17959.0
val/ddim_loss,0.20606
val/loss,0.21042


In [8]:
batch = next(iter(val_dataloader))

In [None]:
n = 5

In [None]:
# reconstructed_imgs = model.inference(batch["partial_diffuse_img"][:n].to("cuda"), num_inference_steps=50, guidance_scale=7.5, image_guidance_scale=1.5)
partial_img = batch["partial_diffuse_img"][2].unsqueeze(0).to("cuda")
image_guidance_scale = [2.5, 5.0, 7.5, 10.0]
text_guidance_scale = [2.5, 5.0, 7.5, 10.0]
results = []
for img_scale in image_guidance_scale:
    row_results = []
    for t_scale in text_guidance_scale:
        reconstructed_imgs = model.inference(partial_img, num_inference_steps=50, guidance_scale=t_scale, image_guidance_scale=img_scale)
        row_results.extend(reconstructed_imgs)
    results.append(row_results)


In [None]:
fig, axs = plt.subplots(
    len(image_guidance_scale),
    len(text_guidance_scale),
    figsize=(5, 5),
    tight_layout=True
)

for i, img_scale in enumerate(image_guidance_scale):
    for j, t_scale in enumerate(text_guidance_scale):
        ax = axs[i, j]
        ax.imshow(results[i][j])
        ax.set_xticks([])
        ax.set_yticks([])
        for spine in ax.spines.values():
            spine.set_edgecolor('white')

        # keep your per-row / per-col numeric labels if you want
        if j == 0:
            ax.set_ylabel(f"{img_scale}", rotation=0, labelpad=10, va='center')
        if i == 0:
            ax.set_title(f"{t_scale}")

# now add the “global” labels
fig.supxlabel("Text guidance scale", fontsize=12)
fig.supylabel("Image guidance scale", fontsize=12)

plt.show()

In [20]:
cond_images = [torch_image_to_pil(img) for img in denormalise_image_torch(batch["partial_diffuse_img"][:n])]
target_images = [torch_image_to_pil(img) for img in denormalise_image_torch(batch["full_diffuse_img"][:n])]

In [None]:
def plot_images(index):
    fig, axs = plt.subplots(1, 3, figsize=(10, 5))
    axs[0].imshow(cond_images[2])
    axs[0].set_title("Condition")
    axs[0].axis("off")
    axs[1].imshow(results[index])
    img_guidance = image_guidance_scale[index]
    axs[1].set_title(f"Predicted, IMG_GUIDE={img_guidance}")
    axs[1].axis("off")
    axs[2].imshow(target_images[2])
    axs[2].set_title("Target")
    axs[2].axis("off")
    plt.show()

interact(plot_images, index=IntSlider(min=0, max=len(results)-1, step=1));