In [64]:
import torch
import wandb
from PIL import Image
from torchmetrics import RetrievalMRR
from transformers import BertTokenizerFast, PerceiverFeatureExtractor

tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
feature_extractor = PerceiverFeatureExtractor()

In [12]:
ids = torch.randint(low=100, high=30000, size=(5, 10))
text = tokenizer.batch_decode(ids, skip_special_tokens=True)
text = wandb.Table(data=[[sent] for sent in text], columns=['text'])
print(text.get_column('text'))

['##lean consisting robert eddie germanic unsuccessful august closeness mariano karachi', 'simulation financially pas potentially sinister pleasure brewers [unused903] literallyudge', 'oaxaca ʔ clapping armoured hummed 227 mckenziepid remains backwards', '##aged good commander knowing phased acquire cargo organisation [unused773] lexi', 'guerre qualified rated sentences barnard explainssneriens hendrix spicy']


In [99]:
from typing import Tuple
import torch
from torch.nn import functional as F


def knn_core(prediction_features, query_features, labels, k, temperature, zero_diagonal, num_classes, batch_size):
    similarity = F.normalize(query_features) @ F.normalize(prediction_features).t()
    similarity_ground_truth = torch.diag(similarity)

    torch.diagonal(similarity, 0).zero_() if zero_diagonal else None
    distances, indices = similarity.topk(k, largest=True, sorted=True)
    candidates = labels.view(1, -1).expand(batch_size, -1)
    retrieved_neighbors = torch.gather(candidates, 1, indices)
    
    retrieval_one_hot = torch.zeros(batch_size * k, num_classes)
    retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
    distances_transform = (distances / temperature).exp_()
    
    probs = torch.sum(
        torch.mul(
            retrieval_one_hot.view(batch_size, -1, num_classes),
            distances_transform.view(batch_size, -1, 1),
        ),
        1,
    )
    probs.div_(probs.sum(dim=1, keepdim=True))
    probs_sorted, predictions = probs.sort(1, True)
    
    return similarity, similarity_ground_truth, distances, indices, probs, probs_sorted, predictions
    

def k_nearest_neighbor(
    prediction_features: torch.Tensor, 
    query_features: torch.Tensor = None, 
    labels: torch.Tensor = None, 
    k: int = 20, 
    chunking: bool = True
    ) -> Tuple:
    
    probabilities = []
    predictions = []
        
    temperature = 0.1
    num_classes = len(set(list(labels.numpy())))
    
    if query_features is None:
        # means that similarity is computed between prediction features and itself
        query_features = prediction_features
        zero_diagonal = True
        trim_preds = False
    else:
        zero_diagonal = False
        trim_preds = True
        
    if chunking:
        num_chunks = 10 #TODO this was 100 but had to be reduced to 10 to avoid OOM for local testing
        num_test_samples = query_features.size()[0]
        samples_per_chunk = num_test_samples // num_chunks
            
        for idx in range(0, num_test_samples, samples_per_chunk):
            
            
            query_chunk_features = query_features[
                idx : min((idx + samples_per_chunk), num_test_samples), :
            ]
            chunk_labels = labels[
                idx : min((idx + samples_per_chunk), num_test_samples)
            ]
            
            batch_size = chunk_labels.shape[0]
            
            similarity, similarity_ground_truth, distances, indices, probs, probs_sorted, preds = knn_core(prediction_features, query_chunk_features, labels, k, temperature, zero_diagonal, num_classes, batch_size)
            
            probabilities.append(probs)
            predictions.append(preds)
        
        probabilities = torch.cat(probabilities, dim=0)
        predictions = torch.cat(predictions, dim=0)
        
        return probabilities, predictions, labels
    else:
        batch_size = labels.shape[0]
        return knn_core(prediction_features, query_features, labels, k, temperature, zero_diagonal, num_classes, batch_size)
        


def k_nearest_neighbor_simple(
    prediction_features: torch.Tensor, 
    query_features: torch.Tensor = None, 
    labels: torch.Tensor = None, 
    k: int = 20, 
    ) -> Tuple:
        
    temperature = 0.1
    num_classes = len(set(list(labels.numpy())))
    
    if query_features is None:
        # means that similarity is computed between prediction features and itself
        query_features = prediction_features
        zero_diagonal = True
        trim_preds = False
    else:
        zero_diagonal = False
        trim_preds = True
        
    batch_size = labels.shape[0]
    
    similarity = F.normalize(query_features) @ F.normalize(prediction_features).t()
    similarity_ground_truth = torch.diag(similarity)

    torch.diagonal(similarity, 0).zero_() if zero_diagonal else None
    distances, indices = similarity.topk(k, largest=True, sorted=True)
    candidates = labels.view(1, -1).expand(batch_size, -1)
    retrieved_neighbors = torch.gather(candidates, 1, indices)
    
    retrieval_one_hot = torch.zeros(batch_size * k, num_classes)
    retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1)
    distances_transform = (distances / temperature).exp_()
    
    probs = torch.sum(
        torch.mul(
            retrieval_one_hot.view(batch_size, -1, num_classes),
            distances_transform.view(batch_size, -1, 1),
        ),
        1,
    )
    probs.div_(probs.sum(dim=1, keepdim=True))
    probs_sorted, predictions = probs.sort(1, True)
    
    return probs, predictions, labels, distances, indices, similarity_ground_truth


In [101]:
text = ["uhm what is going on you guys?", "I think you're gonna wanna see this!", "It's best if we split up.", "I wouldn't do that if I were you!", "That's gonna leave a mark.", "We are done!"]
text_pred = torch.randn(6, 64)
images = torch.randn(6, 3, 224, 224)
img_pred = torch.randn(6, 64)
labels = torch.tensor(list(range(img_pred.size()[0])))

similarity, similarity_ground_truth, top_k_distances, top_k_indices, probs, probs_sorted, predictions = k_nearest_neighbor(img_pred, text_pred, labels, 3, False)

wandb_imgs = [[wandb.Image(img, caption=text[i]) for i, img in enumerate(images)]]

table = wandb.Table(columns=['query', 'ground truth', 'similarity ground truth', '#1 prediction', 'similarity #1 prediction'])

for query, image, sim_gt, top_k_idx, top_k_dist in zip(text, images, similarity_ground_truth, top_k_indices, top_k_distances):
    table.add_data(
        query, 
        wandb.Image(image, caption=query), 
        sim_gt, 
        wandb.Image(images[top_k_idx[0]], caption=text[top_k_idx[0]]), 
        top_k_dist[0]
    )

print(table.data)

[['uhm what is going on you guys?', <wandb.sdk.data_types.image.Image object at 0x0000028AB50B3970>, tensor(-0.0881), <wandb.sdk.data_types.image.Image object at 0x0000028AB50B32E0>, tensor(0.1368)], ["I think you're gonna wanna see this!", <wandb.sdk.data_types.image.Image object at 0x0000028AB50D5D90>, tensor(0.1661), <wandb.sdk.data_types.image.Image object at 0x0000028AADB08A30>, tensor(0.1661)], ["It's best if we split up.", <wandb.sdk.data_types.image.Image object at 0x0000028AB50B3AF0>, tensor(-0.1086), <wandb.sdk.data_types.image.Image object at 0x0000028AB50B3280>, tensor(0.0863)], ["I wouldn't do that if I were you!", <wandb.sdk.data_types.image.Image object at 0x0000028AB50B3190>, tensor(0.0234), <wandb.sdk.data_types.image.Image object at 0x0000028AB50EAFA0>, tensor(0.2125)], ["That's gonna leave a mark.", <wandb.sdk.data_types.image.Image object at 0x0000028AB50D5CD0>, tensor(-0.0542), <wandb.sdk.data_types.image.Image object at 0x0000028AB50B3F70>, tensor(0.0511)], ['We a

In [103]:
q = torch.randn(4000, 64)
p = torch.randn(4000, 64)
l = torch.tensor(list(range(4000)))

probabilities, predictions, labels = k_nearest_neighbor(p, q, l, 3, True)
print(probabilities.shape, predictions.shape, labels.shape)

torch.Size([4000, 4000]) torch.Size([4000, 4000]) torch.Size([4000])
