In this noteboox we use the open source cleanlab tool to identify issues with the wake vision validation and test set.

First we import the necessary libraries

In [None]:
# Run this only once to be able to import modules from the project root directory
%cd ..

In [None]:

import os
import glob
os.environ["KERAS_BACKEND"] = "jax"
import cleanlab
import datasets
import yaml
import numpy as np
import tensorflow as tf
import keras
import wake_vision_loader
from ml_collections import config_dict
from experiment_config import default_cfg as cfg
import matplotlib.pyplot as plt
import math

We then load the dataset that we are interested in cleaning.

In [None]:
SPLIT = "validation"

In [None]:
person_path_list = glob.glob(f"tmp/wv_image_folder/{SPLIT}/person/*")
no_person_path_list = glob.glob(f"tmp/wv_image_folder/{SPLIT}/no_person/*")

person_dataset = datasets.Dataset.from_dict(
    {
        "image": person_path_list,
        "filename": list(map(os.path.basename, person_path_list)),
        "label": [1] * len(person_path_list),
    }
).cast_column("image", datasets.Image())

no_person_dataset = datasets.Dataset.from_dict(
    {
        "image": no_person_path_list,
        "filename": list(map(os.path.basename, no_person_path_list)),
        "label": [0] * len(no_person_path_list),
    }
).cast_column("image", datasets.Image())

ds = datasets.concatenate_datasets([person_dataset, no_person_dataset])


Now initialize the cleanlab Datalab using the dataset.

In [None]:
lab = cleanlab.Datalab(data=ds,label_name="label", image_key="image")

Next we need to use a model to get predicted probabilities for our dataset. We can make use of one of our models previously trained on the training set for this.

First get a model that we can use

In [None]:
model_yaml = "gs://wake-vision-storage/saved_models/bbox_trained2024_01_25-07_17_34_PM/config.yaml"

with tf.io.gfile.GFile(model_yaml, 'r') as fp:
    cfg = yaml.unsafe_load(fp)
    cfg = config_dict.ConfigDict(cfg)

model_path = cfg.SAVE_FILE
model = keras.saving.load_model(model_path)

Now use this model to get predicted probabilities

In [None]:
tf_ds = ds.to_tf_dataset(columns=["image", "label"])
def rename_label(x):
    x["person"] = x["label"]
    return x
tf_ds = tf_ds.map(rename_label, num_parallel_calls=tf.data.AUTOTUNE)
tf_ds = wake_vision_loader.preprocessing(tf_ds, batch_size = 128,cfg=cfg)


In [None]:
pred_probabilities = model.predict(tf_ds)

Apart from the predicted probabilites, we can improve the issue finding by also generating feature embeddings. We can simply get these embeddings from the model that we used to get the predicted probabilities.

In [None]:
embedding_model = keras.Model(inputs=model.inputs, outputs=model.get_layer("global_average_pooling2d").output)
embeddings = embedding_model.predict(tf_ds)

Now we use the predicted probabilities and the embeddings to find issues in the dataset.

In [None]:
lab.find_issues(pred_probs=pred_probabilities, features = embeddings)

Let us see a report of the issues found in the dataset.

In [None]:
lab.report()

Let us first plot some of hte label issues found in the dataset.

In [None]:
label_issues = lab.get_issues("label")
label_issues_df = label_issues.query("is_label_issue").sort_values("label_score")

In [None]:
def plot_label_issue_examples(label_issues_df, num_examples=15):
    ncols = 5
    nrows = int(math.ceil(num_examples / ncols))

    _, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(1.5 * ncols, 1.5 * nrows))
    axes_list = axes.flatten()
    label_issue_indices = label_issues_df.index.values

    for i, ax in enumerate(axes_list):
        if i >= num_examples:
            ax.axis("off")
            continue
        idx = int(label_issue_indices[i])
        row = label_issues.loc[idx]
        ax.set_title(
            f"id: {idx}\n GL: {row.given_label}\n SL: {row.predicted_label}",
            fontdict={"fontsize": 8},
        )
        ax.imshow(ds[idx]["image"], cmap="gray")
        ax.axis("off")
    plt.subplots_adjust(hspace=0.7)
    plt.show()

In [None]:
plot_label_issue_examples(label_issues_df, num_examples=15)

Next let us take a look at some outliers.

In [None]:
outlier_issues_df = lab.get_issues("outlier")
outlier_issues_df = outlier_issues_df.query("is_outlier_issue").sort_values("outlier_score")
# Note: This pulldown content is for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.

def plot_outlier_issues_examples(outlier_issues_df, num_examples):
    ncols = 4
    nrows = num_examples
    N_comparison_images = ncols - 1

    def sample_from_class(label, number_of_samples, index):
        index = int(index)

        non_outlier_indices = (
            label_issues.join(outlier_issues_df)
            .query("given_label == @label and is_outlier_issue.isnull()")
            .index
        )
        non_outlier_indices_excluding_current = non_outlier_indices[non_outlier_indices != index]

        sampled_indices = np.random.choice(
            non_outlier_indices_excluding_current, number_of_samples, replace=False
        )

        label_scores_of_sampled = label_issues.loc[sampled_indices]["label_score"]

        top_score_indices = np.argsort(label_scores_of_sampled.values)[::-1][:N_comparison_images]

        top_label_indices = sampled_indices[top_score_indices]

        sampled_images = [ds[int(i)]["image"] for i in top_label_indices]

        return sampled_images

    def get_image_given_label_and_samples(idx):
        image_from_dataset = ds[idx]["image"]
        corresponding_label = label_issues.loc[idx]["given_label"]
        comparison_images = sample_from_class(corresponding_label, 30, idx)[:N_comparison_images]

        return image_from_dataset, corresponding_label, comparison_images

    count = 0
    images_to_plot = []
    labels = []
    idlist = []
    for idx, row in outlier_issues_df.iterrows():
        idx = row.name
        image, label, comparison_images = get_image_given_label_and_samples(idx)
        labels.append(label)
        idlist.append(idx)
        images_to_plot.append(image)
        images_to_plot.extend(comparison_images)
        count += 1
        if count >= nrows:
            break

    ncols = 1 + N_comparison_images
    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(1.5 * ncols, 1.5 * nrows))
    axes_list = axes.flatten()
    for i, ax in enumerate(axes_list):
        if i % ncols == 0:
            ax.set_title(f"id: {idlist[i // ncols]}\n GL: {labels[i // ncols]}", fontdict={"fontsize": 8})
        ax.imshow(images_to_plot[i], cmap="gray")
        ax.axis("off")
    plt.subplots_adjust(hspace=0.7)
    plt.show()
plot_outlier_issues_examples(outlier_issues_df, num_examples=5)

Now let us see some near duplicates

In [None]:
near_duplicate_issues_df = lab.get_issues("near_duplicate")
near_duplicate_issues_df = near_duplicate_issues_df.query("is_near_duplicate_issue").sort_values(
    "near_duplicate_score"
)
# Note: This pulldown content is for docs.cleanlab.ai, if running on local Jupyter or Colab, please ignore it.

def plot_near_duplicate_issue_examples(near_duplicate_issues_df, num_examples=3):
    nrows = num_examples
    seen_id_pairs = set()

    def get_image_and_given_label_and_predicted_label(idx):
        image = ds[idx]["image"]
        label = label_issues.loc[idx]["given_label"]
        predicted_label = label_issues.loc[idx]["predicted_label"]
        return image, label, predicted_label

    count = 0
    for idx, row in near_duplicate_issues_df.iterrows():
        image, label, predicted_label = get_image_and_given_label_and_predicted_label(idx)
        duplicate_images = row.near_duplicate_sets
        nd_set = set([int(i) for i in duplicate_images])
        nd_set.add(int(idx))

        if nd_set & seen_id_pairs:
            continue

        _, axes = plt.subplots(1, len(nd_set), figsize=(len(nd_set), 3))
        for i, ax in zip(list(nd_set), axes):
            label = label_issues.loc[i]["given_label"]
            ax.set_title(f"id: {i}\n GL: {label}", fontdict={"fontsize": 8})
            ax.imshow(ds[i]["image"], cmap="gray")
            ax.axis("off")
        seen_id_pairs.update(nd_set)
        count += 1
        if count >= nrows:
            break

    plt.show()
plot_near_duplicate_issue_examples(near_duplicate_issues_df, num_examples=5)

In [None]:
blurry = lab.get_issues("blurry")
blurry_issues_df = blurry.query("is_blurry_issue").sort_values("blurry_score")
plot_label_issue_examples(blurry_issues_df, num_examples=15)

Finally we save the datalab instance to use for fixing issues in the cleanlab_fix_issues notebook.

In [None]:
lab.save(f'tmp/wv_datalab_{SPLIT}')