In [None]:
import os
import torch
import matplotlib.pyplot as plt
from PIL import Image
from utils import apply_mask_for_display

VANILLA_DIR  = 'eval_results'
IMPROVED_DIR = './eval_results_improved'
NOVEL_DIR    = './eval_results_novel'
ROWS_PER_FIG = 25  # images per figure cell â€” tune if too slow/fast

indices = sorted(
    int(f[:4]) for f in os.listdir(os.path.join(VANILLA_DIR, 'inpainted')) if f.endswith('.png')
)
print(f'{len(indices)} images found')

def load(directory, subdir, idx, ext='png'):
    path = os.path.join(directory, subdir, f'{idx:04d}.{ext}')
    if ext == 'pt':
        return torch.load(path, weights_only=True)
    return Image.open(path).convert('RGB')

def _has_novel(idx):
    return os.path.exists(os.path.join(NOVEL_DIR, 'inpainted', f'{idx:04d}.png'))

for batch_start in range(0, len(indices), ROWS_PER_FIG):
    batch = indices[batch_start:batch_start + ROWS_PER_FIG]
    n = len(batch)
    show_novel = any(_has_novel(i) for i in batch)
    ncols = 5 if show_novel else 4

    fig, axes = plt.subplots(n, ncols, figsize=(4 * ncols, 4 * n))
    if n == 1:
        axes = [axes]

    axes[0][0].set_title('Original',  fontsize=11)
    axes[0][1].set_title('Masked',    fontsize=11)
    axes[0][2].set_title('Vanilla',   fontsize=11)
    axes[0][3].set_title('Improved',  fontsize=11)
    if show_novel:
        axes[0][4].set_title('Novel', fontsize=11)

    for row, idx in enumerate(batch):
        orig     = load(VANILLA_DIR,  'originals', idx)
        mask     = load(VANILLA_DIR,  'masks',     idx, ext='pt')
        vanilla  = load(VANILLA_DIR,  'inpainted', idx)
        improved = load(IMPROVED_DIR, 'inpainted', idx)
        masked_vis = apply_mask_for_display(orig, mask)

        axes[row][0].imshow(orig)
        axes[row][0].set_ylabel(f'#{idx:04d}', fontsize=10, rotation=0, labelpad=40, va='center')
        axes[row][1].imshow(masked_vis)
        axes[row][2].imshow(vanilla)
        axes[row][3].imshow(improved)
        if show_novel:
            if _has_novel(idx):
                axes[row][4].imshow(load(NOVEL_DIR, 'inpainted', idx))
            else:
                axes[row][4].text(0.5, 0.5, 'N/A', ha='center', va='center', transform=axes[row][4].transAxes)

        for col in range(ncols):
            axes[row][col].axis('off')

    plt.tight_layout()
    plt.show()
