In [1]:
#!/usr/bin/env python3
import json
import random
from pathlib import Path
from types import SimpleNamespace

import numpy as np
import torch
import yaml
from diffusers import KandinskyV22InpaintPipeline, AutoPipelineForInpainting, KandinskyV22PriorPipeline
from PIL import Image, ImageDraw, ImageOps

from dataset import SVHNFullBBox


CONFIG_DEFAULT = {
    "data_root": "../data/svhn_full/train_extracted",
    "split": "train",
    "out_dir": "outputs/svhn_edits",
    "num_samples": 10,
    "seed": 42,
    "mask_pad": 2,
    "prompt_template": "replace the existing digits in the image with: {new_digits}",
    "model_id": "kandinsky-community/kandinsky-2-2-decoder-inpaint",
    "steps": 30,
    "guidance": 7.5,
}

def load_config(path=None, overrides=None):
    cfg = dict(CONFIG_DEFAULT)
    if path:
        with open(path, "r") as f:
            if path.endswith((".yml", ".yaml")):
                cfg.update(yaml.safe_load(f) or {})
            else:
                cfg.update(json.load(f))
    if overrides:
        cfg.update(overrides)
    return SimpleNamespace(**cfg)


def make_mask(size, boxes, pad=2):
    mask = Image.new("L", size, 0)
    draw = ImageDraw.Draw(mask)
    for x1, y1, x2, y2 in boxes:
        draw.rectangle((x1 - pad, y1 - pad, x2 + pad, y2 + pad), fill=255)
    return mask


def pad_to_multiple_of_8(img):
    w, h = img.size
    new_w = (w + 7) // 8 * 8
    new_h = (h + 7) // 8 * 8
    if new_w == w and new_h == h:
        return img, (0, 0)
    fill = 0 if img.mode != "RGB" else (0, 0, 0)
    padded = ImageOps.expand(img, border=(0, 0, new_w - w, new_h - h), fill=fill)
    return padded, (new_w - w, new_h - h)


def sample_new_labels(labels, rng):
    return [rng.choice([d for d in range(10) if d != l]) for l in labels]


def run(cfg):
    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32

    rng = random.Random(cfg.seed)
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    ds = SVHNFullBBox(root=cfg.data_root, split=cfg.split)
    chosen = rng.sample(range(len(ds)), cfg.num_samples)

    pipe_prior = KandinskyV22PriorPipeline.from_pretrained(
        "kandinsky-community/kandinsky-2-2-prior", torch_dtype=dtype
    ).to(device)
    pipe = KandinskyV22InpaintPipeline.from_pretrained(
        "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=dtype
    ).to(device)
    pipe.safety_checker = None  # optional; keep feature_extractor default

    out_dir = Path(cfg.out_dir); out_dir.mkdir(parents=True, exist_ok=True)
    meta = []

    for i, idx in enumerate(chosen):
        img, target = ds[idx]
        boxes = target["boxes"].tolist()
        labels = target["labels"].tolist()
        new_labels = [rng.choice([d for d in range(10) if d != l]) for l in labels]
        prompt = cfg.prompt_template.format(new_digits="".join(map(str, new_labels)))

        mask = make_mask(img.size, boxes, pad=cfg.mask_pad)
        img_pad, pad_wh = pad_to_multiple_of_8(img)
        mask_pad, _ = pad_to_multiple_of_8(mask)

        gen = torch.Generator(device=device).manual_seed(cfg.seed + i)
        prior_out = pipe_prior(prompt=prompt, num_inference_steps=50, generator=gen)
        edited = pipe(
            prompt=prompt,
            image=img_pad,
            mask_image=mask_pad,
            image_embeds=prior_out.image_embeds,
            negative_image_embeds=prior_out.negative_image_embeds,
            generator=gen,
            num_inference_steps=cfg.steps,
            guidance_scale=cfg.guidance,
        ).images[0]

        if pad_wh != (0, 0):
            w, h = img.size
            edited = edited.crop((0, 0, w, h))
        edited_composite = Image.composite(edited, img, mask)

        base = Path(ds.records[idx]["name"]).stem
        img.save(out_dir / f"{base}_orig.png")
        edited_composite.save(out_dir / f"{base}_edit.png")
        meta.append({...})
    json.dump(meta, open(out_dir / "metadata.json", "w"), indent=2)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
cfg = load_config()
run(cfg)

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]`torch_dtype` is deprecated! Use `dtype` instead!
Loading pipeline components...: 100%|██████████| 6/6 [00:00<00:00,  6.85it/s]
Loading pipeline components...: 100%|██████████| 3/3 [00:00<00:00,  5.91it/s]
100%|██████████| 50/50 [00:00<00:00, 70.68it/s]
100%|██████████| 30/30 [00:01<00:00, 18.28it/s]
100%|██████████| 50/50 [00:00<00:00, 79.36it/s]
100%|██████████| 30/30 [00:01<00:00, 19.79it/s]
100%|██████████| 50/50 [00:00<00:00, 78.71it/s]
100%|██████████| 30/30 [00:01<00:00, 19.77it/s]
100%|██████████| 50/50 [00:00<00:00, 79.65it/s]
100%|██████████| 30/30 [00:01<00:00, 19.72it/s]
100%|██████████| 50/50 [00:00<00:00, 79.91it/s]
100%|██████████| 30/30 [00:01<00:00, 19.76it/s]
100%|██████████| 50/50 [00:00<00:00, 79.95it/s]
100%|██████████| 30/30 [00:01<00:00, 19.83it/s]
100%|██████████| 50/50 [00:00<00:00, 79.55it/s]
100%|██████████| 30/30 [00:01<00:00, 19.88it/s]
100%|██████████| 50/50 [00:00<00:00, 79.17it/s]
100%|

TypeError: Object of type set is not JSON serializable