# Qualitative Inpainting Demo

Runs both vanilla and improved DDPM inpainting on your own images and displays results side-by-side.

**Setup:** Drop files into `my_images/` following the naming convention:
```
my_images/
  photo.png
  photo_mask.png    ← white = inpaint, black = keep
  cat.jpg
  cat_mask.png
```

In [None]:
# Cell 1 — Imports & config
import sys, os, warnings
import torch
import matplotlib.pyplot as plt
warnings.filterwarnings("ignore", message="IProgress not found")
from tqdm.auto import tqdm
from utils.image import apply_mask_for_display
from utils.cli import preprocess_inputs, load_sd_pipeline

STEPS          = 50
GUIDANCE_SCALE = 7.5
SEED           = 42
RESAMPLE_STEPS = 5
MY_IMAGES_DIR  = './our_images'

In [None]:
# Cell 2 — Discover image/mask pairs
pairs = []  # list of (image_path, mask_path, stem)
for f in sorted(os.listdir(MY_IMAGES_DIR)):
    stem, ext = os.path.splitext(f)
    if stem.endswith('_mask') or ext.lower() not in ('.png', '.jpg', '.jpeg'):
        continue
    for mask_ext in ('.png', '.jpg', '.jpeg'):
        mask_file = os.path.join(MY_IMAGES_DIR, stem + '_mask' + mask_ext)
        if os.path.exists(mask_file):
            pairs.append((os.path.join(MY_IMAGES_DIR, f), mask_file, stem))
            break

print(f'Found {len(pairs)} image/mask pair(s):')
for _, _, stem in pairs:
    print(f'  {stem}')

In [None]:
# Cell 3 — Load model
sys.path.insert(0, os.path.abspath('.'))
from vanilla_inpaint import ddpm_inpaint
from improved_inpaint import ddpm_inpaint_improved

device = 'mps' if torch.backends.mps.is_available() else 'cpu'
print(f'Using device: {device}')
pipe = load_sd_pipeline(device)
pipe.set_progress_bar_config(disable=True)
print('Model loaded.')

In [None]:
# Cell 4 — Interactive prompt collection
# Displays each image/mask pair and asks you to enter a prompt before running.

prompts = {}

for img_path, mask_path, stem in pairs:
    image_preview, mask_preview = preprocess_inputs(img_path, mask_path)
    masked_vis = apply_mask_for_display(image_preview, mask_preview)

    fig, axes = plt.subplots(1, 2, figsize=(10, 5))
    axes[0].imshow(image_preview); axes[0].set_title('Original')
    axes[1].imshow(masked_vis);    axes[1].set_title('Masked (region to inpaint)')
    for ax in axes:
        ax.axis('off')
    plt.suptitle(stem, fontsize=12)
    plt.tight_layout()
    plt.show()

    prompt = input(f'Prompt for "{stem}" (leave blank for unconditional): ').strip()
    prompts[stem] = prompt
    print(f'  -> "{prompt}"\n')

In [None]:
# Cell 5 — Run inpainting & display side-by-side
for img_path, mask_path, stem in tqdm(pairs, desc='Inpainting'):
    prompt = prompts.get(stem, '')
    image, mask = preprocess_inputs(img_path, mask_path)

    van = ddpm_inpaint(pipe, image, mask, prompt, STEPS, GUIDANCE_SCALE, SEED)
    imp = ddpm_inpaint_improved(pipe, image, mask, prompt, STEPS, GUIDANCE_SCALE, SEED, RESAMPLE_STEPS)

    masked_vis = apply_mask_for_display(image, mask)

    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    axes[0].imshow(image);      axes[0].set_title('Original')
    axes[1].imshow(masked_vis); axes[1].set_title('Masked')
    axes[2].imshow(van);        axes[2].set_title('Vanilla')
    axes[3].imshow(imp);        axes[3].set_title('Improved')
    for ax in axes:
        ax.axis('off')
    plt.suptitle(f'{stem}  |  prompt: "{prompt}"', fontsize=11, y=1.01)
    plt.tight_layout()
    plt.show()