In [None]:
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from utils import apply_mask_for_display

VANILLA_DIR  = 'eval_results'
IMPROVED_DIR = './eval_results_improved'
NOVEL_DIR    = './eval_results_novel'
ROWS_PER_FIG = 25  # images per figure cell â€” tune if too slow/fast

def _inpainted(directory, idx):
    return os.path.exists(os.path.join(directory, 'inpainted', f'{idx:04d}.png'))

# Union of all indices across all three dirs
all_indices = set()
for d in [VANILLA_DIR, IMPROVED_DIR, NOVEL_DIR]:
    inpainted_dir = os.path.join(d, 'inpainted')
    if os.path.isdir(inpainted_dir):
        all_indices |= {int(f[:4]) for f in os.listdir(inpainted_dir) if f.endswith('.png')}
indices = sorted(all_indices)
print(f'{len(indices)} images found')

def load(directory, subdir, idx, ext='png'):
    path = os.path.join(directory, subdir, f'{idx:04d}.{ext}')
    if ext == 'pt':
        return torch.load(path, weights_only=True)
    return Image.open(path).convert('RGB')

def load_orig_and_mask(idx):
    """Load original + mask from whichever dir has them."""
    for d in [VANILLA_DIR, NOVEL_DIR, IMPROVED_DIR]:
        orig_path = os.path.join(d, 'originals', f'{idx:04d}.png')
        mask_path = os.path.join(d, 'masks', f'{idx:04d}.pt')
        if os.path.exists(orig_path) and os.path.exists(mask_path):
            return Image.open(orig_path).convert('RGB'), torch.load(mask_path, weights_only=True)
    return None, None

for batch_start in range(0, len(indices), ROWS_PER_FIG):
    batch = indices[batch_start:batch_start + ROWS_PER_FIG]
    n = len(batch)
    show_improved = any(_inpainted(IMPROVED_DIR, i) for i in batch)
    show_novel    = any(_inpainted(NOVEL_DIR, i) for i in batch)
    ncols = 2 + sum([True, show_improved, show_novel])  # orig, masked, vanilla, +improved, +novel

    col_titles = ['Original', 'Masked', 'Vanilla']
    if show_improved:
        col_titles.append('Improved')
    if show_novel:
        col_titles.append('Novel')

    fig, axes = plt.subplots(n, ncols, figsize=(4 * ncols, 4 * n))
    if n == 1:
        axes = [axes]

    for ci, title in enumerate(col_titles):
        axes[0][ci].set_title(title, fontsize=11)

    for row, idx in enumerate(batch):
        orig, mask = load_orig_and_mask(idx)
        masked_vis = apply_mask_for_display(orig, mask) if orig is not None else None

        col = 0
        axes[row][col].imshow(orig) if orig is not None else axes[row][col].axis('off')
        axes[row][col].set_ylabel(f'#{idx:04d}', fontsize=10, rotation=0, labelpad=40, va='center')
        col += 1

        axes[row][col].imshow(masked_vis) if masked_vis is not None else axes[row][col].axis('off')
        col += 1

        if _inpainted(VANILLA_DIR, idx):
            axes[row][col].imshow(load(VANILLA_DIR, 'inpainted', idx))
        else:
            axes[row][col].text(0.5, 0.5, 'N/A', ha='center', va='center', transform=axes[row][col].transAxes)
        col += 1

        if show_improved:
            if _inpainted(IMPROVED_DIR, idx):
                axes[row][col].imshow(load(IMPROVED_DIR, 'inpainted', idx))
            else:
                axes[row][col].text(0.5, 0.5, 'N/A', ha='center', va='center', transform=axes[row][col].transAxes)
            col += 1

        if show_novel:
            if _inpainted(NOVEL_DIR, idx):
                axes[row][col].imshow(load(NOVEL_DIR, 'inpainted', idx))
            else:
                axes[row][col].text(0.5, 0.5, 'N/A', ha='center', va='center', transform=axes[row][col].transAxes)
            col += 1

        for c in range(ncols):
            axes[row][c].axis('off')

    plt.tight_layout()
    plt.show()
