In [None]:
%load_ext autoreload
%autoreload 2

from gorillatracker.datasets.cxl import CXLDataset
from gorillatracker.model import EfficientNetV2Wrapper
from gorillatracker.transform_utils import SquarePad
import pandas as pd
import wandb
import torch
import torchvision.transforms as transforms

wandb.login()
wandb.init(mode="disabled")
api = wandb.Api()

artifact = api.artifact(
    "gorillas/Embedding-ALL-SPAC-Open/model-3ag1c2vf:v1",  # your artifact name
    type="model",
)
artifact_dir = artifact.download()
model = artifact_dir + "/model.ckpt"

# load model
checkpoint = torch.load(model, map_location=torch.device("cpu"))

model = EfficientNetV2Wrapper(  # switch this with the model you want to use
    model_name_or_path="EfficientNetV2_Large",
    from_scratch=False,
    loss_mode="softmax/arcface",
    weight_decay=0.001,
    lr_schedule="cosine",
    warmup_mode="cosine",
    warmup_epochs=10,
    max_epochs=100,
    initial_lr=0.01,
    start_lr=0.01,
    end_lr=0.0001,
    beta1=0.9,
    beta2=0.999,
    embedding_size=128,
)
# the following lines are necessary to load a model that was trained with arcface (the prototypes are saved in the state dict)
model.loss_module_train.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_train.prototypes"])
model.loss_module_val.prototypes = torch.nn.Parameter(checkpoint["state_dict"]["loss_module_val.prototypes"])

model.load_state_dict(checkpoint["state_dict"])
model.eval()

# generate table that contains labels and images and embeddings
df = pd.DataFrame(columns=["label", "image", "embedding"])
dataset = CXLDataset(
    data_dir="/workspaces/gorillatracker/data/splits/ground_truth-cxl-face_images-openset-reid-val-0-test-0-mintraincount-3-seed-42-train-50-val-25-test-25",
    partition="val",
    transform=transforms.Compose(  # use the transforms that were used for the model (except of course data augmentations)
        [
            SquarePad(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
        ]
    ),
)

for i in range(len(dataset)):
    image_tensor, label = dataset[i]
    label_string = dataset.mapping[label]
    image = transforms.ToPILImage()(image_tensor)
    image_tensor = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.228, 0.224, 0.225])(
        image_tensor
    )  # if your model was trained with normalization, you need to normalize the images here as well
    embedding = model(image_tensor.unsqueeze(0))
    df = pd.concat(
        [
            df,
            pd.DataFrame(
                {
                    "label_string": [label_string],
                    "label": [label],
                    "image": [image],
                    "embedding": [embedding[0].detach().numpy()],
                }
            ),
        ]
    )

    if i % 10 == 0:
        print(f"\rprocessed {i} images")
df = df.reset_index(drop=False)

In [None]:
df

In [None]:
from gorillatracker.metrics import knn
import numpy as np

df.label = df.label.astype(int)
df.embedding = df.embedding.apply(lambda x: np.array(x, dtype=np.float32))
d = knn(df.embedding, df.label.to_numpy(), k=1)
d

In [None]:
knn(df.embedding, df.label.to_numpy(), k=5)

In [None]:
df["label"].value_counts().plot(kind="bar")

In [None]:
# let's filter out the individuals that have less than 3 images
min3labels = df["label"].value_counts()[df["label"].value_counts() >= 3].index
min3df = df[df["label"].isin(min3labels)]
min3df.reset_index(drop=True, inplace=True)
min3df["label"].value_counts().plot(kind="bar")

In [None]:
knn(torch.tensor(min3df.embedding), min3df.label.to_numpy(), k=1)

In [None]:
knn(torch.tensor(min3df.embedding), min3df.label.to_numpy(), k=5)

# Realworld
In a real world context, we'll see new individuals arriving over time. 

Options:
- on centroids (of known classes and of new individuals)
- with boolean filter (individuals seen at the same time)


Relevant Metrics:
- What is the average/min/max distance within images of an individual? 
- What is the average distance between centroids of individuals?
- What is the change in average distance between centroids of faces for increasing margins (0.5, 0.1. 1.5, 2, 4, 8)?


In [None]:
# We'll now operate on the min 3 images model:

grouped = df.groupby(["label", "label_string"])["embedding"].apply(lambda x: np.mean(np.vstack(x), axis=0))
centroid_df = pd.DataFrame({"centroid": grouped.values})
centroid_df[["label", "label_string"]] = pd.DataFrame(grouped.index.tolist(), index=centroid_df.index)
assert len(centroid_df['label'].unique()) == len(centroid_df['label_string'].unique()), "Label does not have a 1:1 mapping with label_string"
centroid_df

In [None]:
import numpy as np
from scipy.spatial.distance import cdist

for label in centroid_df['label']:
    centroid = centroid_df[centroid_df['label'] == label]['centroid'].values[0]
    embeddings = df[df['label'] == label]['embedding'].tolist()
    distances = cdist(embeddings, [centroid])
    min_distance = np.min(distances)
    max_distance = np.max(distances)
    avg_distance = np.mean(distances)
    centroid_df.loc[centroid_df['label'] == label, 'min_distance'] = min_distance
    centroid_df.loc[centroid_df['label'] == label, 'max_distance'] = max_distance
    centroid_df.loc[centroid_df['label'] == label, 'avg_distance'] = avg_distance

centroid_df


In [None]:
from scipy.spatial.distance import cdist

# Compute the pairwise distances between centroids
distances = cdist(centroid_df['centroid'].tolist(), centroid_df['centroid'].tolist())

# Compute the average distance between classes
avg_distance = np.mean(distances)

avg_distance


In [None]:
%autoreload 2

from gorillatracker.metrics import tsne

centroid_marker = 1000000
# p = tsne(torch.tensor(centroid_df.centroid.tolist()), torch.tensor(centroid_df.label.tolist()), perplexity=min(30, len(centroid_df)-1))
p = tsne(torch.tensor(df.embedding.tolist() + centroid_df.centroid.tolist()), torch.tensor(df.label.tolist() + [centroid_marker + c for c in centroid_df.label.tolist()]))

In [None]:
from sklearn.cluster import KMeans
k_means = KMeans(n_clusters=centroid_df.label.nunique(), random_state=42)
outputs = k_means.fit_predict(embeddings)
k_means.cluster_centers_, outputs

In [None]:
# Add all train embeddings
# Then add embeddings from validation and check how close they are to the centroids
# 
# Check how many images we have per individual.