# Notebook to create the panel images to upload in the google forms for evaluation of the models

In [4]:
import os
import random
import csv
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import numpy as np
from torchvision import transforms
from PIL import ImageDraw, ImageFont
from skimage import measure

In [5]:
def create_evaluation_panels_final_randomized(flair_folder, mask_folder, generated_folders, output_folder, metadata_csv, size=512):
    """
    This function generates visual evaluation panels for synthetic lesion assessment.
    It combines original FLAIR images, synthetic masks, overlays, and generated outputs
    from multiple models into a unified image. The generated images are randomized
    in position (A, B, C), and a mapping is saved for later analysis.

    Inputs:
    - flair_folder: path to original FLAIR images
    - mask_folder: path to synthetic lesion masks
    - generated_folders: dict of model name -> path to generated lesion images
    - output_folder: where final panel images are saved
    - metadata_csv: path to CSV that stores A/B/C -> model mappings

    Output:
    - For each image: a 2-row panel image (top: original/mask/overlay, bottom: A/B/C outputs)
    - CSV: with filename and which model corresponds to each label (A, B, C)
    """

    flair_folder = Path(flair_folder)
    mask_folder = Path(mask_folder)
    output_folder = Path(output_folder)
    output_folder.mkdir(parents=True, exist_ok=True)

    image_transform = transforms.Compose([
        transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.CenterCrop(size)
    ])

    def torch_transform(pil_img):
        return image_transform(pil_img)

    def add_title(image, title, size, font_path="./fonts/DejaVuSans.ttf"):
        new_img = Image.new("RGB", (size, size + 40), (255, 255, 255))
        new_img.paste(image, (0, 0))
        draw = ImageDraw.Draw(new_img)

        try:
            font = ImageFont.truetype(font_path, 28)
        except:
            print("Font not found, using default small font")
            font = ImageFont.load_default()

        bbox = draw.textbbox((0, 0), title, font=font)
        text_w = bbox[2] - bbox[0]
        draw.text(((size - text_w) // 2, size + 5), title, fill=(0, 0, 0), font=font)
        return new_img

    flair_images = list(Path(flair_folder).glob("*.png"))
    model_names = sorted(generated_folders.keys())
    metadata = []

    for flair_path in flair_images:
        base_name = flair_path.name
        mask_path = Path(mask_folder) / base_name

        if not mask_path.exists():
            print(f"Skipping {base_name}: mask missing")
            continue

        flair_img = Image.open(flair_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")

        flair_np = np.array(flair_img)
        mask_np = np.array(mask_img)
        contours = measure.find_contours(mask_np > 0, 0.5)
        for contour in contours:
            for y, x in contour.astype(np.int32):
                if 0 <= y < flair_np.shape[0] and 0 <= x < flair_np.shape[1]:
                    flair_np[y, x] = [255, 0, 0]
        overlay_img = Image.fromarray(flair_np)

        flair_img = torch_transform(flair_img)
        mask_rgb = torch_transform(Image.fromarray(np.stack([mask_np]*3, axis=-1)))
        overlay_img = torch_transform(overlay_img)

        top_row = [
            add_title(flair_img, "Original", size),
            add_title(mask_rgb, "Mask", size),
            add_title(overlay_img, "Overlay", size),
        ]

        model_to_image = {}
        for model in model_names:
            gen_path = Path(generated_folders[model]) / base_name
            if gen_path.exists():
                gen_img = torch_transform(Image.open(gen_path).convert("RGB"))
                model_to_image[model] = gen_img
            else:
                print(f"Skipping {base_name}: missing image for model {model}")
                break

        if len(model_to_image) != 3:
            continue

        # Randomize model-to-label mapping
        items = list(model_to_image.items())
        random.shuffle(items)
        label_to_model = dict(zip(['A', 'B', 'C'], [model for model, _ in items]))
        images_ordered = [img for _, img in items]

        bottom_row = [add_title(img, label, size) for img, label in zip(images_ordered, ['A', 'B', 'C'])]

        composite = Image.new('RGB', (size * 3, (size + 40) * 2), (255, 255, 255))
        for i, img in enumerate(top_row):
            composite.paste(img, (i * size, 0))
        for i, img in enumerate(bottom_row):
            composite.paste(img, (i * size, size + 40))

        composite.save(Path(output_folder) / base_name)
        metadata.append({
            "filename": base_name,
            "A": label_to_model['A'],
            "B": label_to_model['B'],
            "C": label_to_model['C'],
        })
        print(f"Saved {base_name}")

    with open(metadata_csv, 'w', newline='') as csvfile:
        fieldnames = ['filename', 'A', 'B', 'C']
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        writer.writerows(metadata)
    print(f"Metadata saved to {metadata_csv}")


In [6]:
# Example usage
flair_folder = "./test_images/flair"
mask_folder = "./test_images/synthetic_masks_big"
generated_folders = {
    "vh": "./test_images/generated_lesions_big/vh",
    "vh_shifts": "./test_images/generated_lesions_big/vh_shifts",
    "vh_shifts_wmh": "./test_images/generated_lesions_big/vh_shifts_wmh",
}
output_folder = "./evaluation_panels_big"
metadata_csv = "./evaluation_panels_big/model_order.csv"

create_evaluation_panels_final_randomized(flair_folder, mask_folder, generated_folders, output_folder, metadata_csv)


Saved WMH2017_27_9.png
Saved eval_in_30_12.png
Saved dev_out_22_7.png
Saved dev_out_24_4.png
Saved VH_749_6.png
Saved VH_741_1.png
Saved train_14_5.png
Saved eval_in_18_3.png
Saved VH_746_5.png
Saved VH_752_10.png
Saved VH_727_8.png
Saved VH_754_11.png
Saved WMH2017_132_2.png
Saved WMH2017_58_7.png
Saved WMH2017_103_6.png
Saved dev_out_25_8.png
Saved eval_in_4_9.png
Saved train_22_0.png
Saved VH_758_12.png
Saved train_4_11.png
Saved WMH2017_56_0.png
Saved train_32_6.png
Saved WMH2017_50_1.png
Saved train_6_2.png
Saved dev_in_4_1.png
Saved WMH2017_59_5.png
Saved WMH2017_6_11.png
Saved WMH2017_101_4.png
Saved WMH2017_137_8.png
Saved WMH2017_65_3.png
Saved WMH2017_33_10.png
Saved VH_751_9.png
Saved WMH2017_126_12.png
Saved VH_729_2.png
Saved VH_745_4.png
Saved VH_738_7.png
Saved eval_in_25_10.png
Saved VH_648_0.png
Saved VH_739_3.png
Metadata saved to ./evaluation_panels_big/model_order.csv
