Imports and EmbeddingProjector

In [None]:
import numpy as np
from io import BytesIO
import base64
from bokeh.io import output_notebook
from bokeh.resources import INLINE
from gorillatracker.scripts.visualize_embeddings import EmbeddingProjector
import pandas as pd
import torch

output_notebook(INLINE)

Generate embeddings from run (if embedding.pkt exists, this cell is not needed)

In [None]:
from gorillatracker.utils.embedding_generator import generate_embeddings_from_run
from gorillatracker.utils.embedding_generator import read_embeddings_from_disk
from PIL import Image

df = generate_embeddings_from_run("https://wandb.ai/gorillas/Embedding-SwinV2Large-CXL-Open/runs/4nlubzcy/workspace?nw=nwuseremirhan404", 
                                  "embedding.pkl", 
                                  "gorillatracker.datasets.cxl.CXLDataset", 
                                  "/workspaces/gorillatracker/data/splits/ground_truth-cxl-face_images-openset-reid-val-0-test-0-mintraincount-3-seed-42-train-50-val-25-test-25"
                                 )
regenerate = True

if regenerate:
    df = generate_embeddings_from_run("https://wandb.ai/gorillas/Embedding-SwinV2-CXL-Open/runs/7wg98d3l/workspace", "embedding.pkl")
else:
    df = read_embeddings_from_disk("embedding.pkl")
df.head()

In [None]:
from gorillatracker.utils.embedding_generator import read_embeddings_from_disk
from bokeh.plotting import save, output_file, show
df = read_embeddings_from_disk("embedding.pkl")

embeddings = df["embedding"].to_numpy()
embeddings = np.stack(embeddings)

images = []
for id in df["id"]:
    image = Image.open(id)
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images.append(image_byte)

ep = EmbeddingProjector()
low_dim_embeddings = ep.reduce_dimensions(embeddings, method="tsne")
fig = ep.plot_clusters(
    low_dim_embeddings, df["label"], df["label_string"], images, title="Embedding Projector", figsize=(12, 10)
)
output_file(filename="embedding.html")
show(fig)

Local viewing of missclassified images (df from last executed cell of one of the above is used)

In [None]:
from torchmetrics.functional import pairwise_euclidean_distance
import matplotlib.pyplot as plt
import ipywidgets as widgets


def get_closest_indices(embeddings: torch.Tensor, k: int) -> torch.Tensor:
    distance_matrix = pairwise_euclidean_distance(embeddings)
    distance_matrix.fill_diagonal_(float("inf"))
    # Find the indices of the closest embeddings for each embedding
    closest_indices = []
    distances = []
    for i in range(len(embeddings)):
        closest_indices_i = torch.argsort(distance_matrix[i])[:k].tolist()
        distances.append([distance_matrix[ci][i] for ci in closest_indices_i])
        closest_indices.append(closest_indices_i)

    return closest_indices, distances


def get_missclassified_images(embeddings_table: pd.DataFrame, k: int) -> None:
    misclassified_images = []
    misclassified_distances = []
    labels = embeddings_table["label"]
    embeddings = embeddings_table["embedding"].to_numpy()
    embeddings = torch.stack(embeddings.tolist())
    closest_indices, distances = get_closest_indices(torch.tensor(embeddings), k)
    counter = 0
    for i in range(len(labels)):
        true_label = labels[i]
        nearest_labels = []
        for j in range(k):
            nearest_labels.append(labels[closest_indices[i][j]])
        predicted_label = max(nearest_labels, key=nearest_labels.count)
        if true_label != predicted_label:
            misclassified_images.append((i, *closest_indices[i]))
            misclassified_distances.append((i, *distances[i]))
            counter += 1

    print(f"Accuracy: {1 - counter / len(labels)}")

    return misclassified_images, misclassified_distances


k = 3  # number of closest images to display
compare_amount = 4  # amount of images to compare for each missclassified image
missclassified, distances = get_missclassified_images(df, k)
images_per_page = k + 1

labels = df["label_string"]
images = []
for id in df["id"]:
    image = Image.open(id)
    images.append(image)

height_ratios = [0.3 for _ in range(compare_amount)]
height_ratios.insert(0, 0.4)

scale_factor = 2.5 # Scale factor for the height of the subplots

# Function to display a page of images
def display_images(page):
    start = page
    fig, axs = plt.subplots(1 + compare_amount, images_per_page, figsize=(15, (compare_amount+1) * scale_factor), height_ratios=height_ratios)  # Create subplots
    
    current_labels = []    
    for i in range(images_per_page):
        if start < len(missclassified):
            ind = missclassified[start][i]
            axs[0, i].imshow(images[ind])
            if i == 0:
                axs[0, i].set_title("missclassified image (" + labels[ind] + ")")
            else:
                axs[0, i].set_title(str(i) + ". closest image (" + labels[ind] + ") \n dist: " + str(round(distances[start][i].item(), 3)))
            current_labels.append((labels[ind], images[ind], df["id"][ind]))
            axs[0, i].axis("off")
        else:
            axs[0, i].axis("off")  # Hide axes for empty subplots
    for i in range(images_per_page):
        lbl, img, id = current_labels[i]
        filtered_df = df[(df['label_string'] == lbl) & (df['id'] != id)].head(compare_amount)
        comp_images = []
        for id in filtered_df["id"]:
            image = Image.open(id)
            comp_images.append(image)
        for k in range(len(comp_images)):
            axs[k + 1, i].imshow(comp_images[k])
            axs[k + 1, i].set_title(lbl)
            axs[k + 1, i].axis("off")
        remainder = compare_amount - len(comp_images)
        for k in range(remainder):
            axs[k + 1 + len(comp_images), i].axis("off")
    plt.tight_layout()
    plt.show()


page_selector = widgets.IntSlider(min=0, max=(len(missclassified) - 1), description="Page:")
widgets.interact(display_images, page=page_selector)