In [None]:
import os
import warnings

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

os.chdir("..")

from dataset import image_label_paths_from_dir
from utils import load_colormap


In [None]:
COMPOSED_DIR = "data/composed"
image_paths, label_paths = image_label_paths_from_dir(COMPOSED_DIR)

# Remove partially empty labels
pairs = zip(image_paths, label_paths)
pairs = [p for p in pairs if np.all(Image.open(p[1]))]

print(len(pairs), "fully annotated pairs found")

Supplementary Table 1: Error of random point annotation

In [None]:
num_points = 50
N = 10000 # experiments per image
aposteriori = True

maes = []
mapes = []
plos = []

def dist_from_arr(arr, classes=10):
    dist = np.bincount(arr.ravel(), minlength=classes + 1)[1:]
    return dist / dist.sum()

for _, label_path in pairs:
    label = np.asarray(Image.open(label_path))

    # Class distribution from segmentation mask
    dist = dist_from_arr(label)

    # Generate random poins
    x = np.random.randint(label.shape[0], size=(N, num_points))
    y = np.random.randint(label.shape[1], size=(N, num_points))

    # Class distribution from random points
    rp_label = label[x, y]
    rp_dists = np.apply_along_axis(dist_from_arr, 1, rp_label)

    # Class "leave out" probability
    left_out = np.logical_xor(rp_dists, dist).astype(float)
    if aposteriori:
        left_out[:, dist == 0] = np.nan # if class is not in gt image
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        left_out = np.nanmean(left_out, axis=0)
    plos.append(left_out)

    # Error
    mae = np.abs(rp_dists - dist)
    if aposteriori:
        mae[:, dist == 0] = np.nan # if class is not in gt image
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        mae = np.nanmean(mae, axis=0)
    maes.append(mae)

    with np.errstate(divide='ignore', invalid='ignore'):
        mape = (rp_dists - dist) / dist
    if not aposteriori:
        mape = np.nan_to_num(mape)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)
        mape = np.nanmean(np.abs(mape), axis=0)
    mapes.append(mape)

plos = np.stack(plos)
plos_mean = np.nanmean(plos, axis=0) * 100
plos_serr = np.nanstd(plos, axis=0) / np.sqrt(N) / np.sqrt(len(pairs)) * 100

maes = np.stack(maes)
maes_mean = np.nanmean(maes, axis=0) * 100
maes_serr = np.nanstd(maes, axis=0) / np.sqrt(N) / np.sqrt(len(pairs)) * 100

mape = np.stack(mapes)
mape_mean = np.nanmean(mape, axis=0) * 100
mape_serr = np.nanstd(mape, axis=0) / np.sqrt(N) / np.sqrt(len(pairs)) * 100

classes = list(load_colormap()[0].keys())[1:]

print(f"{'class':20s}  {'p_leftout':9s}  {'mae':9s}  {'mape':9s}")
for i, c in enumerate(classes):
    print(f"{c:20s}  {plos_mean[i]:4.1f}+-{plos_serr[i]:.1f}  "
          f"{maes_mean[i]:4.2f}+-{maes_serr[i]:.2f}  "
          f"{mape_mean[i]:4.1f}+-{mape_serr[i]:.1f}")

Supplementary Figure 1: Visualization of random point sampling error

In [None]:
im = Image.open(pairs[6][1])
label = np.asarray(im)

seed = 1237 #1237
x = np.random.RandomState(seed).randint(im.width, size=(3, num_points))
y = np.random.RandomState(seed).randint(im.height, size=(3, num_points))

fig, axes = plt.subplots(2, 3, figsize=(9, 7.5), 
    gridspec_kw={'height_ratios': [16, 1]})

height = 0.2
classes, palette = load_colormap()
classes.pop("empty")
colors = np.asarray(palette).reshape((-1, 3)) / 255

for i, (ax1, ax2) in enumerate(axes.T):
    ax1.set_axis_off()
    ax1.set_title(f"Trial {i+1}", size="large")
    ax1.imshow(im)
    ax1.scatter(x[i], y[i], s=10, facecolors="none", edgecolors="black")

    # Class distribution from random points
    rp_dists = dist_from_arr(label[y[i], x[i]]) * 100

    cover_prev = 0
    for cover, c in zip(rp_dists, list(classes.keys())):
        color = colors[classes[c]]
        ax2.barh(1, cover, height, left=cover_prev, 
            color=color, label=c)

        if cover > 0:
            ax2.text(int(cover_prev + cover / 2), 1, str(int(cover)), 
                ha='center', va='center', color="black")
        
        cover_prev += cover

    ax2.set_yticks([])
    ax2.set_xlim(0, 100)
    ax2.set_xlabel("Coverage [%]")

handles, labels = ax2.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(by_label.values(), by_label.keys(), loc="lower center", ncol=5,
    bbox_to_anchor=(0.5, -0.07), frameon=False)

fig.tight_layout()
fig.savefig("results/figures/random point sampling.svg", dpi=300, bbox_inches="tight")