# Vanilla vs Improved DDPM Inpainting — Comparison

This notebook loads pre-computed results from both evaluation runs and
produces comparative statistics and visualizations.

**Pre-condition:** Both `eval_results/` and `eval_results_improved/` must
already exist with their `inpainted/`, `originals/`, and `masks/` subdirectories
(populated by running `Vanilla-Inpainting-Demo.ipynb` and
`Improved-Inpainting-Demo.ipynb` respectively).

**Outputs:**
- `comparison_kde.png` — overlaid KDE for SSIM / PSNR / LPIPS
- `comparison_best5.png` / `comparison_worst5.png` — side-by-side top/worst


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from utils import (
    load_data, run_metrics,
    print_stats_table, plot_kde_overlay, apply_mask_for_display,
)

VANILLA_DIR  = './eval_results'
IMPROVED_DIR = './eval_results_improved'

device = 'mps' if torch.backends.mps.is_available() else 'cpu'

# Determine how many results are available in each directory
def _count_results(d):
    inp = os.path.join(d, 'inpainted')
    if not os.path.isdir(inp):
        return 0
    return len([f for f in os.listdir(inp) if f.endswith('.png')])

n_van  = _count_results(VANILLA_DIR)
n_imp  = _count_results(IMPROVED_DIR)
N      = min(n_van, n_imp)

print(f'Vanilla results:  {n_van} images')
print(f'Improved results: {n_imp} images')
print(f'Comparing first {N} images (smaller of the two)')
assert N > 0, 'No results found — run both evaluation notebooks first.'


In [None]:
results_vanilla = run_metrics(
    inpainted_dir=os.path.join(VANILLA_DIR,  'inpainted'),
    masks_dir=os.path.join(VANILLA_DIR,      'masks'),
    originals_dir=os.path.join(VANILLA_DIR,  'originals'),
    n_images=N,
    device=device,
)

results_improved = run_metrics(
    inpainted_dir=os.path.join(IMPROVED_DIR, 'inpainted'),
    masks_dir=os.path.join(IMPROVED_DIR,     'masks'),
    originals_dir=os.path.join(IMPROVED_DIR, 'originals'),
    n_images=N,
    device=device,
)

print(f'Vanilla  metrics: {len(results_vanilla)} images')
print(f'Improved metrics: {len(results_improved)} images')


In [None]:
print_stats_table(results_vanilla,  label='Vanilla DDPM')
print_stats_table(results_improved, label='Improved (RePaint)')


In [None]:
plot_kde_overlay(
    results_vanilla, results_improved,
    label_a='Vanilla',
    label_b='Improved',
    out_path='comparison_kde.png',
)


In [None]:
# ---- Build per-image SSIM delta (improved - vanilla) ----
van_by_idx  = {r['idx']: r for r in results_vanilla}
imp_by_idx  = {r['idx']: r for r in results_improved}
common_idx  = sorted(set(van_by_idx) & set(imp_by_idx))

deltas = [
    {
        'idx':   i,
        'delta': imp_by_idx[i]['ssim'] - van_by_idx[i]['ssim'],
        'van':   van_by_idx[i],
        'imp':   imp_by_idx[i],
    }
    for i in common_idx
]
deltas_sorted = sorted(deltas, key=lambda d: d['delta'])

worst5 = deltas_sorted[:5]   # improved most WORSE than vanilla
best5  = deltas_sorted[-5:]  # improved most BETTER than vanilla (reversed below)
best5  = list(reversed(best5))


def _show_comparison(entries, title, out_path):
    """5-row x 4-col grid: Original | Vanilla | Improved | Mask."""
    fig, axes = plt.subplots(5, 4, figsize=(20, 25))
    fig.suptitle(title, fontsize=16, y=1.0)
    col_titles = ['Original', 'Vanilla Inpainted', 'Improved Inpainted', 'Mask']
    for col, ct in enumerate(col_titles):
        axes[0, col].set_title(ct, fontsize=12)

    for row, entry in enumerate(entries):
        idx    = entry['idx']
        rv     = entry['van']
        ri     = entry['imp']
        delta  = entry['delta']

        orig_img = rv['original']
        van_img  = rv['inpainted']
        imp_img  = ri['inpainted']
        mask_t   = rv['mask']

        axes[row, 0].imshow(orig_img)
        axes[row, 0].set_ylabel(
            f'#{idx}\ndSSIM={delta:+.3f}',
            fontsize=9, rotation=0, labelpad=70, va='center'
        )
        axes[row, 1].imshow(van_img)
        axes[row, 1].set_xlabel(
            f'SSIM={rv["ssim"]:.3f}', fontsize=8, labelpad=4
        )
        axes[row, 2].imshow(imp_img)
        axes[row, 2].set_xlabel(
            f'SSIM={ri["ssim"]:.3f}', fontsize=8, labelpad=4
        )
        axes[row, 3].imshow(mask_t.squeeze(), cmap='gray')

        for col in range(4):
            axes[row, col].axis('off')
        axes[row, 1].xaxis.set_visible(True)
        axes[row, 1].tick_params(bottom=False, labelbottom=True)
        axes[row, 1].set_xticks([])
        axes[row, 2].xaxis.set_visible(True)
        axes[row, 2].tick_params(bottom=False, labelbottom=True)
        axes[row, 2].set_xticks([])

    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f'Saved: {out_path}')


_show_comparison(
    best5,
    title='Top 5 — Improved most BETTER than Vanilla (by SSIM delta)',
    out_path='comparison_best5.png',
)

_show_comparison(
    worst5,
    title='Top 5 — Improved most WORSE than Vanilla (by SSIM delta)',
    out_path='comparison_worst5.png',
)
