In [133]:
from typing import Tuple
import torch
from torch.nn import functional as F
from torchmetrics import Accuracy, Recall, RetrievalMRR

In [351]:
def k_nearest_neighbor(
    prediction_features: torch.Tensor, 
    query_features: torch.Tensor = None, 
    labels: torch.Tensor = None, 
    k: int = 20, 
    temperature: float = 0.1
    ) -> Tuple(torch.Tensor, torch.Tensor, torch.Tensor):
    
    probabilities = []
    predictions = []
        
    num_classes = len(set(list(labels.numpy())))
    
    if query_features is None:
        # means that similarity is computed between prediction features and itself
        query_features = prediction_features
        zero_diagonal = True
        trim_preds = False
    else:
        zero_diagonal = False
        trim_preds = True
        
    num_chunks = 100
    num_test_samples = query_features.size()[0]
    samples_per_chunk = num_test_samples // num_chunks
        
    for idx in range(0, num_test_samples, samples_per_chunk):
        
        chunk_features = query_features[
            idx : min((idx + samples_per_chunk), num_test_samples), :
        ]
        chunk_labels = labels[
            idx : min((idx + samples_per_chunk), num_test_samples)
        ]
        
        batch_size = chunk_labels.shape[0]
        
        similarity = F.normalize(chunk_features) @ F.normalize(prediction_features).t() 
        torch.diagonal(similarity, 0).zero_() if zero_diagonal else None
        distances, indices = similarity.topk(k, largest=True, sorted=True)
        candidates = labels.view(1, -1).expand(batch_size, -1)
        retrieved_neighbors = torch.gather(candidates, 1, indices)
        
        retrieval_one_hot = torch.zeros(batch_size * k, num_classes)
        retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
        distances_transform = (distances / temperature).exp_()
        
        probs = torch.sum(
            torch.mul(
                retrieval_one_hot.view(batch_size, -1, num_classes),
                distances_transform.view(batch_size, -1, 1),
            ),
            1,
        )
        probs.div_(probs.sum(dim=1, keepdim=True))
        probs_sorted, preds = probs.sort(1, True)
        
        probabilities.append(probs)
        predictions.append(preds)
    
    probabilities = torch.cat(probabilities, dim=0)
    predictions = torch.cat(predictions, dim=0)
    
    return probabilities, predictions, labels, num_classes


In [352]:
queries = torch.rand(40000, 64)
preds = torch.randn(40000, 64)

num_classes = 10
labels = torch.randint(0, num_classes, (preds.size()[0],))
probabilities, predictions, labels_generated = k_nearest_neighbor(prediction_features=preds, query_features=queries)

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0.])
tensor([0.1659, 0.0651, 0.0567, 0.0547, 0.0532, 0.0505, 0.0497, 0.0491, 0.0475,
        0.0420, 0.0411, 0.0387, 0.0382, 0.0369, 0.0368, 0.0368, 0.0347, 0.0345,
        0.0344, 0.0335, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
tensor([24901, 12787,  3562, 37723, 17172, 34831,  7749, 11388, 34461, 18395,
        10454, 10677,  3214, 18817, 14909, 21315,  4674, 22027, 28746,  8902,
            7,     8,     9,    10,    11])


In [335]:
acc = Accuracy(num_classes=num_classes, top_k=5)
acc(predictions, labels)

ValueError: If `preds` have one dimension more than `target`, `preds` should be a float tensor.

In [357]:
predictions_generated = F.softmax(torch.rand(40000, 40000), dim=1)
labels_generated = torch.tensor(list(range(40000)))

In [358]:
# size of tensor in GB
print(f"Size of predictions: {predictions_generated.size()[0] * predictions_generated.size()[1] * 4 / 1e9} GB")

Size of predictions: 6.4 GB


In [360]:
rec = Recall(num_classes=len(labels_generated), top_k=5)
rec(predictions_generated, labels_generated)

tensor(0.0002)