In [1]:
import pandas as pd
import numpy as np
from cleanlab import Datalab
import math
import matplotlib.pyplot as plt
from PIL import Image

In [2]:
features = np.load("features.npy")
pred_probs = np.load("pred_probs.npy")
true_labels = np.load("labels.npy")[:, 0]
image_paths = np.load("image_paths.npy", allow_pickle=True)

In [3]:
data = pd.DataFrame({
    "id": np.arange(len(image_paths)),  # IDs match dataset order
    "label": true_labels,  # True labels
    "image_path": image_paths  # Store image paths for visualization
})

In [None]:
# Initialize Cleanlab
lab = Datalab(data=data, label_name="label")
# Detect label issues
lab.find_issues(features=features, pred_probs=pred_probs)
# Generate report
lab.report()
# Retrieve problem indices
problematic_samples = lab.get_issues()
# Get image file paths of problematic samples
problematic_images = data.loc[problematic_samples.index, "image_path"]

In [None]:
label_issues = lab.get_issues("label")
label_issues.head()
label_issues_df = label_issues.query("is_label_issue").sort_values("label_score")
label_issues_df.head()

In [6]:
def plot_label_issue_examples(label_issues_df, num_examples=15):
    ncols = 5
    nrows = int(math.ceil(num_examples / ncols))

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(1.5 * ncols, 1.5 * nrows))
    axes_list = axes.flatten()

    label_issue_indices = label_issues_df.index.values

    for i, ax in enumerate(axes_list):
        if i >= num_examples:
            ax.axis("off")
            continue
        
        idx = int(label_issue_indices[i])
        row = label_issues_df.loc[idx]

        img_path = data.loc[data["id"] == idx, "image_path"].values[0]
        img = Image.open(img_path)

        ax.set_title(f"id: {idx}\nGT: {row.given_label}\nPRED: {row.predicted_label}", fontsize=8)
        ax.imshow(img, cmap="gray")
        ax.axis("off")

    plt.subplots_adjust(hspace=0.7)
    plt.show()

In [None]:
plot_label_issue_examples(label_issues_df, num_examples=20)

In [None]:
outlier_issues_df = lab.get_issues("outlier")
outlier_issues_df = outlier_issues_df.query("is_outlier_issue").sort_values("outlier_score")
outlier_issues_df.head()

In [9]:
def plot_outlier_issues_examples(outlier_issues_df, num_examples=10, ncols=5):
    """Display outlier images in a grid without reference images.

    Args:
        outlier_issues_df (DataFrame): Cleanlab-detected outliers.
        num_examples (int): Number of outliers to display.
        ncols (int): Number of images per row.
    """
    if outlier_issues_df.empty:
        print("⚠️ No outlier issues found.")
        return

    num_examples = min(num_examples, len(outlier_issues_df))
    nrows = math.ceil(num_examples / ncols)  # Determine number of rows

    fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(2 * ncols, 2 * nrows))
    axes_list = axes.flatten() if nrows > 1 else [axes]  # Handle single-row case

    for i, (idx, row) in enumerate(outlier_issues_df.iterrows()):
        if i >= num_examples:
            break
        
        # Load outlier image
        img_path = data.loc[data.index == idx, "image_path"].values[0]
        img = Image.open(img_path)

        # Display image
        axes_list[i].imshow(img, cmap="gray")
        axes_list[i].set_title(f"id: {idx}", fontsize=8)
        axes_list[i].axis("off")

    # Hide any remaining empty subplots
    for j in range(i + 1, len(axes_list)):
        axes_list[j].axis("off")

    plt.subplots_adjust(hspace=0.5, wspace=0.5)
    plt.show()

In [None]:
plot_outlier_issues_examples(outlier_issues_df, num_examples=10)

In [None]:
near_duplicate_issues_df = lab.get_issues("near_duplicate")
near_duplicate_issues_df = near_duplicate_issues_df.query("is_near_duplicate_issue").sort_values(
    "near_duplicate_score"
)
near_duplicate_issues_df.head()

In [12]:
def plot_near_duplicate_issue_examples(near_duplicate_issues_df, num_examples=3):
    nrows = min(num_examples, len(near_duplicate_issues_df))
    seen_id_pairs = set()

    count = 0
    for idx, row in near_duplicate_issues_df.iterrows():
        img_path = data.loc[data["id"] == idx, "image_path"].values[0]
        img = Image.open(img_path)

        duplicate_images = row.near_duplicate_sets
        nd_set = set([int(i) for i in duplicate_images])
        nd_set.add(int(idx))

        if nd_set & seen_id_pairs:
            continue

        fig, axes = plt.subplots(1, len(nd_set), figsize=(len(nd_set), 3))
        if len(nd_set) == 1:
            axes = [axes]  # Ensure axes is iterable

        for i, ax in zip(list(nd_set), axes):
            img_path = data.loc[data["id"] == i, "image_path"].values[0]
            img = Image.open(img_path)

            ax.set_title(f"id: {i}", fontsize=8)
            ax.imshow(img, cmap="gray")
            ax.axis("off")

        seen_id_pairs.update(nd_set)
        count += 1
        if count >= nrows:
            break

    plt.show()

In [None]:
plot_near_duplicate_issue_examples(near_duplicate_issues_df, num_examples=10)