## Image Captioning using KNN

Although VLMs (Vision Language Models) are the go to tools for image captioning right now, there are interesting works from earlier years that used KNN for captioning and perform surprisingly well enough!

Further, Libraries like [Faiss](https://engineering.fb.com/2017/03/29/data-infrastructure/faiss-a-library-for-efficient-similarity-search/) perform the nearest neighbor computation efficiently and are used in many industrial applications.

- In this question you will implement an algorithm to perform captioning using KNN based on the paper [A Distributed Representation Based Query Expansion Approach for
Image Captioning](https://aclanthology.org/P15-2018.pdf)

- Dataset: [MS COCO](https://cocodataset.org/#home) 2014 (val set only)

- Algorithm:
    1. Given: Image embeddings and correspond caption embeddings (5 Per image)
    1. For every image, findout the k nearest images and compute its query vector as the weighted sum of the captions of the nearest images (k*5 captions per image)
    1. The predicted caption would be the caption in the dataset that is closest to the query vector. (for the sake of the assignment use the same coco val set captions as the dataset)

- The image and text embeddings are extracted from the [CLIP](https://openai.com/research/clip) model. (You need not know about this right now)

- Tasks:
    1. Implement the algorithm and compute the bleu score. Use Faiss for nearest neighbor computation. Starter code is provided below.
    1. Try a few options for k. Record your observations.
    1. For a fixed k, try a few options in the Faiss index factory to speed the computation in step 2. Record your observations.
    1. Qualitative study: Visualize five images, their ground truth captions and the predicted caption.
    
Note: Run this notebook on Colab for fastest resu

In [1]:
!gdown 1RwhwntZGZ9AX8XtGIDAcQD3ByTcUiOoO #image embeddings

Downloading...
From: https://drive.google.com/uc?id=1RwhwntZGZ9AX8XtGIDAcQD3ByTcUiOoO
To: /content/coco_imgs.npy
100% 83.0M/83.0M [00:00<00:00, 151MB/s]


In [2]:
!gdown 1b-4hU2Kp93r1nxMUGEgs1UbZov0OqFfW #caption embeddings

Downloading...
From (original): https://drive.google.com/uc?id=1b-4hU2Kp93r1nxMUGEgs1UbZov0OqFfW
From (redirected): https://drive.google.com/uc?id=1b-4hU2Kp93r1nxMUGEgs1UbZov0OqFfW&confirm=t&uuid=adda0530-64c8-4ef4-b5b6-0efbe742f47e
To: /content/coco_captions.npy
100% 415M/415M [00:02<00:00, 168MB/s]


In [3]:
!wget http://images.cocodataset.org/zips/val2014.zip
!unzip /content/val2014.zip
!wget http://images.cocodataset.org/annotations/annotations_trainval2014.zip
!unzip /content/annotations_trainval2014.zip
!pip install faiss-cpu

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
 extracting: val2014/COCO_val2014_000000376046.jpg  
 extracting: val2014/COCO_val2014_000000066394.jpg  
 extracting: val2014/COCO_val2014_000000277584.jpg  
 extracting: val2014/COCO_val2014_000000383518.jpg  
 extracting: val2014/COCO_val2014_000000262466.jpg  
 extracting: val2014/COCO_val2014_000000153055.jpg  
 extracting: val2014/COCO_val2014_000000155312.jpg  
 extracting: val2014/COCO_val2014_000000010440.jpg  
 extracting: val2014/COCO_val2014_000000541108.jpg  
 extracting: val2014/COCO_val2014_000000202503.jpg  
 extracting: val2014/COCO_val2014_000000161308.jpg  
 extracting: val2014/COCO_val2014_000000153013.jpg  
 extracting: val2014/COCO_val2014_000000089924.jpg  
 extracting: val2014/COCO_val2014_000000289842.jpg  
 extracting: val2014/COCO_val2014_000000191226.jpg  
 extracting: val2014/COCO_val2014_000000157465.jpg  
 extracting: val2014/COCO_val2014_000000218224.jpg  
 extracting: val2014/COCO_val2014_

In [4]:
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
from nltk.translate import bleu_score
import faiss
import numpy as np

In [5]:
def get_transform():
    transform = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor(),  # convert the PIL Image to a tensor
        transforms.Normalize(
            (0.485, 0.456, 0.406),  # normalize image for pre-trained model
            (0.229, 0.224, 0.225),
        )
    ])
    return transform

coco_dset = dset.CocoCaptions(root = '/content/val2014',
                        annFile = '/content/annotations/captions_val2014.json',
                        transform=get_transform())

print('Number of samples: ', len(coco_dset))
img, target = coco_dset[3] # load 4th sample

print("Image Size: ", img.shape)
print(target)

loading annotations into memory...
Done (t=0.50s)
creating index...
index created!
Number of samples:  40504
Image Size:  torch.Size([3, 224, 224])
['A loft bed with a dresser underneath it.', 'A bed and desk in a small room.', 'Wooden bed on top of a white dresser.', 'A bed sits on top of a dresser and a desk.', 'Bunk bed with a narrow shelf sitting underneath it. ']


In [6]:


ids = list(sorted(coco_dset.coco.imgs.keys()))
captions = []
for i in range(len(ids)):
    captions.append([ele['caption'] for ele in coco_dset.coco.loadAnns(coco_dset.coco.getAnnIds(ids[i]))][:5]) #5 per image
captions_np = np.array(captions)
print('Captions:', captions_np.shape)

Captions: (40504, 5)


In [8]:
captions_flat = captions_np.flatten().tolist()
print('Total captions:', len(captions_flat))

Total captions: 202520


In [9]:
cap_path = '/content/coco_captions.npy'
caption_embeddings = np.load(cap_path)
print('Caption embeddings',caption_embeddings.shape)

Caption embeddings (40504, 5, 512)


In [11]:
img_path = '/content/coco_imgs.npy'
image_embeddings = np.load(img_path)
print('Image embeddings',image_embeddings.shape)

Image embeddings (40504, 512)


In [10]:
def accuracy(predict, real):
    '''
    use bleu score as a measurement of accuracy
    :param predict: a list of predicted captions
    :param real: a list of actual descriptions
    :return: bleu accuracy
    '''
    accuracy = 0
    for i, pre in enumerate(predict):
        references = real[i]
        score = bleu_score.sentence_bleu(references, pre)
        accuracy += score
    return accuracy/len(predict)

In [14]:
print(image_embeddings.shape[1])
print((image_embeddings[0:1]).shape)

512
(1, 512)


In [18]:
def perform_knn(image_embeddings, caption_embeddings, k):
    index = faiss.IndexFlatIP(image_embeddings.shape[1])
    index.add(image_embeddings)

    results_dict = {}
    for i in range(len(image_embeddings)):
        D, I = index.search(image_embeddings[i:i + 1], k)

        # Store results in the dictionary
        print(i)
        results_dict[i] = {'nearest_indices': I.flatten(), 'distances': D.flatten()}

    return results_dict

# Example usage
k = 5  # Specify the value of k
knn_results = perform_knn(image_embeddings, caption_embeddings, k)

# Access results using the dictionary
for index, values in knn_results.items():
    print(f"Index {index} - Nearest Indices: {values['nearest_indices']}, Distances: {values['distances']}")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Index 35504 - Nearest Indices: [35504  7987 39097  6574 22804], Distances: [1.0006146  0.9369757  0.92908263 0.92781603 0.92535114]
Index 35505 - Nearest Indices: [35505  2635 23060 37014 38238], Distances: [1.0005968  0.88677955 0.86815053 0.86686397 0.865749  ]
Index 35506 - Nearest Indices: [35506  2272 20845  7805 20636], Distances: [1.0000198  0.85742724 0.8545218  0.8495043  0.84887695]
Index 35507 - Nearest Indices: [35507 31070 13648 20943 23629], Distances: [0.99996114 0.8791736  0.8720451  0.86503947 0.8644443 ]
Index 35508 - Nearest Indices: [35508 30805 33972  7086 10335], Distances: [1.0000963  0.7894603  0.78142774 0.7798523  0.77887857]
Index 35509 - Nearest Indices: [35509 36411  8383 14768 24752], Distances: [0.9999108  0.91609    0.9049259  0.90180016 0.89617133]
Index 35510 - Nearest Indices: [35510 18035 37625 18283  8612], Distances: [0.9998869  0.86896664 0.8644017  0.85566866 0.8468934 ]
Index 35511

In [40]:
N=5 # K NEAREST
M=5 #CAPTIONS
for index, values in knn_results.items():
    print(index)
    values = knn_results[index]
    nearest_indices = values['nearest_indices']
    distances = values['distances']
    concatenated_captions = np.sum(caption_embeddings[nearest_indices], axis=0)
    # print(concatenated_captions.shape)
    normalized_distances = 1 - (distances / np.max(distances))
    distributed_query = np.sum(concatenated_captions * normalized_distances[:, np.newaxis], axis=0)
    distributed_query = distributed_query/(N*M)
    # Add the distributed_query to knn_results dictionary
    knn_results[index]['distributed_query'] = distributed_query
    print(distributed_query.shape)




[1;30;43mStreaming output truncated to the last 5000 lines.[0m
38004
(512,)
38005
(512,)
38006
(512,)
38007
(512,)
38008
(512,)
38009
(512,)
38010
(512,)
38011
(512,)
38012
(512,)
38013
(512,)
38014
(512,)
38015
(512,)
38016
(512,)
38017
(512,)
38018
(512,)
38019
(512,)
38020
(512,)
38021
(512,)
38022
(512,)
38023
(512,)
38024
(512,)
38025
(512,)
38026
(512,)
38027
(512,)
38028
(512,)
38029
(512,)
38030
(512,)
38031
(512,)
38032
(512,)
38033
(512,)
38034
(512,)
38035
(512,)
38036
(512,)
38037
(512,)
38038
(512,)
38039
(512,)
38040
(512,)
38041
(512,)
38042
(512,)
38043
(512,)
38044
(512,)
38045
(512,)
38046
(512,)
38047
(512,)
38048
(512,)
38049
(512,)
38050
(512,)
38051
(512,)
38052
(512,)
38053
(512,)
38054
(512,)
38055
(512,)
38056
(512,)
38057
(512,)
38058
(512,)
38059
(512,)
38060
(512,)
38061
(512,)
38062
(512,)
38063
(512,)
38064
(512,)
38065
(512,)
38066
(512,)
38067
(512,)
38068
(512,)
38069
(512,)
38070
(512,)
38071
(512,)
38072
(512,)
38073
(512,)
38074
(512,)
38075
(512,)

In [41]:
print(knn_results[30]['distributed_query'])


[-1.69955636e-03 -1.46693795e-03  3.56103672e-04 -8.69904179e-04
  1.68032141e-03  1.17832984e-04  3.87524269e-05 -4.63441620e-03
 -1.75189425e-03  7.69493927e-05  2.24123313e-03 -1.43928069e-03
  5.74249076e-03 -1.94305729e-03  2.13364791e-03  1.34629605e-04
  4.10424458e-04  2.82575074e-03 -2.77643651e-03  2.24809488e-03
  4.90531675e-04  1.15928950e-03  4.15282906e-04 -5.29091631e-04
  1.59268721e-03 -1.49218176e-04  5.46985248e-04 -1.77818947e-04
  7.83832220e-04  1.37339323e-03 -1.86834252e-03 -2.17630574e-03
 -2.46135797e-03  5.10608545e-04 -7.26580247e-03 -2.00164132e-03
  1.82994572e-03 -2.37093633e-03 -2.01618276e-03 -4.31749824e-04
 -3.69761721e-04  8.67216615e-04  2.13005790e-03 -6.89231907e-04
  2.60671508e-03 -1.76101481e-03 -9.19057289e-04 -1.85369258e-03
 -1.57142384e-03 -8.92537704e-04 -2.44831177e-03 -5.44511992e-03
  1.14933820e-03 -4.83554235e-04  1.73547232e-04 -2.31264639e-04
 -4.95374063e-03 -1.52743640e-04  1.58898165e-05 -2.73596519e-03
 -5.51720243e-03  1.08445

In [66]:
from sklearn.metrics.pairwise import cosine_similarity
N = 5  # Number of nearest images (K)
M = 5  # Number of captions per image
for target_index, values in knn_results.items():
  query_vector = knn_results[target_index]['distributed_query']
  nearest_indices = knn_results[target_index]['nearest_indices']
  print(target_index)
  candidate_captions = caption_embeddings[nearest_indices].reshape(-1, caption_embeddings.shape[2])
  similarities = cosine_similarity([query_vector], candidate_captions)[0]
  closest_caption_index = np.argmax(similarities)
  # # Convert the index to the corresponding image and caption index
  closest_image_index = closest_caption_index // M
  actual_image= nearest_indices[closest_image_index]
  closest_caption_within_image = closest_caption_index % M
  knn_results[target_index]['closest_image_index']=closest_image_index
  knn_results[target_index]['closest_caption_index']=closest_caption_within_image
  ground_imagee,ground_captions = coco_dset[target_index]
  our_imagee,our_result=coco_dset[actual_image]
  # print(our_result)
  our_caption=our_result[closest_caption_within_image]
  knn_results[target_index]['our_caption']=our_caption
  knn_results[target_index]['ground_captions']=ground_captions


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
35504
35505
35506
35507
35508
35509
35510
35511
35512
35513
35514
35515
35516
35517
35518
35519
35520
35521
35522
35523
35524
35525
35526
35527
35528
35529
35530
35531
35532
35533
35534
35535
35536
35537
35538
35539
35540
35541
35542
35543
35544
35545
35546
35547
35548
35549
35550
35551
35552
35553
35554
35555
35556
35557
35558
35559
35560
35561
35562
35563
35564
35565
35566
35567
35568
35569
35570
35571
35572
35573
35574
35575
35576
35577
35578
35579
35580
35581
35582
35583
35584
35585
35586
35587
35588
35589
35590
35591
35592
35593
35594
35595
35596
35597
35598
35599
35600
35601
35602
35603
35604
35605
35606
35607
35608
35609
35610
35611
35612
35613
35614
35615
35616
35617
35618
35619
35620
35621
35622
35623
35624
35625
35626
35627
35628
35629
35630
35631
35632
35633
35634
35635
35636
35637
35638
35639
35640
35641
35642
35643
35644
35645
35646
35647
35648
35649
35650
35651
35652
35653
35654
35655
35656
35657
35658
35659

In [67]:
# Iterate through knn_results and print ground truth captions and predicted captions
for target_index, values in knn_results.items():
    print(f"Index {target_index}")

    # Display the ground truth captions
    ground_truth_captions = values['ground_captions']
    print("Ground Truth Captions:")
    for caption in ground_truth_captions:
        print(caption)

    # Display the predicted caption
    predicted_caption = values['our_caption']
    print("\nPredicted Caption:")
    print(predicted_caption)

    print("\n----------------------------------------\n")


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Traffic signal on wall attached to cement structure.
Traffic signal on the side of a bridge outside.
The traffic light is bolted on a concrete wall. 
a large cement building with a stop light resting on the side of it
The light switch box determines the color of the traffic lights.

Predicted Caption:
A sign pointing up next to a stop light

----------------------------------------

Index 9280
Ground Truth Captions:
A double decker bus driving down a street next to a tall building.
busy street with two buses and cars 
two public transit buses on a city street
A group of buses driving down a city street
An intersection with two buses and a car

Predicted Caption:
The bus is driving on the city street.

----------------------------------------

Index 9281
Ground Truth Captions:
A street sign standing on a corner in front of a building.
A street sign outside of a building with a store front
A set of street signs in front of 

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Half a cup of coffee are next to a couple pieces of bread.
A half-filled cup of coffee on a table with pastries and bread.
A cup of coffee is sitting next to a pastry.

Predicted Caption:
A half-filled cup of coffee on a table with pastries and bread.

----------------------------------------

Index 25792
Ground Truth Captions:
A group of man riding a skateboard down a street.
A man on a skateboard about to do a trick.
A young man in a park on a skateboard
A MAN IS ON HIS SKATE BOARD IN THE PARK 
A man is on his skateboard in a park setting.

Predicted Caption:
People riding skateboards and skates on a street.

----------------------------------------

Index 25793
Ground Truth Captions:
Three horse grazing on grass near a street sign.
three horses stand next to a street sign grazing 
three brown white black  horses and a an arrow sign
Three horses are grazing on grass by the road.
Three horses grazing on the side of the r

IOPub data rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_data_rate_limit`.

Current values:
NotebookApp.iopub_data_rate_limit=1000000.0 (bytes/sec)
NotebookApp.rate_limit_window=3.0 (secs)



[1;30;43mStreaming output truncated to the last 5000 lines.[0m

Predicted Caption:
a close up of a number of cows near one another 

----------------------------------------

Index 40121
Ground Truth Captions:
A person on a skateboard does an air trick.
A somewhat blurry image of a young man in the air on his skateboard. 
A man in red shirt doing a trick on skateboard.
A man flying through the air on top of a skateboard.
A person jumping up in the air on a skateboard.

Predicted Caption:
A man riding a skateboard up the side of a ramp.

----------------------------------------

Index 40122
Ground Truth Captions:
A bathroom with white fixtures and green flooring
A bathroom area with toilet, sink and tub.
A fisheye lens photograph of a residential bathroom
A sparsely furnished bathroom is dimly lit by the overhead bulb.
A picture taken with a fish bowl lens

Predicted Caption:
A bathroom with a sink, toilet, tub and a mirror. 

----------------------------------------

Index 40123
Grou

In [73]:
def accuracy_v2(predict, real):
    '''
    use bleu score as a measurement of accuracy
    :param predict: a list of predicted captions
    :param real: a list of actual descriptions
    :return: bleu accuracy
    '''
    lower_n_split = lambda x: x.lower().split()

    accuracy = 0
    for i, pre in enumerate(predict):
        refs = real[i]
        score = bleu_score.sentence_bleu(list(map(lambda ref: lower_n_split(ref), refs)), lower_n_split(pre))
        accuracy += score
    return accuracy/len(predict)

In [74]:
# Convert knn_results to the format suitable for accuracy_v2
predict_list = []
real_list = []

for target_index, values in knn_results.items():
    # Extract predicted caption
    predicted_caption = values['our_caption']

    # Extract ground truth captions
    ground_truth_captions = values['ground_captions']

    # Append to the lists
    predict_list.append(predicted_caption)
    real_list.append(ground_truth_captions)

# Calculate accuracy using accuracy_v2
accuracy_v2_result = accuracy_v2(predict_list, real_list)
print('Accuracy (v2):', accuracy_v2_result)


Accuracy (v2): 0.28449150762621117
