In [None]:
!git clone https://github.com/camlhui/inpainting-review.git

In [None]:
%cd /content/inpainting-review/src
!git pull

In [None]:
from google.colab import drive
import os


drive.mount('/content/drive')

DATA_DIR = '/content/drive/MyDrive/Colab storage/inpainting-review'

In [None]:
%cd /content/inpainting-review/src

import json
import os
from inpainting_review.models import InpaintingTask
from inpainting_review.utils.image import load_and_preprocess_image, load_and_preprocess_mask


with open("inpainting_review/tasks.json") as f:
    tasks = [InpaintingTask(**t) for t in json.load(f)["tasks"]]

image = load_and_preprocess_image(os.path.join(os.environ['DATA_DIR'], tasks[0].source_image))
mask = load_and_preprocess_mask(os.path.join(os.environ['DATA_DIR'], tasks[0].mask_image))
prompt = tasks[0].prompt

# Benchmark dataset

In [34]:
from PIL import Image


def merge_image_and_mask(image: Image, mask: Image, mask_color=(255, 0, 0, 90)):
    image = image.convert("RGBA")
    mask = mask.convert("L")

    mask_color = (255, 0, 0, 128)
    colored_mask = Image.new("RGBA", image.size, mask_color)
    mask_overlay = Image.composite(colored_mask, Image.new("RGBA", image.size, (0, 0, 0, 0)), mask)

    return Image.alpha_composite(image, mask_overlay)

## Samples

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
import textwrap

from inpainting_review.utils.image import load_and_preprocess_image, load_and_preprocess_mask


fig, ax = plt.subplots(len(tasks), 2, figsize=(5*2, 4*len(tasks)))

for i, task in enumerate(tasks):
    image = load_and_preprocess_image(os.path.join(DATA_DIR, task.source_image))
    mask = load_and_preprocess_mask(os.path.join(DATA_DIR, task.mask_image))
    ax[i, 0].imshow(merge_image_and_mask(image, mask))
    ax[i, 0].set_title(task.task_id)
    ax[i, 0].axis("off")

    wrapper = textwrap.TextWrapper(width=50)
    positive_prompt = wrapper.fill(f"Positive: {task.prompt}")
    negative_prompt = wrapper.fill(f"Negative: {task.negative_prompt}")
    full_prompt_text = f"{positive_prompt}\n\n\n{negative_prompt}"
    ax[i, 1].text(0.5, 0.6, full_prompt_text, ha='center', va='center', wrap=True, fontsize=10)
    ax[i, 1].set_title("Prompts")
    ax[i, 1].axis("off")

fig.tight_layout()
fig.savefig(os.path.join(DATA_DIR, "outputs", "benchmark.png"), dpi=150)

## Grid

In [None]:
fig, ax = plt.subplots(3, 4, figsize=(4*4, 3*3))

for k, task in enumerate(tasks):
    image = load_and_preprocess_image(os.path.join(DATA_DIR, task.source_image))
    mask = load_and_preprocess_mask(os.path.join(DATA_DIR, task.mask_image))
    i, j = k // 4, k % 4
    ax[i, j].imshow(merge_image_and_mask(image, mask))
    ax[i,j].set_title(task.task_id)
    ax[i, j].axis("off")

ax[2, 2].axis("off")
ax[2, 3].axis("off")
fig.tight_layout()
fig.savefig(os.path.join(DATA_DIR, "outputs", "benchmark_grid.png"), dpi=150)

# Results

In [None]:
models = [item for item in os.listdir(os.path.join(DATA_DIR, "outputs"))
          if os.path.isdir(os.path.join(DATA_DIR, "outputs", item))]

for model in models:
    fig, ax = plt.subplots(len(tasks), 4, figsize=(4*4, 5*len(tasks)))

    for i, task in enumerate(tasks):
        image = load_and_preprocess_image(os.path.join(DATA_DIR, task.source_image))
        mask = load_and_preprocess_mask(os.path.join(DATA_DIR, task.mask_image))

        ax[i, 0].imshow(merge_image_and_mask(image, mask))
        ax[i, 0].set_title(task.task_id)
        ax[i, 0].axis("off")

        trials = [item for item in os.listdir(os.path.join(DATA_DIR, "outputs", model))
                  if os.path.isdir(os.path.join(DATA_DIR, "outputs", model, item))]
        trials = sorted(trials, reverse=True)[:3]

        for j in range(3):
            image_path = os.path.join(DATA_DIR, "outputs", model, trials[j], f"{task.task_id}.png")
            if os.path.exists(image_path):
                image = load_and_preprocess_image(image_path)
                ax[i, j+1].imshow(image)
                ax[i, j+1].set_title(f"trial {j+1}")
                ax[i, j+1].axis("off")

    fig.tight_layout()
    fig.savefig(os.path.join(DATA_DIR, "outputs", model, "results.png"), dpi=150)