In [None]:
import os
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image

os.chdir("..")

from dataset import image_label_paths_from_dir

Supplementary Figure 2: Semantic segmentation of images tiles

In [None]:
VALIDATION_DIR = "data/validation"
PREDICTION_DIR = "results/prediction"

# Collect paths
image_paths, label_paths = image_label_paths_from_dir(VALIDATION_DIR)
image_paths, label_paths = np.asarray(image_paths), np.asarray(label_paths)

os.path.isdir(PREDICTION_DIR)
pred_paths = []
for label_path in label_paths:
    root, ext = os.path.splitext(os.path.basename(label_path))
    pred_paths.append(os.path.join(PREDICTION_DIR, f"{root}_pred{ext}"))
pred_paths = np.asarray(pred_paths)
assert all(os.path.isfile(p) for p in pred_paths)

# Visualize some samples
indices = np.random.choice(len(image_paths), 24, replace=False)
fig, axes = plt.subplots(len(indices)//3, 9, figsize=(10, 9))

for i, ((ax1, ax2, ax3), im, mask, pred) in enumerate(zip(axes.reshape((-1, 3)), 
    image_paths[indices], label_paths[indices], pred_paths[indices])):

    ax1.imshow(Image.open(im))
    ax1.set_axis_off()
    ax2.imshow(Image.open(mask))
    ax2.set_axis_off()
    ax3.imshow(Image.open(pred))
    ax3.set_axis_off()

    if i < axes.shape[1]//3:
        ax1.set_title("Image Tile")
        ax2.set_title("Ground truth")
        ax3.set_title("Prediction")

fig.tight_layout(pad=0.2)
fig.savefig("results/figures/image tiles.svg", dpi=200)