### Task c) Use BERT as Feature Extractor

### Import required libraries

In [2]:
import glob
from itertools import chain
import os
import random
import zipfile
from tqdm.notebook import tqdm

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
from sklearn.model_selection import train_test_split
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models, ops
from typing import Any, Callable, List, Optional, Tuple
from PIL import Image
import json
from transformers import BertTokenizer, BertModel


### Load models

In [None]:
model = models.resnet50().cuda()
num_features = model.fc.in_features
model.fc = nn.Sequential(
    nn.Linear(num_features, 512))

model = model.cuda()
model.load_state_dict(torch.load('./weights-resnet.pth'))
device = "cuda"

In [69]:
# Define a simple linear layer to map BERT hidden state to caption embedding
class CaptionEncoder(nn.Module):
    def __init__(self, bert_model):
        super(CaptionEncoder, self).__init__()
        self.bert = bert_model
        self.linear = nn.Sequential(nn.Linear(768, 512))
        for param in self.linear.parameters():
            param.requires_grad = True

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        last_hidden_state = outputs.last_hidden_state
        cls_embedding = last_hidden_state[:, 0, :]
        caption_embedding = self.linear(cls_embedding)
        return caption_embedding
    
from transformers import BertTokenizer, BertModel, AutoTokenizer, AutoModel
import re

epochs = 3
device = 'cuda'

# Load pre-trained BERT model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
bert_model = AutoModel.from_pretrained('bert-base-uncased')

# Create caption encoder and optimizer
caption_encoder = CaptionEncoder(bert_model).cuda()
caption_encoder = caption_encoder.to(device)
#caption_encoder.load_state_dict(torch.load('./weights-caption-encoder-definitive.pth'))
# Our base model

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Build database

In [None]:
#!pip install faiss-gpu
import faiss            
PATH_TRAIN = "./COCO/train2014/"
PATH_VALID = "./COCO/val2014/"
ANNOTATIONS = "./COCO/mcv_image_retrieval_annotations.json"
from pycocotools.coco import COCO


with open(ANNOTATIONS, 'r') as j:
            contents = json.loads(j.read())

coco = COCO('./COCO/instances_train2014.json')

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


 # Buiid Database
database_id = []
database = []
image_ids_database = set()


for category in contents['database']:
        image_ids_database.update(contents['database'][category])

image_ids_database = sorted(list(image_ids_database))

with torch.no_grad():
    for image_id in image_ids_database:
           filename = coco.loadImgs(image_id)[0]["file_name"]
           im = Image.open(os.path.join(PATH_TRAIN,filename)).convert('RGB')
           im = im.resize((224,224))
           im = torch.tensor(np.array([val_transforms(im).numpy()])).cuda()
           output = model(im)
           output = output.detach().cpu().numpy().reshape(-1, np.prod(output.size()[1:]))
           database.append(output)
           database_id.append(image_id)
           
database = np.asarray(database)

In [71]:
database = database.reshape((database.shape[0], database.shape[1]*database.shape[2]))
index = faiss.IndexFlatL2(database.shape[1])
print("Adding database to index")
index.add(database)
print("Database added to index")

Adding database to index
Database added to index


### Build queries

In [None]:
print("Buliding list of queries...")
#Build list of queries
queries = []
queries_id = []
queries_id_test = set()
PATH_VALID = "./COCO/val2014/"

for category in contents["database"]:
       queries_id_test.update(contents["database"][category])

queries_caption = []
queries_id_test = sorted(list(queries_id_test))

coco_captions = COCO('./COCO/captions_train2014.json')
caption_encoder.eval()
with torch.no_grad():
       for idx, image_id in enumerate(queries_id_test):
              captions = [ann["caption"] for ann in coco_captions.loadAnns(coco_captions.getAnnIds(image_id))]
              print(idx, "/", len(queries_id_test))
              for caption in captions:
                     queries_id.append(image_id)
                     caption = tokenizer(caption.lower(), return_tensors="pt", padding = True)
                     input_ids = caption["input_ids"].to(device)
                     attention_mask = caption["attention_mask"].to(device)
                     embedding = caption_encoder(input_ids, attention_mask)
                     embedding = embedding.detach().cpu().numpy().reshape(-1,np.prod(embedding.size()[1:]))
                     queries.append(embedding)
        
queries = np.asarray(queries)

queries = queries.reshape((queries.shape[0], queries.shape[1]*queries.shape[2]))

print("Searching K neighbors for each query")
D, I = index.search(queries, 5)
print("Finished searching")

### Precision

In [None]:

precision = []    
for i in range(I.shape[0]): # para cada query
    TP = 0.0
    retrieved = 0.0
    y_true = []
    for j in range(I[i].shape[0]): #para cada una de las 5 imágenes que ha devuelto FAISS / KNN
        if queries_id[i] == database_id[I[i][j]]: #si el id de la imagen se corresponde la id de la query
            TP += 1
            y_true.append(1)
        else:
            y_true.append(0)
        retrieved += 1
    precision.append(TP/retrieved)

precision = np.asarray(precision)
precision = np.mean(precision)
print(precision)


### Mean average precision

In [None]:
import sklearn.metrics 
average_precisions = []    
for i in range(I.shape[0]): # para cada query
    y_true = []
    y_scores = [1.0, 0.8 ,0.6, 0.4, 0.2]
    for j in range(I[i].shape[0]): #para cada una de las 5 imágenes que ha devuelto FAISS / KNN
        if queries_id[i] == database_id[I[i][j]]: #si el id de la imagen se corresponde la id de la query
            y_true.append(1)
        else:
            y_true.append(0)
    average_precisions.append(sklearn.metrics.average_precision_score(y_true, y_scores))

average_precision = np.asarray(average_precisions)
print(np.mean(average_precision))
