In [None]:
import torch
from gorillatracker.miewid_model import MiewIdNet

bin_path = "/workspaces/gorillatracker/models/miew_id.ms_face.bin"
device = "cuda" if torch.cuda.is_available() else "cpu"

model_params = {  # TODO
    "n_classes": 1977,  # some value
    "model_name": "timm/efficientnetv2_rw_m",
    "use_fc": False,
    "fc_dim": 2152,
    "dropout": 0.0,
    "loss_module": "softmax",
    "s": 49.32675426153405,
    "margin": 0.32,
    "ls_eps": 0.0,
    "theta_zero": 0.785,
    "pretrained": True,
    "margins": None,
    "k": None,
}

if bin_path:
    weights = torch.load(bin_path, map_location=torch.device(device))
    weights.pop("final.wmetric_classify.weight")  # irrelevant for inference
    weights.pop("final.warcface_margin.margins")  # irrelevant for inference

    model = MiewIdNet(**dict(model_params))
    model.to(device)
    model.final = torch.nn.Identity()
    print(model.load_state_dict(weights, strict=False))
    # model.final = torch.nn.Identity()
    model.eval()
    print("loaded checkpoint from ", bin_path)

In [None]:
test = torch.randn(1, 3, 224, 224).to(device)
output = model(test)
print(output.shape)

In [None]:
from gorillatracker.utils.embedding_generator import generate_embeddings
import torchvision.transforms as transforms

from gorillatracker.transform_utils import SquarePad
from gorillatracker.datasets.cxl import CXLDataset


model_transforms = transforms.Compose(
    [
        SquarePad(),
        # Uniform input, you may choose higher/lower sizes.
        transforms.Resize(440),
        transforms.ToTensor(),
        # transforms.Resize((192), antialias=True),
        # transforms_v2.Normalize([0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ]
)
dataset = CXLDataset(
    data_dir="/workspaces/gorillatracker/data/splits/ground_truth-cxl-face_images-openset-reid-val-0-test-0-mintraincount-3-seed-42-train-50-val-25-test-25",
    partition="val",
    transform=model_transforms,
)

df = generate_embeddings(model, dataset, device="cpu", norm_input=True)

In [None]:
def create_latex_table(metrics, formatted_names):
    latex_table = r"""
\begin{table}
    \centering
    \resizebox{0.8\paperwidth}{!}{
    \begin{tabular}{l *{8}{c}}
    \toprule
    & \multicolumn{2}{c}{Bristol} & \multicolumn{2}{c}{SPAC} & \multicolumn{2}{c}{Bristol} & \multicolumn{2}{c}{SPAC} \\
    \cmidrule(lr){2-3} \cmidrule(lr){4-5} \cmidrule(lr){6-7} \cmidrule(lr){8-9}
    Metric & ViT-F & ViT-P & ViT-F & ViT-P & EfN-F & EfN-P & EfN-F & EfN-P \\
    \midrule
"""

    for metric, formatted_name in formatted_names.items():
        row = f"{formatted_name} "
        for model_prefix in ["ViT-", "EfN-"]:
            for dataset in ["Bristol", "SPAC"]:
                for model_suffix in ["F", "P"]:
                    model = f"{model_prefix}{model_suffix}"
                    value = metrics.get((model, dataset), {}).get(metric, "-")
                    if isinstance(value, float):
                        row += f"& {value:.2f} "
                    else:
                        row += f"& {value} "
        row += r"\\"
        latex_table += row + "\n"

    latex_table += r"""    \bottomrule
    \end{tabular}
    }
    \caption{Metrics comparison across models and datasets.}
    \label{tab:metrics-comparison}
\end{table}
"""
    return latex_table


df = merged_df
actual_metrics = {
    ("ViT-F", "SPAC"): analyse_embedding_space(df[(df["model"] == "ViT-Finetuned") & (df["dataset"] == "SPAC")]),
    ("ViT-P", "SPAC"): analyse_embedding_space(df[(df["model"] == "ViT-Pretrained") & (df["dataset"] == "SPAC")]),
    # ('EfN-F', 'SPAC'): analyse_embedding_space(df[(df['model'] == 'EfN-Finetuned') & (df['dataset'] == 'SPAC')]),
    ("EfN-P", "SPAC"): analyse_embedding_space(df[(df["model"] == "EfN-Pretrained") & (df["dataset"] == "SPAC")]),
    ("ViT-F", "Bristol"): analyse_embedding_space(df[(df["model"] == "ViT-Finetuned") & (df["dataset"] == "Bristol")]),
    ("ViT-P", "Bristol"): analyse_embedding_space(df[(df["model"] == "ViT-Pretrained") & (df["dataset"] == "Bristol")]),
    # ('EfN-F', 'Bristol'): analyse_embedding_space(df[(df['model'] == 'EfN-Finetuned') & (df['dataset'] == 'Bristol')]),
    ("EfN-P", "Bristol"): analyse_embedding_space(df[(df["model"] == "EfN-Pretrained") & (df["dataset"] == "Bristol")]),
}

print(create_latex_table(actual_metrics, formatted_names))

In [None]:
# plot embeddings
embeddings = df["embedding"].to_numpy()
embeddings = np.stack(embeddings)

images = []
for image in df["input"]:
    buffer = BytesIO()
    image.save(buffer, format="JPEG")
    image_byte = base64.b64encode(buffer.getvalue()).decode("utf-8")
    images.append(image_byte)

ep = EmbeddingProjector()
low_dim_embeddings = ep.reduce_dimensions(embeddings, method="tsne")
ep.plot_clusters(
    low_dim_embeddings, df["label"], df["label_string"], images, title="Embedding Projector", figsize=(12, 10)
)

In [None]:
from torchmetrics.functional import pairwise_euclidean_distance
import pandas as pd


def get_closest_indices(embeddings: torch.Tensor, k: int) -> torch.Tensor:
    distance_matrix = pairwise_euclidean_distance(embeddings)
    distance_matrix.fill_diagonal_(float("inf"))
    # Find the indices of the closest embeddings for each embedding
    closest_indices = []
    for i in range(len(embeddings)):
        closest_indices_i = torch.argsort(distance_matrix[i])[:k].tolist()
        closest_indices.append(closest_indices_i)

    return closest_indices


def get_missclassified_images(embeddings_table: pd.DataFrame, k: int) -> None:
    misclassified_images = []
    labels = embeddings_table["label"]
    embeddings = embeddings_table["embedding"].to_numpy()
    embeddings = torch.stack(embeddings.tolist())
    closest_indices = get_closest_indices(torch.tensor(embeddings), k)
    counter = 0
    for i in range(len(labels)):
        true_label = labels[i]
        nearest_labels = []
        for j in range(k):
            nearest_labels.append(labels[closest_indices[i][j]])
        predicted_label = max(nearest_labels, key=nearest_labels.count)
        if true_label != predicted_label:
            misclassified_images.append((i, *closest_indices[i]))
            counter += 1

    print(f"Accuracy: {1 - counter / len(labels)}")

    return misclassified_images

In [None]:
get_missclassified_images(df, 1);