In [1]:
%load_ext autoreload
%autoreload 2

In [1]:
import os
import sys
sys.path.append('/home/cizinsky/garment-texture-completion')
CKPT_ROOT = '/scratch/izar/cizinsky/garment-completion/checkpoints'

from matplotlib import pyplot as plt
from ipywidgets import interact, IntSlider
import pytorch_lightning as pl

import torch

from helpers.pl_module import GarmentInpainterModule
from helpers.dataset import get_dataloaders
from helpers.data_utils import denormalise_image_torch
from helpers.data_utils import torch_image_to_pil, denormalise_image_torch

In [3]:
pl.seed_everything(42);

Seed set to 42


### Inspection of trained models

---


In [4]:
!ls /scratch/izar/cizinsky/garment-completion/checkpoints/

lilac-hill-102


In [5]:
run_name = "lilac-hill-102"
checkpoint_path = f"{CKPT_ROOT}/{run_name}/last.ckpt"
os.path.exists(checkpoint_path)

True

In [6]:
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
cfg = checkpoint["hyper_parameters"]
trn_dataloader, val_dataloader = get_dataloaders(cfg)



In [7]:
model = GarmentInpainterModule(cfg, trn_dataloader).to(torch.float16)
model.setup()
model.load_state_dict(checkpoint["state_dict"])
model.eval().cuda()
print("✅ Model loaded!")

✅ Model loaded!


In [8]:
batch = next(iter(val_dataloader))

In [29]:
reconstructed_imgs = model.inference(batch["partial_diffuse_img"].to("cuda"), strength=1.0, num_inference_steps=50, guidance_scale=5.0)
cond_images = [torch_image_to_pil(img) for img in denormalise_image_torch(batch["partial_diffuse_img"])]
pred_images = [torch_image_to_pil(img) for img in denormalise_image_torch(reconstructed_imgs)]
target_images = [torch_image_to_pil(img) for img in denormalise_image_torch(batch["full_diffuse_img"])]

Denoising loop during inference: 100%|██████████| 50/50 [00:39<00:00,  1.28it/s]


In [30]:
def plot_images(index):
    fig, axs = plt.subplots(1, 3, figsize=(10, 5))
    axs[0].imshow(cond_images[index])
    axs[0].set_title("Condition")
    axs[0].axis("off")
    axs[1].imshow(pred_images[index])
    axs[1].set_title("Predicted")
    axs[1].axis("off")
    axs[2].imshow(target_images[index])
    axs[2].set_title("Target")
    axs[2].axis("off")
    plt.show()

interact(plot_images, index=IntSlider(min=0, max=len(pred_images)-1, step=1));

interactive(children=(IntSlider(value=0, description='index', max=19), Output()), _dom_classes=('widget-intera…