# Vanilla vs Improved vs Novel DDPM Inpainting — Comparison

This notebook loads pre-computed results from all three evaluation runs and
produces comparative statistics and visualizations.

**Pre-condition:** `eval_results/`, `eval_results_improved/`, and
`eval_results_novel/` must already exist with their `inpainted/` subdirectories
(populated by running the three evaluation notebooks). Originals and masks are
taken from `eval_results/`.

**Outputs:**
- `comparison_kde.png` — overlaid KDE for SSIM / PSNR / LPIPS (all three methods)
- `comparison_best5.png` / `comparison_worst5.png` — side-by-side top/worst by SSIM delta
- `comparison_best5_lpips.png` — top 5 images where Improved beats Vanilla most (by LPIPS delta)

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from utils import (
    load_data, run_metrics,
    print_stats_table, plot_kde_overlay, apply_mask_for_display,
)

VANILLA_DIR  = 'eval_results'
IMPROVED_DIR = './eval_results_improved'
NOVEL_DIR    = './eval_results_novel'

device = 'mps' if torch.backends.mps.is_available() else 'cpu'

# Determine how many results are available in each directory
def _count_results(d):
    inp = os.path.join(d, 'inpainted')
    if not os.path.isdir(inp):
        return 0
    return len([f for f in os.listdir(inp) if f.endswith('.png')])

n_van  = _count_results(VANILLA_DIR)
n_imp  = _count_results(IMPROVED_DIR)
n_nov  = _count_results(NOVEL_DIR)
N      = min(n_van, n_imp, n_nov)

print(f'Vanilla results:  {n_van} images')
print(f'Improved results: {n_imp} images')
print(f'Novel results:    {n_nov} images')
print(f'Comparing first {N} images (minimum of the three)')
assert N > 0, 'No results found — run all three evaluation notebooks first.'

In [None]:
results_vanilla = run_metrics(
    inpainted_dir=os.path.join(VANILLA_DIR,  'inpainted'),
    masks_dir=os.path.join(VANILLA_DIR,      'masks'),
    originals_dir=os.path.join(VANILLA_DIR,  'originals'),
    n_images=N,
    device=device,
)

results_improved = run_metrics(
    inpainted_dir=os.path.join(IMPROVED_DIR, 'inpainted'),
    masks_dir=os.path.join(VANILLA_DIR,      'masks'),
    originals_dir=os.path.join(VANILLA_DIR,  'originals'),
    n_images=N,
    device=device,
)

results_novel = run_metrics(
    inpainted_dir=os.path.join(NOVEL_DIR,   'inpainted'),
    masks_dir=os.path.join(VANILLA_DIR,     'masks'),
    originals_dir=os.path.join(VANILLA_DIR, 'originals'),
    n_images=N,
    device=device,
)

print(f'Vanilla  metrics: {len(results_vanilla)} images')
print(f'Improved metrics: {len(results_improved)} images')
print(f'Novel    metrics: {len(results_novel)} images')

In [None]:
print_stats_table(results_vanilla,  label='Vanilla DDPM')
print_stats_table(results_improved, label='Improved (RePaint)')
print_stats_table(results_novel,    label='Novel (Cosine Mask Dilation)')

In [None]:
import scipy.stats as stats

metrics_cfg = [
    ('ssim',  'SSIM',  True),
    ('psnr',  'PSNR',  True),
    ('lpips', 'LPIPS', False),
]

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
fig.suptitle('Metric Distributions — Vanilla vs Improved vs Novel', fontsize=14)

colors = {'Vanilla': 'steelblue', 'Improved': 'darkorange', 'Novel': 'forestgreen'}
result_sets = {
    'Vanilla':  results_vanilla,
    'Improved': results_improved,
    'Novel':    results_novel,
}

for ax, (key, label, higher_better) in zip(axes, metrics_cfg):
    for name, res in result_sets.items():
        vals = np.array([r[key] for r in res])
        kde  = stats.gaussian_kde(vals)
        xs   = np.linspace(vals.min() - 0.05, vals.max() + 0.05, 300)
        ax.plot(xs, kde(xs), label=f'{name} (μ={vals.mean():.3f})', color=colors[name])
        ax.axvline(vals.mean(), linestyle='--', alpha=0.5, color=colors[name])
    direction = '↑ better' if higher_better else '↓ better'
    ax.set_title(f'{label}  [{direction}]')
    ax.set_xlabel(label)
    ax.set_ylabel('Density')
    ax.legend(fontsize=8)

plt.tight_layout()
plt.savefig('comparison_kde.png', dpi=150, bbox_inches='tight')
plt.show()
print('Saved: comparison_kde.png')

In [None]:
from utils import apply_mask_for_display

# ---- Build per-image SSIM delta (improved - vanilla) ----
van_by_idx  = {r['idx']: r for r in results_vanilla}
imp_by_idx  = {r['idx']: r for r in results_improved}
nov_by_idx  = {r['idx']: r for r in results_novel}
common_idx  = sorted(set(van_by_idx) & set(imp_by_idx) & set(nov_by_idx))

deltas = [
    {
        'idx':   i,
        'delta': imp_by_idx[i]['ssim'] - van_by_idx[i]['ssim'],
        'van':   van_by_idx[i],
        'imp':   imp_by_idx[i],
        'nov':   nov_by_idx[i],
    }
    for i in common_idx
]
deltas_sorted = sorted(deltas, key=lambda d: d['delta'])

worst5 = deltas_sorted[:5]
best5  = list(reversed(deltas_sorted[-5:]))


def _show_comparison(entries, title, out_path):
    """5-row x 5-col grid: Original | Masked | Vanilla | Improved | Novel."""
    fig, axes = plt.subplots(5, 5, figsize=(25, 25))
    fig.suptitle(title, fontsize=16, y=1.0)
    col_titles = ['Original', 'Masked', 'Vanilla', 'Improved', 'Novel']
    for col, ct in enumerate(col_titles):
        axes[0, col].set_title(ct, fontsize=12)

    for row, entry in enumerate(entries):
        idx   = entry['idx']
        rv    = entry['van']
        ri    = entry['imp']
        rn    = entry['nov']
        delta = entry['delta']

        masked_vis = apply_mask_for_display(rv['original'], rv['mask'])

        axes[row, 0].imshow(rv['original'])
        axes[row, 0].set_ylabel(
            f'#{idx}\ndSSIM={delta:+.3f}',
            fontsize=9, rotation=0, labelpad=70, va='center'
        )
        axes[row, 1].imshow(masked_vis)
        axes[row, 2].imshow(rv['inpainted'])
        axes[row, 2].set_xlabel(f'SSIM={rv["ssim"]:.3f}', fontsize=8, labelpad=4)
        axes[row, 3].imshow(ri['inpainted'])
        axes[row, 3].set_xlabel(f'SSIM={ri["ssim"]:.3f}', fontsize=8, labelpad=4)
        axes[row, 4].imshow(rn['inpainted'])
        axes[row, 4].set_xlabel(f'SSIM={rn["ssim"]:.3f}', fontsize=8, labelpad=4)

        for col in range(5):
            axes[row, col].axis('off')
        for col in (2, 3, 4):
            axes[row, col].xaxis.set_visible(True)
            axes[row, col].tick_params(bottom=False, labelbottom=True)
            axes[row, col].set_xticks([])

    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f'Saved: {out_path}')


_show_comparison(
    best5,
    title='Top 5 — Improved most BETTER than Vanilla (by SSIM delta)',
    out_path='comparison_best5.png',
)

_show_comparison(
    worst5,
    title='Top 5 — Improved most WORSE than Vanilla (by SSIM delta)',
    out_path='comparison_worst5.png',
)

In [None]:
# ---- Top 5 best improvements by LPIPS delta (improved - vanilla, most negative = best) ----
lpips_deltas = [
    {
        'idx':        i,
        'delta_lpips': imp_by_idx[i]['lpips'] - van_by_idx[i]['lpips'],
        'van':        van_by_idx[i],
        'imp':        imp_by_idx[i],
        'nov':        nov_by_idx[i],
    }
    for i in common_idx
]
lpips_best5 = sorted(lpips_deltas, key=lambda d: d['delta_lpips'])[:5]


def _show_comparison_lpips(entries, title, out_path):
    """5-row x 5-col grid: Original | Masked | Vanilla | Improved | Novel (ranked by LPIPS delta)."""
    fig, axes = plt.subplots(5, 5, figsize=(25, 25))
    fig.suptitle(title, fontsize=16, y=1.0)
    col_titles = ['Original', 'Masked', 'Vanilla', 'Improved', 'Novel']
    for col, ct in enumerate(col_titles):
        axes[0, col].set_title(ct, fontsize=12)

    for row, entry in enumerate(entries):
        idx   = entry['idx']
        rv    = entry['van']
        ri    = entry['imp']
        rn    = entry['nov']
        delta = entry['delta_lpips']

        masked_vis = apply_mask_for_display(rv['original'], rv['mask'])

        axes[row, 0].imshow(rv['original'])
        axes[row, 0].set_ylabel(
            f'#{idx}\ndLPIPS={delta:+.3f}',
            fontsize=9, rotation=0, labelpad=80, va='center'
        )
        axes[row, 1].imshow(masked_vis)
        axes[row, 2].imshow(rv['inpainted'])
        axes[row, 2].set_xlabel(f'LPIPS={rv["lpips"]:.3f}', fontsize=8, labelpad=4)
        axes[row, 3].imshow(ri['inpainted'])
        axes[row, 3].set_xlabel(f'LPIPS={ri["lpips"]:.3f}', fontsize=8, labelpad=4)
        axes[row, 4].imshow(rn['inpainted'])
        axes[row, 4].set_xlabel(f'LPIPS={rn["lpips"]:.3f}', fontsize=8, labelpad=4)

        for col in range(5):
            axes[row, col].axis('off')
        for col in (2, 3, 4):
            axes[row, col].xaxis.set_visible(True)
            axes[row, col].tick_params(bottom=False, labelbottom=True)
            axes[row, col].set_xticks([])

    plt.tight_layout()
    plt.savefig(out_path, dpi=150, bbox_inches='tight')
    plt.show()
    print(f'Saved: {out_path}')


_show_comparison_lpips(
    lpips_best5,
    title='Top 5 — Improved most BETTER than Vanilla (by LPIPS delta, lower is better)',
    out_path='comparison_best5_lpips.png',
)