In [None]:
import os
import random
from glob import glob

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
from PIL import Image
from tqdm import tqdm

os.chdir("..")

from dataset import image_label_paths_from_dir
from predict import predict_embeddings, predict_images
from split_generate import calculate_probabilities
from training import allow_gpu_memory_growth
from utils import load_colormap

allow_gpu_memory_growth()


In [None]:
MODEL_PATH = "results/models/4_final_balanced_efficientnet_unet/model.h5"
DATA_DIRS = ["data/training", "data/validation", "data/pool"]
RESULTS_DIR = "results"
LOAD_EMBEDDING = True

In [None]:
# Collect image paths
image_paths, label_paths, labeled = image_label_paths_from_dir(DATA_DIRS, filter_unlabeled=False)

# Filter out oversampled images
mask = np.asarray(["_over_" not in p for p in label_paths])
image_paths, label_paths, labeled = image_paths[mask], label_paths[mask], labeled[mask]

print(len(image_paths), "images found")
print(np.count_nonzero(labeled), "labeled images")

In [None]:
if LOAD_EMBEDDING:
    X_pca  = np.load(os.path.join(RESULTS_DIR, "embeddings_pca.npy"))
    X_tsne = np.load(os.path.join(RESULTS_DIR, "embeddings_tsne.npy"))
else:
    X_pca, X_tsne = predict_embeddings(MODEL_PATH, image_paths)

    np.save(os.path.join(RESULTS_DIR, "embeddings_pca.npy"), X_pca)
    np.save(os.path.join(RESULTS_DIR, "embeddings_tsne.npy"), X_tsne)

# Calculate uncertainty for all patches
entropies = [x[2].mean(axis=(-1, -2)) for x in tqdm(
    predict_images(MODEL_PATH, image_paths), total=len(image_paths) // 128)]
entropies = np.hstack(entropies)

Figure 2a: Visualization of the semantic latent space

In [None]:
def truncate_colormap(cmap_name, minval=0.0, maxval=1.0, n=100):
    cmap = plt.get_cmap(cmap_name)
    new_cmap = LinearSegmentedColormap.from_list(
        f'trunc({cmap.name},{minval:.2f},{maxval:.2f})',
        cmap(np.linspace(minval, maxval, n)))
    return new_cmap

# Interactive plot for cluster analysis
def on_pick(event):
    # Transform index
    ind = event.ind[0]
    query = ~labeled if event.artist == scu else labeled
    idx = np.where(query)[0][ind]

    im = Image.open(image_paths[idx])
    im.save(f"results/figures/examples/{os.path.basename(image_paths[idx])}.png")
    im.show()

fig, ax = plt.subplots(figsize=(10, 6))
vis_kwargs = {"vmin":0, "vmax":0.5}

scu = ax.scatter(X_tsne[~labeled,0], X_tsne[~labeled,1],
    c=entropies[~labeled], s=10, edgecolor="none", **vis_kwargs,
    label="Unlabeled", picker=3, cmap=truncate_colormap("Blues_r", 0, 0.8))
scl = ax.scatter(X_tsne[labeled,0], X_tsne[labeled,1],
    c=entropies[labeled], s=10, edgecolor="none", **vis_kwargs,
    label="Labeled", picker=3, cmap=truncate_colormap("Greens_r", 0, 0.8))

ax.set_xlabel("1st component")
ax.set_ylabel("2nd component")
ax.set_ylim((-110, 110))
ax.set_xlim((-105, 110))
fig.colorbar(scl, pad=-0.1).set_label("Mean Image Uncertainty", labelpad=15)
fig.colorbar(scu, pad=0.03, ticks=[])
ax.legend(markerscale=2., frameon=False)

# Interactive for sample picking
# fig.canvas.mpl_connect('pick_event', on_pick)
# plt.show()

fig.tight_layout()
fig.savefig("results/figures/latent space.svg", bbox_inches="tight")

Figure 2b: Visualization of the semantic latent space

In [None]:
# Visualize picked examples
clusters = {
    "a":"Bare",
    "b":"Bare & slime",
    "c":"Slime",
    "d":"Encrusting bry. & barnacles",
    "e":"Tubeworms & encrusting bry.",
    "f":"Big individual tubeworms",
    "g":"Dense tubeworms",
    "h":"Tubeworms & arborescent bry.",
    "i":"Dense barnacles",
    "j":"Sponges & colonial tunicates",
    "k":"Small tubeworms",
    "l":"Mixed species with slime",
}

fig, axes = plt.subplots(4, 3 * len(clusters) // 4, figsize=(10, 5.5),
    gridspec_kw=dict(wspace=0.05, hspace=0.2))

for axs, c in zip(axes.reshape((-1, 3)), clusters.keys()):
    imgs = glob(f"results/figures/clusters/{c}/*.png")
    random.shuffle(imgs)
    for ax, im in zip(axs, imgs):
        im = Image.open(im)
        ax.imshow(im)
        ax.set_axis_off()
    axs[0].set_title(f"({c.upper()}) {clusters[c]}", loc="left", size="medium")

fig.savefig("results/figures/latent space examples.svg", bbox_inches="tight", dpi=150)

Supplementary Figure 4: Ordering of the macrofouling classes in latent space

In [None]:
# Calculate order of classes by area in an image
classes, palette = load_colormap()
classes.pop("empty")
colors = np.asarray(palette).reshape((-1, 3)) / 255

majors = calculate_probabilities(list(label_paths[labeled]), len(classes) + 1, 
    average=False)
majors[:,0] += 1e-6
majors = majors.argsort(axis=-1)[:,::-1]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 3))

for name, c_idx in classes.items():
    mask = majors[:,0] == c_idx
    ax1.scatter(X_tsne[labeled][mask,0], X_tsne[labeled][mask,1],
        color=colors[c_idx], s=12, edgecolor="none", label=name)
    mask = majors[:,1] == c_idx
    ax2.scatter(X_tsne[labeled][mask,0], X_tsne[labeled][mask,1],
        color=colors[c_idx], s=12, edgecolor="none", label=name)

ax1.set_title("Largest class")
ax1.set_xlabel("1st component")
ax1.set_ylabel("2nd component")
ax1.set_ylim((-110, 110))
ax1.set_xlim((-105, 110))

ax2.set_title("Second largest class")
ax2.set_xlabel("1st component")
ax2.set_ylim(ax1.get_ylim())
ax2.set_xlim(ax1.get_xlim())

handles, labels = ax1.get_legend_handles_labels()
by_label = dict(zip(labels, handles))
fig.legend(by_label.values(), by_label.keys(), loc="lower center", ncol=5,
    bbox_to_anchor=(0.5, -0.25), frameon=False, markerscale=2.)
fig.savefig("results/figures/latent space largest classes.svg", bbox_inches="tight")