In [1]:
from typing import Tuple
import torch
from torch.nn import functional as F
from torchmetrics import Accuracy, Recall, RetrievalMRR

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def k_nearest_neighbor(
    prediction_features: torch.Tensor, 
    query_features: torch.Tensor = None, 
    labels: torch.Tensor = None, 
    k: int = 20, 
    temperature: float = 0.1
    ) -> Tuple:
    
    probabilities = []
    predictions = []
        
    num_classes = len(set(list(labels.numpy())))
    
    if query_features is None:
        # means that similarity is computed between prediction features and itself
        query_features = prediction_features
        zero_diagonal = True
        trim_preds = False
    else:
        zero_diagonal = False
        trim_preds = True
        
    num_chunks = 100
    num_test_samples = query_features.size()[0]
    samples_per_chunk = num_test_samples // num_chunks
        
    for idx in range(0, num_test_samples, samples_per_chunk):
        
        chunk_features = query_features[
            idx : min((idx + samples_per_chunk), num_test_samples), :
        ]
        chunk_labels = labels[
            idx : min((idx + samples_per_chunk), num_test_samples)
        ]
        
        batch_size = chunk_labels.shape[0]
        
        similarity = F.normalize(chunk_features) @ F.normalize(prediction_features).t() 
        torch.diagonal(similarity, 0).zero_() if zero_diagonal else None
        distances, indices = similarity.topk(k, largest=True, sorted=True)
        candidates = labels.view(1, -1).expand(batch_size, -1)
        retrieved_neighbors = torch.gather(candidates, 1, indices)
        
        retrieval_one_hot = torch.zeros(batch_size * k, num_classes)
        retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
        distances_transform = (distances / temperature).exp_()
        
        probs = torch.sum(
            torch.mul(
                retrieval_one_hot.view(batch_size, -1, num_classes),
                distances_transform.view(batch_size, -1, 1),
            ),
            1,
        )
        probs.div_(probs.sum(dim=1, keepdim=True))
        probs_sorted, preds = probs.sort(1, True)
        
        probabilities.append(probs)
        predictions.append(preds)
    
    probabilities = torch.cat(probabilities, dim=0)
    predictions = torch.cat(predictions, dim=0)
    
    return probabilities, predictions, labels


In [4]:
features_unimodal = torch.rand(40000, 64)

num_classes = 1000
labels_unimodal = torch.randint(0, num_classes, (features_unimodal.size()[0],))

probabilities_unimodal, _, labels_generated_unimodal, num_classes_unimodal = k_nearest_neighbor(prediction_features=features_unimodal, labels=labels_unimodal)

acc = Accuracy(num_classes=num_classes_unimodal, top_k=5)
acc(probabilities_unimodal, labels_unimodal)

tensor(0.9900)

In [None]:
features_multimodal = torch.rand(4000, 64)
queries_multimodal = torch.randn(4000, 64)

labels_multimodal = torch.tensor(list(range(features_multimodal.size()[0])))

probabilities_multimodal, _, labels_generated_multimodal, num_classes_multimodal = k_nearest_neighbor(prediction_features=features_multimodal, query_features=queries_multimodal, labels=labels_multimodal)

rec = Recall(num_classes=num_classes_multimodal, top_k=5)
rec.compute(probabilities_multimodal, labels_multimodal)

In [30]:
# size of tensor in GB
print(f"Size of predictions: {probabilities_multimodal.size()[0] * probabilities_multimodal.size()[1] * 4 / 1e9} GB")

Size of predictions: 6.4 GB


In [360]:
text_pred = torch.randn(6, 64)
img_pred = torch.randn(6, 64)
labels = torch.tensor(list(range(img_pred.size()[0])))

probs, predictions, labels = k_nearest_neighbor(prediction_features=img_pred, query_features=text_pred, labels=labels, k=3)

mrr = RetrievalMRR()
indexes = torch.tensor([[n]*len(labels) for n in range(len(labels))], dtype=torch.long).flatten()
preds = probs.flatten()
target = torch.eye(len(labels)).flatten()
print(mrr(preds, target, indexes=indexes))

tensor(0.0002)