In [None]:
from typing import Any, Dict, List, Literal

import matplotlib.pyplot as plt
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from sklearn.metrics import accuracy_score, top_k_accuracy_score, roc_auc_score, f1_score, precision_score
from src.gorillatracker.utils.labelencoder import LinearSequenceEncoder

In [None]:
# TODO: pass dm and get negative samples from datamodule
# TODO: encode labels and get negative labels

In [None]:
def knn_ssl(embeddings: torch.Tensor, labels: torch.Tensor, average: Literal["micro", "macro", "weighted", "none"],negatives: Dict[int, List[int]], k:int = 1) -> Dict[str, Any]:
    true_labels = []
    pred_labels = []
    pred_labels_top5 = []
    
    for label in labels.unique():
        subset_labels = negatives[label.item()] + [label.item()]
        if(len(subset_labels) < 2):
            continue
        subset_mask = torch.isin(labels,torch.tensor(subset_labels))
        subset_embeddings = embeddings[subset_mask]
        subset_label_values = labels[subset_mask]
        knn = NearestNeighbors(n_neighbors=max(5,k)+1,algorithm='auto').fit(subset_embeddings.numpy())
        current_label_mask = (subset_label_values == label.item())
        current_label_embeddings = subset_embeddings[current_label_mask]
        distances, indices = knn.kneighbors(current_label_embeddings.numpy())
        distances = distances[:, 1:]
        indices = indices[:, 1:]
        for idx_list in indices:
            neighbor_labels = subset_label_values[idx_list]
            most_common = torch.mode(neighbor_labels[:k]).values.item()
            true_labels.append(label.item())
            pred_labels.append(most_common)
            pred_labels_top5.append(neighbor_labels[:5].numpy())
                    
    true_labels = torch.tensor(true_labels)
    pred_labels = torch.tensor(pred_labels)
    
    pred_labels_top5_tensor = torch.tensor(pred_labels_top5)
    top5_correct = []
    for i, true_label in enumerate(true_labels):
        if true_label in pred_labels_top5_tensor[i]:
            top5_correct.append(1)
        else:
            top5_correct.append(0)
    top5_accuracy = sum(top5_correct) / len(top5_correct)

    accuracy = accuracy_score(true_labels, pred_labels)
    f1 = f1_score(true_labels, pred_labels, average=average)
    precision = precision_score(true_labels, pred_labels, average=average,zero_division=0)
    
    return {'accuracy': accuracy, 'accuracy_top5': top5_accuracy, 'f1': f1, 'precision': precision}

In [None]:
ids = range(1,200)
num_points_per_id = 20
grid_size = (10, 10)  # Define the grid size
np.random.seed(43)
# Generate the embeddings and labels
embeddings = []
labels = []
negatives = {}

for id in ids:
    for _ in range(num_points_per_id):
        # Generate random x and y coordinates within the grid
        x = np.random.uniform(0, grid_size[0])
        y = np.random.uniform(0, grid_size[1])
        embeddings.append([x, y])
        labels.append(id)
        negatives[id] = [i for i in ids if i != id and np.random.random() < 0.01]

# Convert the lists to Torch tensors
embeddings_tensor = torch.tensor(embeddings)
labels_tensor = torch.tensor(labels)

# negatives = {1: [2,3], 2: [1, 3], 3: [1, 2, 4], 4: [3, 5], 5: [4]}
print(knn_ssl(embeddings_tensor, labels_tensor, "macro", negatives, k=1))

In [None]:
for i in range(20):
    print(knn_ssl(embeddings_tensor, labels_tensor, "macro", negatives, k=1))

In [None]:
# Define the subset of labels we want to plot
subset_labels = [1, 2, 5, 4, 3]

# Create a mask to filter embeddings and labels
mask = torch.isin(labels_tensor, torch.tensor(subset_labels))
subset_embeddings = embeddings_tensor[mask]
subset_labels_tensor = labels_tensor[mask]

# Plotting the embeddings for the subset
plt.figure(figsize=(10, 8))
scatter = plt.scatter(subset_embeddings[:, 0], subset_embeddings[:, 1], c=subset_labels_tensor, cmap='viridis', alpha=0.7)

# Adding a colorbar and title
plt.colorbar(scatter, ticks=subset_labels)
plt.title('Embeddings Visualization')
plt.xlabel('Embedding Dimension 1')
plt.ylabel('Embedding Dimension 2')
plt.grid(True)
plt.show()