In [1]:
import torch
import torch.nn as nn
import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import os, json
import pandas as pd
from tqdm import tqdm
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

## Load Dataset

In [3]:
# Load the CUB-200-2011 dataset
def load_cub_dataset(data_dir):
    images = pd.read_csv(os.path.join(data_dir, 'images.txt'), sep=' ', names=['image_id', 'file_path'])
    labels = pd.read_csv(os.path.join(data_dir, 'image_class_labels.txt'), sep=' ', names=['image_id', 'class_id'])
    classes = pd.read_csv(os.path.join(data_dir, 'classes.txt'), sep=' ', names=['class_id', 'class_name'])
    bounding_boxes = pd.read_csv(os.path.join(data_dir, 'bounding_boxes.txt'), sep=' ', names=['image_id', 'x', 'y', 'width', 'height'])
    part_locs = pd.read_csv(os.path.join(data_dir, 'parts/part_locs.txt'), sep=' ', names=['img_id', 'part_id', 'x', 'y', 'visible'])
    # parts = pd.read_csv(os.path.join(data_dir, 'parts/parts.txt'), delimiter =' ', names=['part_id', 'part_name'])
    parts = pd.read_fwf(os.path.join(data_dir, 'parts/parts.txt'), colspecs=[(0, 2), (2, None)], header=None, names=['part_id', 'part_name'])
    parts_click_locs = pd.read_csv(os.path.join(data_dir, 'parts/part_click_locs.txt'), sep = ' ', names=['image_id', 'part_id', 'x', 'y', 'visible', 'time'])
    attributes = pd.read_csv(os.path.join(data_dir, 'attributes/attributes.txt'), sep = ' ', names=['attribute_id', 'attribute_name'])
    certainties = pd.read_fwf(os.path.join(data_dir, 'attributes/certainties.txt'), colspecs=[(0, 1), (2, None)], names=["certainty_id", "certainty_name"])
    image_attribute_labels = pd.read_csv(os.path.join(data_dir, 'attributes/image_attribute_labels.txt'), sep = ' ', names=['image_id', 'attribute_id', 'is_present', 'certainty_id', 'time'])
    with open(os.path.join(data_dir, 'llava_captions.json'), 'r') as f:
        llava_captions = json.load(f)
    return images, labels, classes,  bounding_boxes, parts, part_locs, parts_click_locs, attributes, certainties, image_attribute_labels, llava_captions
data_dir = 'data'
images_dir = os.path.join(data_dir, 'images')
parts_dir = os.path.join(data_dir, 'parts')

images, labels, classes, bounding_boxes, parts, part_locs, parts_click_locs, attributes, certainties, image_attribute_labels, llava_captions = load_cub_dataset(data_dir)

print(images.head())
print(labels.head())
print(classes.head())

print(images.shape)
print(labels.shape)
print(classes.shape)

   image_id                                          file_path
0         1  001.Black_footed_Albatross/Black_Footed_Albatr...
1         2  001.Black_footed_Albatross/Black_Footed_Albatr...
2         3  001.Black_footed_Albatross/Black_Footed_Albatr...
3         4  001.Black_footed_Albatross/Black_Footed_Albatr...
4         5  001.Black_footed_Albatross/Black_Footed_Albatr...
   image_id  class_id
0         1         1
1         2         1
2         3         1
3         4         1
4         5         1
   class_id                  class_name
0         1  001.Black_footed_Albatross
1         2        002.Laysan_Albatross
2         3         003.Sooty_Albatross
3         4       004.Groove_billed_Ani
4         5          005.Crested_Auklet
(11788, 2)
(11788, 2)
(200, 2)


## CLIP

In [8]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)

In [None]:
data_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.481, 0.457, 0.408), std=(0.268, 0.261, 0.275))
])

def get_clip_img_features(img_path, transform):
    image = Image.open(img_path).convert("RGB")
    # image = transform(image)
    # image = transforms.ToPILImage()(image)
    image = clip_preprocess(image).unsqueeze(0).to(device)
    with torch.no_grad():
        img_features = clip_model.encode_image(image).cpu().numpy()
    return img_features

def get_clip_text_features(text):
    text_inputs = clip.tokenize([text]).to(device)
    with torch.no_grad():
        text_features = clip_model.encode_text(text_inputs).cpu().numpy()
    return text_features

def get_cosine_similarity(img_features: torch.Tensor, txt_features: torch.Tensor) -> torch.Tensor:
    """
    Computes the cosine similarity between image and text feature tensors.
    
    Parameters:
    img_features (torch.Tensor): Feature tensor for the image.
    txt_features (torch.Tensor): Feature tensor for the text.

    Returns:
    torch.Tensor: Cosine similarity score between -1 and 1.
    """

    img_features = img_features.squeeze()
    txt_features = txt_features.squeeze()
    
    # Compute dot product
    dot_product = torch.dot(img_features, txt_features)
    
    # Compute L2 norms
    norm_img = torch.norm(img_features, p=2)
    norm_txt = torch.norm(txt_features, p=2)
    
    # Avoid division by zero
    if norm_img == 0 or norm_txt == 0:
        return torch.tensor(0.0)  # Handle zero-vector cases
    
    # Compute cosine similarity
    similarity = dot_product / (norm_img * norm_txt)
    
    return similarity.item()

def get_clip_llava_caption_features(text):
    sentences = [sentence.strip() for sentence in text.split('.') if sentence.strip()]
    text_features = np.zeros((1, 512))
    for sentence in sentences:
        text_features += get_clip_text_features(sentence)
    text_features /= len(sentences)
    return text_features

In [16]:
data_dir = './data/'
def generate_clip_embeddings_imgs(data_dir, transform):
    img_dir = os.path.join(data_dir, 'images')    
    clip_embeds = np.zeros((11788, 512))
    for img_idx in tqdm(range(1, 11789)):
        img_path = images[images['image_id'] == img_idx]['file_path'].iloc[0]
        clip_embeds[img_idx - 1, :] = get_clip_img_features(os.path.join(img_dir, img_path), transform).squeeze()
    return clip_embeds
clip_embeds_imgs = generate_clip_embeddings_imgs(data_dir, data_transform)

np.save('./data/clip_embeds_imgs2.npy', clip_embeds_imgs)

100%|██████████| 11788/11788 [13:41<00:00, 14.36it/s]


In [13]:
data_dir = './data/'
def generate_clip_embeddings_text(data_dir):
    clip_embeds = np.zeros((11788, 512))
    for img_idx in tqdm(range(1, 11789)):
        llava_caption = llava_captions[str(img_idx)]['llava_text']
        clip_embeds[img_idx-1, :] = get_clip_llava_caption_features(llava_caption)
    return clip_embeds
clip_embeds_text = generate_clip_embeddings_text(data_dir)

np.save('./data/clip_embeds_text.npy', clip_embeds_text)

100%|██████████| 11788/11788 [1:41:10<00:00,  1.94it/s] 
