In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
from scipy.ndimage import grey_opening
from sklearn.metrics.pairwise import cosine_distances
from tqdm import tqdm

os.chdir("..")

from dataset import image_label_paths_from_dir
from predict import predict_embeddings, predict_images
from split_generate import calculate_probabilities
from training import allow_gpu_memory_growth
from utils import indexed_from_arr, load_colormap

allow_gpu_memory_growth()


In [None]:
MODEL_PATH = "results/models/4_final_balanced_efficientnet_unet/model.h5"
DATA_DIRS = ["data/training", "data/validation", "data/pool"]
RESULTS_DIR = "results"
LOAD_EMBEDDING = True

In [None]:
# Collect image paths
image_paths, label_paths, labeled = image_label_paths_from_dir(DATA_DIRS, filter_unlabeled=False)

# Filter out oversampled images
mask = np.asarray(["_over_" not in p for p in label_paths])
image_paths, label_paths, labeled = image_paths[mask], label_paths[mask], labeled[mask]

print(len(image_paths), "images found")
print(np.count_nonzero(labeled), "labeled images")

In [None]:
if LOAD_EMBEDDING:
    X_pca  = np.load(os.path.join(RESULTS_DIR, "embeddings_pca.npy"))
    X_tsne = np.load(os.path.join(RESULTS_DIR, "embeddings_tsne.npy"))
else:
    X_pca, X_tsne = predict_embeddings(MODEL_PATH, image_paths)

    np.save(os.path.join(RESULTS_DIR, "embeddings_pca.npy"), X_pca)
    np.save(os.path.join(RESULTS_DIR, "embeddings_tsne.npy"), X_tsne)

# Calculate uncertainty for all patches
entropies = [x[2].mean(axis=(-1, -2)) for x in tqdm(
    predict_images(MODEL_PATH, image_paths), total=len(image_paths) // 128)]
entropies = np.hstack(entropies)

Active Learning: Select uncertain and representative images for human annotation

In [None]:
def plot_embedding(selected, uncertainty=False):
    if uncertainty:
        fig, ax = plt.subplots(figsize=(10, 6))
        vis_kwargs = {"cmap": "inferno", "vmin": 0, "vmax": entropies.max().round(1)}

        sc = ax.scatter(X_tsne[:, 0], X_tsne[:, 1],
            c=entropies, s=3, **vis_kwargs,
            label="All")
        ax.scatter(X_tsne[selected, 0], X_tsne[selected, 1],
            edgecolor="green", s=24, facecolor="None", **vis_kwargs,
            label="Annotate by human")
        fig.colorbar(sc)
    else:
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.scatter(X_tsne[~labeled,0], X_tsne[~labeled,1], s=3, label="unlabeled")
        ax.scatter(X_tsne[labeled,0], X_tsne[labeled,1], s=3, label="labeled")
        ax.scatter(X_tsne[selected,0], X_tsne[selected,1], s=12, label="selected for AL", c="red")
        ax.legend()

    ax.set_xlabel("1st component")
    ax.set_ylabel("2nd component")

In [None]:
K = 64
k = 16

# Select by uncertainty
selected = np.argsort(entropies)[::-1]
selected = selected[~np.isin(selected, np.where(labeled)[0])] # select only unlabeled
selected = selected[:K]
plot_embedding(selected, uncertainty=True)

In [None]:
# Maximize diversity (greedy) & representativity 
# -> argmax intra-selected distance & argmin inter-unlabeled distance

# Distance between selected samples and its 16 unlabeled nearest neighbours
dists_inter = cosine_distances(X_pca[selected], X_pca[~labeled])
dists_inter = np.sort(dists_inter, axis=-1)[:,:16]
dists_inter = np.mean(dists_inter, axis=1)

while len(selected) > k:
    # Distances to nearest other selected sample
    dists_intra = cosine_distances(X_pca[selected], X_pca[selected]) # pairwise
    dists_intra = np.sort(dists_intra, axis=-1)[:,1]

    score = dists_intra / (dists_inter + 1e-7)

    # Delete sample with worst score
    worst = np.argmin(score)
    dists_inter = np.delete(dists_inter, worst, axis=0)
    selected    = np.delete(selected, worst, axis=0)

plot_embedding(selected, uncertainty=True)
plt.savefig(os.path.join(os.path.dirname(MODEL_PATH), "al_embeddings.png"), bbox_inches="tight")
plot_embedding(selected, uncertainty=False)

In [None]:
def plot_uncertainty_masks(image_paths):
    # Inference
    preds = predict_images(MODEL_PATH, image_paths, batch_size=len(image_paths))
    images, _, entropies, masks, _ = next(iter(preds))

    # Visualization
    rows = int(np.ceil(len(image_paths)/2))
    fig, axes = plt.subplots(rows, 6, figsize=(10, 1.6 * rows))

    palette = load_colormap()[1]
    for ax, im, emap, mask in zip(axes.reshape(-1, 3), images, entropies, masks):
        ax[0].imshow(im.astype(np.uint8))
        sm = ax[1].imshow(emap, vmin=0, vmax=1, cmap="inferno")
        ax[2].imshow(indexed_from_arr(mask, palette))

        ax[0].set_axis_off()
        ax[1].set_axis_off()
        ax[2].set_axis_off()

    # Global colorbar
    fig.tight_layout()
    fig.subplots_adjust(right=0.95)
    fig.colorbar(sm, cax=fig.add_axes([0.98, 0.05, 0.02, 0.9])) # right, bottom, width, top

plot_uncertainty_masks(image_paths[selected])
# plt.savefig(os.path.join(os.path.dirname(MODEL_PATH), "al_samples.png"), bbox_inches="tight")

Suggest labels for human annotators

In [None]:
def suggested_masks(image_paths: list[str], 
                   threshold: float = 0.1):
    # Inference
    preds = predict_images(MODEL_PATH, image_paths, batch_size=len(image_paths))
    images, _, entropies, masks, _ = next(iter(preds))

    # Truncate prediction based on uncertainty
    masks = masks.astype(np.uint8)
    masks[entropies > threshold] = 0

    palette = load_colormap()[1]

    suggested_masks = []
    for mask in masks:
        mask = grey_opening(mask, size=(5, 5))
        mask = indexed_from_arr(mask, palette)
        suggested_masks.append(mask)

    return suggested_masks

suggested_masks = suggested_masks(image_paths[selected])

fig, axes = plt.subplots(len(suggested_masks) // 4, 4, figsize=(8, len(suggested_masks) // 2))
for ax, mask in zip(axes.ravel(), suggested_masks):
    ax.imshow(mask)
    ax.set_axis_off()
fig.tight_layout()

In [None]:
# Upload these to your labeling tool
suggested_masks
image_paths[selected]