# Dependencies and Variables

In [2]:
import sys
import math
import base64
import argparse
from io import BytesIO
from typing import List

import tqdm
import torch
import pandas as pd

import webdataset as wds
import matplotlib.pyplot as plt
import torchvision.transforms as T
from torch.utils.data import DataLoader
from IPython.display import Image, HTML, Markdown

sys.path.append("./")
sys.path.append("../")

from models.open_CLIP import OpenCLIP
from models.open_CLIP_adapter import OpenCLIPAdapter
from models.open_clip_wrapper import OpenCLIPWrapper
from utils.capivara_utils import download_pretrained_from_hf

ImportError: cannot import name 'narrow_tensor_by_index' from 'torch.distributed._shard._utils' (/home/alef.ferreira/miniconda3/envs/capivara/lib/python3.9/site-packages/torch/distributed/_shard/_utils.py)

In [2]:
gpu = 4
batch = 100
open_clip = True
translation = 'google'
adapter = 'hiaac-nlp/CAPIVARA-LoRA'
dataset_path = "/hadatasets/clip_pt/final_webdatasets/flickr30k_val_v2/00000.tar"

In [3]:
# %env PYTHONPATH=${PYTHONPATH}:/work/diego.moreira/CLIP-PtBr/clip_pt/src
%env TRANSFORMERS_CACHE=/work/diego.moreira/hf_dir
%env HF_HOME=/work/diego.moreira/hf_dir

env: TRANSFORMERS_CACHE=/work/diego.moreira/hf_dir
env: HF_HOME=/work/diego.moreira/hf_dir


# Auxiliar Methods

In [4]:
def tokenize(example: List, translation: str):
    image_input = vision_processor(example[0])

    captions = None

    if translation.lower() == "english":
        captions = example[1]["captions-en"]
    else:
        if len(example[1]["captions-pt"]) == 1:
            captions = example[1]["captions-pt"][0]
        else:
            if translation == "google":
                captions = example[1]["captions-pt"][1::2]
            elif translation == "marian":
                captions = example[1]["captions-pt"][0::2]

    text_input = text_tokenizer(captions)

    return image_input, text_input, captions

def format_batch(batch):
    image_input = batch[0]
    text_input = batch[1].reshape((-1, 77))
    captions_input = batch[2]
    return image_input, text_input, captions_input

def feature_extraction(model, dataloader, device):
    image_features = []
    text_features = []
    all_images = []
    all_captions = []

    model.to(device)
    model.eval()
    with torch.no_grad():
        for batch in tqdm.tqdm(dataloader, desc="Extracting features"):
            image_input, text_input, caption_input = batch
            image_input = image_input.to(device)
            text_input = text_input.to(device)
            batch = image_input, text_input

            img_features, txt_features = model(batch)

            norm_img_features = img_features / img_features.norm(dim=1, keepdim=True)
            norm_txt_features = txt_features / txt_features.norm(dim=1, keepdim=True)
            image_features.append(norm_img_features)
            text_features.append(norm_txt_features)
            all_images.append(image_input)
            all_captions.append(caption_input)

    return image_features, text_features, all_images, all_captions

def text_to_image_retrieval(text_required, model, image_features, text_features, all_images, all_texts):
    all_texts = sum(all_texts, [])
    caption = []
    df_list = []
    for text in text_required:
        if type(text) != int:
            caption.append(text)
            text_features = text_tokenizer(text)
            text_features = model.encode_text(text_features.to(device))
            text_features = text_features
        else:
            caption.append([text])
        similarities = []
        for i in tqdm.tqdm(range(len(image_features)), desc="t2i retrieval"):
            if type(text) == int:
                scores = text_features[text] @ image_features[i].t()  # shape: [batch_size, batch_size]
            else:
                scores = text_features @ image_features[i].t()  # shape: [batch_size, batch_size]
            item = {
                'score': scores.cpu(),
                'id': i,
            }
            similarities.append(item)
        similarities_df = pd.DataFrame(similarities)
        sorted_df = similarities_df.sort_values(by='score', ascending=False)
        df_list.append(sorted_df)
    return df_list, caption

def image_to_text_retrieval(image_required, image_features, text_features, all_images, all_texts):
    all_texts = sum(all_texts, [])
    images_selected = []
    df_list = []
    for image in image_required:
        images_selected.append(all_images[image])
        similarities = []
        for i in tqdm.tqdm(range(len(text_features)), desc="i2t retrieval"):
            scores = text_features[i] @ image_features[image].t()  # shape: [batch_size, batch_size]
            item = {
                'score': scores.cpu(),
                'id': i,
                'text': all_texts[i]
                }
            similarities.append(item)
        similarities_df = pd.DataFrame(similarities)
        sorted_df = similarities_df.sort_values(by='score', ascending=False)
        df_list.append(sorted_df)
    return df_list, images_selected

In [5]:
device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
print("Device: ", device)

dataset = wds.WebDataset(dataset_path) \
    .decode("pil") \
    .to_tuple("jpg;png", "json") \
    .map(lambda x: tokenize(x, translation)) \
    .batched(batch) \
    .map(lambda x: format_batch(x))

dataloader = DataLoader(dataset, batch_size=None, num_workers=10)

Device:  cuda:4


In [6]:
print(">>>>>>> Loading model")
if open_clip:
    if adapter is None:
        model_path = download_pretrained_from_hf(model_id="hiaac-nlp/CAPIVARA")
        model = OpenCLIPWrapper.load_from_checkpoint(model_path, strict=False).model
    else:
        model = OpenCLIPAdapter(inference=True, devices=device)
        model.load_adapters(pretrained_adapter=True, model_path=adapter)
else:
    model = OpenCLIP()

vision_processor = model.image_preprocessor
text_tokenizer = model.text_tokenizer

print(">>>>>>> Extracting features")
image_features, text_features, all_images, all_texts = feature_extraction(model, dataloader, device)

image_features = torch.cat(image_features, axis=0)
text_features = torch.cat(text_features, axis=0)
all_images = torch.cat(all_images, axis=0)
all_texts = sum(all_texts, [])

>>>>>>> Loading model


Some weights of the model checkpoint at xlm-roberta-base were not used when initializing XLMRobertaModel: ['lm_head.layer_norm.weight', 'lm_head.dense.bias', 'lm_head.dense.weight', 'lm_head.bias', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing XLMRobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


>>>>>>> Extracting features


Extracting features: 11it [00:11,  1.09s/it]


# Text-to-Image Retrieval

In [7]:
text =[0,250,50,750,999]  # text can be a int -> 2, or a text -> 'A imagem de um cachorro'    # text = ["Cidade de São Paulo"," Um laço azul"]

num_texts = len(text)
top_k_images_predictions, caption = text_to_image_retrieval(text, model, image_features, text_features, all_images, all_texts)

t2i retrieval: 100%|███████████████████████████████████████████████████████████████████| 1014/1014 [00:00<00:00, 14220.93it/s]
t2i retrieval: 100%|███████████████████████████████████████████████████████████████████| 1014/1014 [00:00<00:00, 31266.03it/s]
t2i retrieval: 100%|███████████████████████████████████████████████████████████████████| 1014/1014 [00:00<00:00, 30589.81it/s]
t2i retrieval: 100%|███████████████████████████████████████████████████████████████████| 1014/1014 [00:00<00:00, 31187.39it/s]
t2i retrieval: 100%|████████████████████████████████████████████████████████████████████| 1014/1014 [00:00<00:00, 4556.04it/s]


In [8]:
len(top_k_images_predictions)

5

In [9]:
def pos_tokenize(example, translation):
    image_input = example[0]
    captions = None

    if translation.lower() == "english":
        captions = example[1]["captions-en"]
    else:
        if len(example[1]["captions-pt"]) == 1:
            captions = example[1]["captions-pt"][0]
        else:
            if translation == "google":
                captions = example[1]["captions-pt"][1::2]
            elif translation == "marian":
                captions = example[1]["captions-pt"][0::2]

    text_input = text_tokenizer(captions)

    en_captions = example[1]["captions-en"]

    return image_input, text_input, captions, en_captions

def pos_format_batch(batch):
    image_input = batch[0]
    text_input = batch[1].reshape((-1, 77))
    captions_input = batch[2]
    en_captions_input = batch[3]
    return image_input, text_input, captions_input, en_captions_input

In [10]:
dataset = wds.WebDataset(dataset_path) \
    .decode("pil") \
    .to_tuple("jpg;png", "json") \
    .map(lambda x: pos_tokenize(x, translation)) \
    .batched(batch) \
    .map(lambda x: pos_format_batch(x))

dataloader = DataLoader(dataset, batch_size=None, num_workers=10)

In [12]:
all_images = []
all_caption_input = []
all_en_caption_input = []
for image_input, text_input, caption_input, en_caption_input in tqdm.tqdm(dataloader, desc="Extracting features"):
    all_images.append(image_input)
    all_caption_input.append(caption_input)
    all_en_caption_input.append(en_caption_input)
all_images = sum(all_images, [])
all_caption_input = sum(all_caption_input, [])
all_en_caption_input = sum(all_en_caption_input, [])

all_caption_input_extended = sum(all_caption_input, [])
all_en_caption_input_extended = sum(all_en_caption_input, [])


Extracting features: 11it [00:05,  2.16it/s]


In [23]:
def center_print_t2i(text_list):
    for text in text_list:
        display(HTML(f"<div style='text-align:center'>{text}</div>"))

for i in range(len(top_k_images_predictions)):
    header_content = "This is the Header"
    header_text = f"Texto #{i+1}: {all_caption_input_extended[text[i]]}"
    display(HTML(f"<h1 style='text-align:center'>{header_text}</h1>"))
    for idx, j in enumerate(range(num_texts)):
        display(HTML(f"<h2 style='text-align:center'>{'Match #' + str(idx + 1)}</div>"))
        img_id = top_k_images_predictions[i].iloc[j]['id']
        img = all_images[img_id]
        
        display(HTML(f"<h3 style='text-align:center'>Portuguese:</div>"))
        center_print_t2i(all_caption_input[img_id])
        
        display(HTML(f"<h3 style='text-align:center'>English:</div>"))
        center_print_t2i(all_en_caption_input[img_id])

        image_buffer = BytesIO()
        img.save(image_buffer, format="PNG")
        image_data = base64.b64encode(image_buffer.getvalue()).decode()
        center_image_html = f"""
        <div style="display: flex; justify-content: center;">
            <img src="data:image/png;base64,{image_data}" alt="Centered Image">
        </div>
        """
    
        # Display the centered image using HTML
        display(HTML(center_image_html))
        display(Markdown('---'))

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

---

# Image-to-Text Retrival

In [14]:
image = [1,251,501,751,1000]

num_images = len(image)
top_k_text_predictions, images = image_to_text_retrieval(image, image_features, text_features, all_images, all_texts)

i2t retrieval: 100%|███████████████████████████████████████████████████████████████████| 5070/5070 [00:00<00:00, 19339.86it/s]
i2t retrieval: 100%|███████████████████████████████████████████████████████████████████| 5070/5070 [00:00<00:00, 32423.11it/s]
i2t retrieval: 100%|███████████████████████████████████████████████████████████████████| 5070/5070 [00:00<00:00, 30908.65it/s]
i2t retrieval: 100%|███████████████████████████████████████████████████████████████████| 5070/5070 [00:00<00:00, 31750.55it/s]
i2t retrieval: 100%|███████████████████████████████████████████████████████████████████| 5070/5070 [00:00<00:00, 30741.89it/s]


In [24]:
def center_print_i2t(text_list, auxiliar_text_list=None, lang="pt", num_images=None):
    for j in range(num_images):
        if lang=="pt":
            text = text_list.iloc[j]["text"]
        else:
            text = text_list[auxiliar_text_list.iloc[j]["id"]]
        display(HTML(f"<div style='text-align:center'>{text}</div>"))

for img, i in zip(images, range(len(top_k_text_predictions))):
    # center_print(text_list)
    # print(f'------------ image[{i}] ------------')
    display(HTML(f"<h3 style='text-align:center'>Image #{i+1}</div>"))
    
    image_buffer = BytesIO()
    img.save(image_buffer, format="PNG")
    image_data = base64.b64encode(image_buffer.getvalue()).decode()
    center_image_html = f"""
    <div style="display: flex; justify-content: center;">
        <img src="data:image/png;base64,{image_data}" alt="Centered Image">
    </div>
    """
    display(HTML(center_image_html))

    display(HTML(f"<h3 style='text-align:center'>Portuguese:</div>"))
    center_print_i2t(top_k_text_predictions[i], lang="pt", num_images=num_images)
    
    display(HTML(f"<h3 style='text-align:center'>English:</div>"))
    center_print_i2t(all_en_caption_input_extended, top_k_text_predictions[i], lang="en", num_images=num_images)
    
    display(Markdown('---'))

---

---

---

---

---

In [17]:
# free memory
del model