In [31]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os
from PIL import Image
from tqdm import tqdm
from lavis.models import load_model_and_preprocess
import torch
from transformers import BertTokenizer
import torch.nn.functional as F

In [51]:
def init_tokenizer(): 
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    tokenizer.add_special_tokens({"bos_token": "[DEC]"})
    tokenizer.add_special_tokens({"additional_special_tokens": ["[ENC]"]})
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]
    return tokenizer

class blip2:
    def __init__(self, model_size='pretrain'):
        self.device = torch.device("cuda")
        self.model, self.vis_processor, self.text_processor = load_model_and_preprocess(name = "blip2_feature_extractor", 
                                                                          model_type = model_size, 
                                                                          is_eval = True, 
                                                                          device = self.device)
        self.tokenizer = init_tokenizer()
        self.model = self.model.to(torch.float)

    def encode_image(self, image):
        image_processed = self.vis_processor["eval"](image).unsqueeze(0).to(torch.float).to(self.device)
        
        image_embeds = self.model.ln_vision(self.model.visual_encoder(image_processed))
        image_embeds = image_embeds.float()
        image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(
                    image_embeds.device
                )
        query_tokens = self.model.query_tokens.expand(image_embeds.shape[0], -1, -1)

        query_output = self.model.Qformer.bert(
                            query_embeds=query_tokens,
                            encoder_hidden_states=image_embeds,
                            encoder_attention_mask=image_atts,
                            return_dict=True,
                        )
        image_feats = F.normalize(self.model.vision_proj(query_output.last_hidden_state), dim=-1)
      
        return image_feats[0][0].detach().cpu().numpy()


    def encode_text(self, text):
        text_input = self.text_processor["eval"](text)
        text = self.tokenizer(text_input, return_tensors="pt", padding=True).to(self.device)
        text_output = self.model.Qformer.bert(
                    text.input_ids,
                    attention_mask=text.attention_mask,
                    return_dict=True,
                )
        text_feat = F.normalize(
                    self.model.text_proj(text_output.last_hidden_state[:, 0, :]), dim=-1
                )
        return text_feat[0].detach().cpu().numpy()
    
class blip:
    def __init__(self, model_size = 'base', use_cpu = False): # model_size must be "base" or "large"
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if (use_cpu):
            self.device = 'cpu'
        
        self.model, self.vis_processor, self.txt_processor = load_model_and_preprocess(name = "blip_feature_extractor", 
                                                                          model_type = model_size, 
                                                                          is_eval = True, 
                                                                          device = self.device)
        self.tokenizer = init_tokenizer()
        

    def encode_image(self, image):
        image_processed = self.vis_processor["eval"](image).unsqueeze(0).to(self.device)
        image_embeds = self.model.visual_encoder.forward_features(image_processed)
        image_features = self.model.vision_proj(image_embeds)
        image_features = F.normalize(image_features, dim=-1)
      
        embedding = image_features[0][0].detach().cpu().numpy() # get embedding of cls tokens on ViT for representation vector.
        return embedding

    def encode_text(self, text):
        text_input = self.txt_processor["eval"](text)
        text = self.tokenizer(text_input, return_tensors="pt", padding=True).to(self.device)
        text_output = self.model.text_encoder(
                    text.input_ids,
                    attention_mask = text.attention_mask,
                    return_dict = True,
                    mode = "text",
                )
        text_embeds = text_output.last_hidden_state
        text_features = self.model.text_proj(text_embeds)
        text_features = F.normalize(text_features, dim=-1)
        embedding = text_features[0][0].detach().cpu().numpy() # get embedding of cls tokens on BERT for representation vector.
        return embedding

In [64]:
model = blip2('coco')

Position interpolate from 16x16 to 26x26


100%|██████████| 4.37G/4.37G [04:28<00:00, 17.5MB/s] 


In [65]:
testset = pd.read_json('/home/nhan-softzone/cs336/CS336.O11.KHTN/data/test.json')

In [66]:
image_embeddings = []
for i in tqdm(range(len(testset))):
    image_embeddings.append(model.encode_image(Image.open('/home/nhan-softzone/cs336/flickr30k/' + testset.loc[i]['image']).convert('RGB')))

100%|██████████| 1000/1000 [02:50<00:00,  5.88it/s]


In [67]:
texts = []
for i in range(len(testset)):
    texts += testset.loc[i]['caption']

In [68]:
text_embeddings = []
for i in tqdm(range(len(texts))):
    text_embeddings.append(model.encode_text(texts[i]))

100%|██████████| 5000/5000 [01:28<00:00, 56.80it/s]


In [69]:
ground_truth = []
for i in range(len(testset)):
    for j in range(5):
        ground_truth.append({'image_name' : testset.loc[i]['image'], ' comment' : testset.loc[i]['caption'][j]})
ground_truth = pd.DataFrame(ground_truth)

In [125]:
def recall_at_k(similarities, k):
    hit_count = 0
    for i in range(len(texts)):
        text_caption = texts[i]
        correct_image_name = caption_to_image[text_caption]
        correct_index = image_names.index(correct_image_name)
        
        # Lấy k indices có giá trị similarity cao nhất
        top_k_indices = np.argsort(-similarities[i])[:k]
        
        if correct_index in top_k_indices:
            hit_count += 1
    
    return hit_count / len(texts)

In [126]:
# Chuyển đổi dữ liệu sang numpy arrays cho tính toán hiệu quả
import time 
st_time = time.time()
text_embeddings = np.array(text_embeddings)
image_embeddings = np.array(image_embeddings)
image_names = testset['image'].to_list()

caption_to_image = {}
for i in range(len(ground_truth)):
    caption_to_image[ground_truth.iloc[i][' comment']] = ground_truth.iloc[i]['image_name']

# Tính cosine similarity giữa text và image embeddings
import json
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

similarities = text_embeddings @ image_embeddings.T

# Tính Recall@1, Recall@5 và Recall@10
recall_1 = recall_at_k(similarities, 1)
recall_5 = recall_at_k(similarities, 5)
recall_10 = recall_at_k(similarities, 10)
print(time.time() - st_time)
recall_20 = recall_at_k(similarities, 20)
recall_40 = recall_at_k(similarities, 40)
recall_80 = recall_at_k(similarities, 80)
recall_100 = recall_at_k(similarities, 100)

print(f"Recall@1: {recall_1}")
print(f"Recall@5: {recall_5}")
print(f"Recall@10: {recall_10}")
print(f"Recall@20: {recall_20}")
print(f"Recall@40: {recall_40}")
print(f"Recall@80: {recall_80}")
print(f"Recall@100: {recall_100}")

1.5791950225830078
Recall@1: 0.745
Recall@5: 0.9378
Recall@10: 0.9724
Recall@20: 0.9868
Recall@40: 0.994
Recall@80: 0.998
Recall@100: 0.998


In [123]:
print('''
BLIP2 Features Extractor - pretrain ViT-L
Recall@1: 0.5264
Recall@5: 0.7682
Recall@10: 0.8444
Recall@20: 0.9014
Recall@40: 0.939
Recall@80: 0.9706
Recall@100: 0.979
      
BLIP2 Features Extractor - pretrain ViT-G
Recall@1: 0.5052
Recall@5: 0.7774
Recall@10: 0.8552
Recall@20: 0.9134
Recall@40: 0.9536
Recall@80: 0.9808
Recall@100: 0.9862
      
BLIP Features Extractor - base
Recall@1: 0.713
Recall@5: 0.915
Recall@10: 0.9494
Recall@20: 0.9712
Recall@40: 0.9856
Recall@80: 0.9924
Recall@100: 0.9942
      
BLIP2 COCO
Recall@1: 0.745
Recall@5: 0.9378
Recall@10: 0.9724
Recall@20: 0.9868
Recall@40: 0.994
Recall@80: 0.998
Recall@100: 0.998''')


BLIP2 Features Extractor - pretrain ViT-L
Recall@1: 0.5264
Recall@5: 0.7682
Recall@10: 0.8444
Recall@20: 0.9014
Recall@40: 0.939
Recall@80: 0.9706
Recall@100: 0.979
      
BLIP2 Features Extractor - pretrain ViT-G
Recall@1: 0.5052
Recall@5: 0.7774
Recall@10: 0.8552
Recall@20: 0.9134
Recall@40: 0.9536
Recall@80: 0.9808
Recall@100: 0.9862
      
BLIP Features Extractor - base
Recall@1: 0.713
Recall@5: 0.915
Recall@10: 0.9494
Recall@20: 0.9712
Recall@40: 0.9856
Recall@80: 0.9924
Recall@100: 0.9942
      
BLIP2 COCO
Recall@1: 0.745
Recall@5: 0.9378
Recall@10: 0.9724
Recall@20: 0.9868
Recall@40: 0.994
Recall@80: 0.998
Recall@100: 0.998


In [103]:
retrievals = []
ground_truths = []
for i in range(len(image_names[:4])):
    retrievals.append(list(np.argsort(-similarities.T[i])[:10]))
    ground_truths.append([i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4])

In [115]:
def recall_at_k(retrieved_items, ground_truth_items, k):
    hits = sum(item in retrieved_items[:k] for item in ground_truth_items)
    return hits / min(k, len(ground_truth_items))
def overall_recall_at_k(all_retrievals, all_ground_truths, k):
    recalls = [recall_at_k(all_retrievals[i], ground_truths[i], k) 
               for i in range(len(all_retrievals))]
    return sum(recalls) / len(recalls) if recalls else 0

In [122]:
overall_recall_at_k(retrievals, ground_truths, 10)

0.9