In [1]:
import torch
from gorillatracker.miewid_model import MiewIdNet

bin_path = "/workspaces/gorillatracker/models/miew_id.ms_face.bin"
device = "cuda" if torch.cuda.is_available() else "cpu"

model_params = {  # TODO
    "n_classes": 1977,  # some value
    "model_name": "timm/efficientnetv2_rw_m",
    "use_fc": False,
    "fc_dim": 2152,
    "dropout": 0.0,
    "loss_module": "softmax",
    "s": 49.32675426153405,
    "margin": 0.32,
    "ls_eps": 0.0,
    "theta_zero": 0.785,
    "pretrained": True,
    "margins": None,
    "k": None,
}

if bin_path:
    weights = torch.load(bin_path, map_location=torch.device(device))
    weights.pop("final.wmetric_classify.weight")  # irrelevant for inference
    weights.pop("final.warcface_margin.margins")  # irrelevant for inference

    model = MiewIdNet(**dict(model_params))
    model.to(device)
    model.final = torch.nn.Identity()
    print(model.load_state_dict(weights, strict=False))
    # model.final = torch.nn.Identity()
    model.eval()
    print("loaded checkpoint from ", bin_path)

Building Model Backbone for timm/efficientnetv2_rw_m model


model.safetensors:   0%|          | 0.00/214M [00:00<?, ?B/s]

<All keys matched successfully>
loaded checkpoint from  /workspaces/gorillatracker/models/miew_id.ms_face.bin


In [2]:
test = torch.randn(1, 3, 224, 224).to(device)
output = model(test)
print(output.shape)

torch.Size([1, 2152])


In [3]:
from gorillatracker.utils.embedding_generator import generate_embeddings
import torchvision.transforms as transforms

from gorillatracker.transform_utils import SquarePad
from gorillatracker.datasets.cxl import CXLDataset


model_transforms = transforms.Compose(
    [
        SquarePad(),
        # Uniform input, you may choose higher/lower sizes.
        transforms.Resize(440),
        transforms.ToTensor(),
        # transforms.Resize((192), antialias=True),
        # transforms_v2.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
dataset = CXLDataset(
    data_dir="/workspaces/gorillatracker/data/splits/ground_truth-cxl-face_images-openset-reid-val-0-test-0-mintraincount-3-seed-42-train-50-val-25-test-25",
    partition="val",
    transform=model_transforms,
)

df = generate_embeddings(model, dataset, device="cpu", norm_input=True)

ModuleNotFoundError: No module named 'gorillatracker.utils.embedding_generator'

In [28]:
from sklearn.manifold import Isomap, LocallyLinearEmbedding, MDS, SpectralEmbedding, TSNE
from sklearn.decomposition import PCA
import umap.umap_ as umap
import numpy as np
from io import BytesIO
import base64
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.io import output_notebook
from bokeh.resources import INLINE
import colorcet as cc
from gorillatracker.datasets.cxl import CXLDataset
from gorillatracker.transform_utils import SquarePad
import torch
import torchvision.transforms as transforms

output_notebook(INLINE)


class EmbeddingProjector:
    def __init__(self):
        self.algorithms = {
            "tsne": TSNE(n_components=2),
            "isomap": Isomap(n_components=2),
            "lle": LocallyLinearEmbedding(n_components=2),
            "mds": MDS(n_components=2),
            "spectral": SpectralEmbedding(n_components=2),
            "pca": PCA(n_components=2),
            "umap": umap.UMAP(),
        }

    def reduce_dimensions(self, embeddings, method="tsne"):
        # handle --fast_dev_run where there is a reduced number of embeddings
        assert len(embeddings) > 2
        algorithm = TSNE(n_components=2, perplexity=1)
        if len(embeddings) > 30:
            algorithm = self.algorithms.get(method, TSNE(n_components=2))
        return algorithm.fit_transform(embeddings)

    def plot_clusters(
        self, low_dim_embeddings, labels, og_labels, images, title="Embedding Projector", figsize=(12, 10)
    ):
        color_names = cc.glasbey
        color_lst = [color_names[label * 2] for label in labels]
        data = {
            "x": low_dim_embeddings[:, 0],
            "y": low_dim_embeddings[:, 1],
            "color": color_lst,
            "class": og_labels,
            "image": images,
        }

        fig = figure(tools="pan, wheel_zoom, box_zoom, reset")
        fig.scatter(
            x="x",
            y="y",
            size=12,
            fill_color="color",
            line_color="black",
            source=ColumnDataSource(data=data),
            legend_field="class",
        )

        hover = HoverTool(tooltips='<img src="data:image/jpeg;base64,@image" width="128" height="128">')
        fig.add_tools(hover)

        output_notebook()
        show(fig)

In [29]:
# plot embeddings
embeddings = df["embedding"].to_numpy()
embeddings = np.stack(embeddings)

images = []
for image in df["input"]:
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images.append(image_byte)

ep = EmbeddingProjector()
low_dim_embeddings = ep.reduce_dimensions(embeddings, method="tsne")
ep.plot_clusters(
    low_dim_embeddings, df["label"], df["label_string"], images, title="Embedding Projector", figsize=(12, 10)
)

In [30]:
from torchmetrics.functional import pairwise_euclidean_distance
import pandas as pd


def get_closest_indices(embeddings: torch.Tensor, k: int) -> torch.Tensor:
    distance_matrix = pairwise_euclidean_distance(embeddings)
    distance_matrix.fill_diagonal_(float("inf"))
    # Find the indices of the closest embeddings for each embedding
    closest_indices = []
    for i in range(len(embeddings)):
        closest_indices_i = torch.argsort(distance_matrix[i])[:k].tolist()
        closest_indices.append(closest_indices_i)

    return closest_indices


def get_missclassified_images(embeddings_table: pd.DataFrame, k: int) -> None:
    misclassified_images = []
    labels = embeddings_table["label"]
    embeddings = embeddings_table["embedding"].to_numpy()
    embeddings = torch.stack(embeddings.tolist())
    closest_indices = get_closest_indices(torch.tensor(embeddings), k)
    counter = 0
    for i in range(len(labels)):
        true_label = labels[i]
        nearest_labels = []
        for j in range(k):
            nearest_labels.append(labels[closest_indices[i][j]])
        predicted_label = max(nearest_labels, key=nearest_labels.count)
        if true_label != predicted_label:
            misclassified_images.append((i, *closest_indices[i]))
            counter += 1

    print(f"Accuracy: {1 - counter / len(labels)}")

    return misclassified_images

In [36]:
get_missclassified_images(df, 1);

Accuracy: 0.6031746031746033


  closest_indices = get_closest_indices(torch.tensor(embeddings), k)
