Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrieval results on COCO validation dataset #115

Closed
mabravo641 opened this issue Jun 9, 2021 · 10 comments
Closed

Retrieval results on COCO validation dataset #115

mabravo641 opened this issue Jun 9, 2021 · 10 comments

Comments

@mabravo641
Copy link

mabravo641 commented Jun 9, 2021

Thank you for the code.
I have run the different models available and tested on COCO captions 5k validation set for retrieval using the first single caption and the results are good but still not close to the reported in the paper. Which architecture was used for that experiment (Table 13.)? I am posting the implementation of this Zero-shot retrieval evaluation.

Paper CLIP (Table 13)
i2t - 'r1': 58.4, 'r5': 81.5, 'r10': 88.1
t2i - 'r1': 37.8, 'r5': 62.4, 'r10': 72.2

Code Results:
RN50
i2t - 'r1': 31.82, 'r5': 56.56, 'r10': 67.44, 'r50': 89.86, 'medr': 4.0, 'meanr': 21.4242, 'sum': 155.82
t2i - 'r1': 27.40, 'r5': 51.52, 'r10': 62.28, 'r50': 87.14, 'medr': 5.0, 'meanr': 31.1382, 'sum': 141.2

RN101
i2t - 'r1': 32.22, 'r5': 57.7, 'r10': 68.2, 'r50': 90.46, 'medr': 4.0, 'meanr': 21.0626, 'sum': 158.12
t2i - 'r1': 28.26, 'r5': 52.64, 'r10': 63.3, 'r50': 87.78, 'medr': 5.0, 'meanr': 30.8762, 'sum': 144.2

RN50x4
i2t - 'r1': 34.8, 'r5': 59.38, 'r10': 70.76, 'r50': 91.44, 'medr': 3.0, 'meanr': 19.715, 'sum': 164.94
t2i - 'r1': 30.84, 'r5': 54.76, 'r10': 65.6, 'r50': 88.98, 'medr': 4.0, 'meanr': 29.3306, 'sum': 151.2

ViT-B/32
i2t - 'r1': 32.56, 'r5': 57.98, 'r10': 68.24, 'r50': 90.48, 'medr': 4.0, 'meanr': 20.0076, 'sum': 158.78
t2i - 'r1': 28.42, 'r5': 53.1, 'r10': 64.16, 'r50': 88.3, 'medr': 5.0, 'meanr': 27.6856, 'sum': 145.68

import numpy as np
from torch import nn
import torch
import clip
import yaml
import os
from torch.utils.data import DataLoader
from torchvision.datasets.coco import CocoCaptions

single_caption = True # choose if evalating only using the first caption
model_name = "ViT-B/32" #"RN50" #"RN50x4" #"RN101" #

def compute_similarity(image_features, text_features, bs = 1000):
    # compute similarity
    max_pairs = image_features.shape[0]
    similarity_scores = torch.zeros(max_pairs, max_pairs)
    for v in range(0, max_pairs, bs):
        for t in range(0, max_pairs, bs):
            print('Processing Visual '+str(v)+' Text '+str(t), end='\r')
            batch_visual_emb = image_features[v:v+bs]
            batch_caption_emb = text_features[t:t+bs]

            logits = batch_visual_emb @ batch_caption_emb.t()
            similarity_scores[v:v+bs,t:t+bs] = logits

    print('Done similarity')
    return similarity_scores

def compute_retrieval(a2b_sims, return_ranks=True):
    """
    Args:
        a2b_sims: Result of computing similarity between two sets of embeddings (emb1 @ emb2.T)
            with shape (num_datapoints, num_datapoints).

    Returns:
        Retrieval metrics for that similarity.
    """
    npts = a2b_sims.shape[0]
    ranks = np.zeros(npts)
    top1 = np.zeros(npts)
    # loop source embedding indices
    for index in range(npts):
        # get order of similarities to target embeddings
        inds = np.argsort(a2b_sims[index])[::-1]
        # find where the correct embedding is ranked
        where = np.where(inds == index)
        rank = where[0][0]
        ranks[index] = rank
        # save the top1 result as well
        top1[index] = inds[0]

    # Compute metrics
    r1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
    r5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
    r10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
    r50 = 100.0 * len(np.where(ranks < 50)[0]) / len(ranks)
    medr = np.floor(np.median(ranks)) + 1
    meanr = ranks.mean() + 1

    report_dict = {"r1": r1, "r5": r5, "r10": r10, "r50": r50, "medr": medr, "meanr": meanr, "sum": r1 + r5 + r10}

    if return_ranks:
        return report_dict, (ranks, top1)
    else:
        return report_dict

print(model_name)
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load(model_name, device=device)

data_root = "path-to-data"
train_root = os.path.join(data_root, 'train2017')
valid_root = os.path.join(data_root, 'val2017')
train_captions = os.path.join(data_root, 'annotations/captions_train2017.json')
valid_captions = os.path.join(data_root, 'annotations/captions_val2017.json')

valid_dataset = CocoCaptions(root = valid_root,
                        annFile = valid_captions,
                        transform = preprocess)
valid_dataloader = DataLoader(valid_dataset, batch_size = 1)

# fwd all samples
image_features = []
text_features = []
for batch_idx, batch in enumerate(valid_dataloader):
    print('Evaluating batch {}/{}'.format(batch_idx, len(valid_dataloader)), end = "\r")
    images, texts = batch
    if single_caption:
        texts = [texts[0][0]]
    else:
        texts = [txt[0] for txt in texts]

    texts = clip.tokenize(texts).cuda() #tokenize
    text_emb = model.encode_text(texts) #embed with text encoder
    if not single_caption:
        text_emb = text_emb.unsqueeze(0)

    image_emb = model.encode_image(images) #embed with image encoder
    

    text_features.append(text_emb.detach().cpu())
    image_features.append(image_emb.detach().cpu())

image_features = torch.cat(image_features, 0)
text_features = torch.cat(text_features, 0)
print('Done forward')

# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)

if not single_caption:
    for cap_idx in range(text_features.shape[1]):
        similarity_scores = compute_similarity(image_features, text_features[:,cap_idx,:])
        i2t_dict = compute_retrieval(similarity_scores.numpy())
        t2i_dict = compute_retrieval(similarity_scores.t().numpy())
        print(cap_idx, 'i2t', i2t_dict)
        print(cap_idx, 't2i', t2i_dict)
else:
    similarity_scores = compute_similarity(image_features, text_features)
    i2t_dict = compute_retrieval(similarity_scores.numpy())
    t2i_dict = compute_retrieval(similarity_scores.t().numpy())
    print('i2t', i2t_dict)
    print('t2i', t2i_dict)

Thanks,
MA

@jongwook
Copy link
Collaborator

CLIP metrics on Table 13 (and many others) refer to the largest CLIP model of ViT-L/14 fine-tuned at 336px, following this sentence in Section 2.5:

We denote this model as ViT-L/14@336px. Unless otherwise specified, all results reported in this paper as “CLIP” use this model which we found to perform best.

Apologies about the models not being publicly available yet - The team is still working on the plan, but I'm hoping all models will eventually be available, like GPT-2 did.

@justlovebarbecue
Copy link

Hi @mabravo641 @jongwook ,

Thanks for your issue and comments! I tried the updated ViT-L/14@336px for zero-shot evaluation using coco. But still get low performances which are close to what Maria posted.

I used Maria's zero-shot script with both single_caption = True and False. Both are relatively low.

Do you have any clue to reproduce the results in Tab.13?

Thanks!

@Neltherion
Copy link

@mabravo641 Hi, I wanted to ask if the code you're using is originally written by yourself to test the Image Retrieval capabilities of CLIP or was it used from some other place?

I'm looking for an Image Retrieval on COCO code that is considered the standard approach and I'm not even sure if some code like that exists or every researcher in this field implements it themselves (which could add to discrepancies between results).

Thanks!

@zzbuzzard
Copy link

For what it's worth, I tested this using my own independent code and was able to (approximately) reproduce the results. I get the following values for ViT-L/14@336px on the COCO 5k val set:

text-to-image
R@1 = 36.10% (vs 37.8% in the paper)
R@5 = 60.75% (vs 62.4% in the paper)
R@10 = 70.75% (vs 72.2% in the paper)

image-to-text
R@1 = 57.48% (vs 58.4% in the paper)
R@5 = 80.32% (vs 81.5% in the paper)
R@10 = 87.60% (vs 88,1% in the paper)

The paper does state the following (page 45):

For both these datasets we prepend the prompt “a photo
of” to the description of each image which we found boosts
CLIP’s zero-shot R@1 performance between 1 and 2 points

Implementing this change gave me values of 57.94 / 81.50 / 88.42 for image-to-text retrieval (even closer to the paper), but seemed to slightly reduce my text-to-image retrieval scores.

So, I would attribute the differences to either some other minor tricks like this that they used to slightly improve performance, or simply differences in implementation - it's not totally obvious how to handle the one-to-many relationship between images and captions.

(unfortunately I am not able to share my code, as it is part of an ongoing assessed project)

@Neltherion I was also hoping to find some commonly used definition, for example in the torchmetrics library (for example, there seems to be a standard definition for mean average precision), but failed to find one.

@Neltherion
Copy link

@zzbuzzard Thanks. Is it possible for you to upload your own independent code for me to test?

I'm really wondering if there even is a formal code to evaluate Retrieval tasks or not?

@Neltherion
Copy link

@zzbuzzard Oh! I just saw the part you mentioned why you can't share your code! Either way, Thanks for the info😉

@nimakasipour
Copy link

@zzbuzzard I am looking forward to reproducing this result on coco 5k val, but still not close to text retrieval after adding the prompt "a photo of " to text description. Could you please give us some tips maybe what could be different from the code provided by @mabravo641

@zzbuzzard
Copy link

@Neltherion @nimakasipour ok! I spent a while extracting it from the rest of my codebase (hopefully I won't get flagged by plagiarism checkers for this haha) but here's my script. I am able to run exactly this to get the following output:

Text-to-image Recall@K
 R@1: 36.09%
 R@5: 60.77%
 R@10: 70.78%
 R@50: 91.00%
Image-to-text Recall@K
 R@1: 57.40%
 R@5: 80.30%
 R@10: 87.56%
 R@50: 97.44%

I am running on 6GB GPU VRAM and this code certainly uses a decent amount of it - if you have memory issues you can try editing it (you only really need things on the GPU when evaluating CLIP, you could immediately move the result to CPU, but will need to cast to float32 as I believe CPU support for float16 is limited) or computing the whole thing on the CPU (this would be insanely slow though).

I don't have the time to compare to the code at the top of this issue, so I'm not sure why we're producing different answers, but this version seems to replicate the paper's results. It's not perfect, and I haven't properly tested it, but I hope this helps!

import torch
from torchvision.datasets import CocoCaptions
import torch.utils.data as dutils
from typing import List
import clip

# Change these to path of local COCO dataset:
coco_root = "**your_path_here**/coco/val2017"
coco_ann_file = "**your_path_here**/coco/annotations/captions_val2017.json"

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model, transform = clip.load("ViT-L/14@336px")
model.to(device).eval()

dataset = CocoCaptions(
    root=coco_root,
    annFile=coco_ann_file,
    transform=transform,
    # Note: almost all images have 5 captions, but 12/5000 have 6, and 1/5000 has 7 - I ignore these few extra captions.
    target_transform=lambda texts: clip.tokenize(texts[:5])
)

k_vals=[1, 5, 10, 50]


# Encodes all text and images in a dataset
def encode_dataset(clip, dataset: dutils.Dataset, batch_size = 16):
    with torch.no_grad():
        # image_to_text_map[i] gives the corresponding text indices for the ith image
        #  (as there are multiple pieces of text for each image)
        image_to_text_map = []

        # text_to_image_map[i] gives the corresponding image index for the ith text
        text_to_image_map = []

        dataloader = dutils.DataLoader(dataset, batch_size=batch_size, shuffle=False)

        image_encodings = []
        text_encodings = []

        text_index = 0
        image_index = 0

        for images, text in dataloader:
            images = images.to(device)
            text = text.to(device)

            # text has shape B x 5 x 77
            batch_size, captions_per_image, _ = text.shape

            # Update text_to_image_map and image_to_text_map for this batch
            for i in range(batch_size):
                # the next image corresponds to text captions [text_index ... text_index + captions_per_image - 1]
                text_indices = list(range(text_index, text_index + captions_per_image))
                image_to_text_map.append(text_indices)
                text_index += captions_per_image

                # Each of the next captions_per_image text captions correspond to the same image
                text_to_image_map += [image_index] * captions_per_image
                image_index += 1

            # B x 5 x 77 -> (B*5) x 77
            text = torch.flatten(text, start_dim=0, end_dim=1)
            
            image_encodings.append(clip.encode_image(images))
            text_encodings.append(clip.encode_text(text))

        image_encodings = torch.cat(image_encodings)
        text_encodings = torch.cat(text_encodings)
        text_to_image_map = torch.LongTensor(text_to_image_map).to(device)
        image_to_text_map = torch.LongTensor(image_to_text_map).to(device)

        # Normalise encodings
        image_encodings = image_encodings / image_encodings.norm(dim=-1, keepdim=True)
        text_encodings = text_encodings / text_encodings.norm(dim=-1, keepdim=True)

        return image_encodings, text_encodings, text_to_image_map, image_to_text_map


def recall_at_k(clip, dataset: dutils.Dataset, k_vals: List[int], batch_size: int):
    print("Encoding all data...")
    image_encodings, text_encodings, text_to_image_map, image_to_text_map = encode_dataset(clip, dataset, batch_size=batch_size)
 
    num_text = text_encodings.shape[0]
    num_im = image_encodings.shape[0]
    captions_per_image = image_to_text_map.shape[1]

    # text-to-image recall
    print("Text-to-image recall...")

    dist_matrix = text_encodings @ image_encodings.T  # dist_matrix[i] gives logits for ith text

    # Note: this matrix is pretty big (5000 x 25000 with dtype float16 = 250MB)
    #  torch.argsort runs out of memory for me (6GB VRAM) so I move to CPU for sorting
    dist_matrix = dist_matrix.cpu()

    # Sort in descending order; first is the biggest logit
    inds = torch.argsort(dist_matrix, dim=1, descending=True)
    inds = inds.to(device)

    text_to_image_recall = []

    for k in k_vals:
        # Extract top k indices only
        topk = inds[:, :k]

        # Correct iff one of the top_k values equals the correct image (as given by text_to_image_map)
        correct = torch.eq(topk, text_to_image_map.unsqueeze(-1)).any(dim=1)

        num_correct = correct.sum().item()
        text_to_image_recall.append(num_correct / num_text)


    # image-to-text recall
    print("Image-to-text recall...")
    dist_matrix = dist_matrix.T  # dist_matrix[i] gives logits for the ith image

    # Sort in descending order; first is the biggest logit
    inds = torch.argsort(dist_matrix, dim=1, descending=True)
    inds = inds.to(device)

    image_to_text_recall = []

    for k in k_vals:
        # Extract top k indices only
        topk = inds[:, :k]

        correct = torch.zeros((num_im,), dtype=torch.bool).cuda()

        #  For each image, check whether one of the 5 relevant captions was retrieved
        # Check if image matches its ith caption (for i=0..4)
        for i in range(captions_per_image):
            contains_index = torch.eq(topk, image_to_text_map[:, i].unsqueeze(-1)).any(dim=1)
            correct = torch.logical_or(correct, contains_index)

        num_correct = correct.sum().item()
        image_to_text_recall.append(num_correct / num_im)#

    print("Done.")
    return text_to_image_recall, image_to_text_recall


t2i, i2t = recall_at_k(model, dataset, k_vals=k_vals, batch_size=16)

print("Text-to-image Recall@K")
for k, x in zip(k_vals, t2i):
    print(f" R@{k}: {100*x:.2f}%")

print("Image-to-text Recall@K")
for k, x in zip(k_vals, i2t):
    print(f" R@{k}: {100*x:.2f}%")

@nimakasipour
Copy link

@zzbuzzard I greatly appreciate you sharing the code! It provided a straightforward and intriguing approach to evaluating the performance of CLIP. Through this evaluation, I could identify the source of my problem. It turns out that associating a single caption for each image led to poorer results in the text retrieval task. (20% less for each metric)
I had also noticed a minor decrease in performance for image retrieval tasks when I added the prompt "a photo of" to the image descriptions using my code with "ViT-L/14@336px" before adding "R@1= 33.88, RQ5=58.5, R@10 =69.32, R@50 = 90.66, and MeanR = 24.26". After adding the prompt "a photo of ", the results were "R@1 = 33.44, RQ5 = 58.36, R@10 = 69.42, R@50 = 90.58, and MeanR = 24.50". However, a slight improvement in the Text retrieval task (where each image was associated with one/first caption)

@lerogo
Copy link

lerogo commented Jun 28, 2023

For what it's worth, I tested this using my own independent code and was able to (approximately) reproduce the results. I get the following values for ViT-L/14@336px on the COCO 5k val set:

text-to-image R@1 = 36.10% (vs 37.8% in the paper) R@5 = 60.75% (vs 62.4% in the paper) R@10 = 70.75% (vs 72.2% in the paper)

image-to-text R@1 = 57.48% (vs 58.4% in the paper) R@5 = 80.32% (vs 81.5% in the paper) R@10 = 87.60% (vs 88,1% in the paper)

The paper does state the following (page 45):

For both these datasets we prepend the prompt “a photo
of” to the description of each image which we found boosts
CLIP’s zero-shot R@1 performance between 1 and 2 points

Implementing this change gave me values of 57.94 / 81.50 / 88.42 for image-to-text retrieval (even closer to the paper), but seemed to slightly reduce my text-to-image retrieval scores.

So, I would attribute the differences to either some other minor tricks like this that they used to slightly improve performance, or simply differences in implementation - it's not totally obvious how to handle the one-to-many relationship between images and captions.

(unfortunately I am not able to share my code, as it is part of an ongoing assessed project)

@Neltherion I was also hoping to find some commonly used definition, for example in the torchmetrics library (for example, there seems to be a standard definition for mean average precision), but failed to find one.

text-to-image
R@1 = 37.68% (vs 36.10% (vs 37.8% in the paper))
R@5 = 62.10% (vs 60.75% (vs 62.4% in the paper))
R@10 = 71.87% (vs 70.75% (vs 72.2% in the paper))

image-to-text
R@1 = 56.26% (vs 57.48% (vs 58.4% in the paper))
R@5 = 80.22% (vs 80.32% (vs 81.5% in the paper))
R@10 = 87.32% (vs 87.60% (vs 88,1% in the paper))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants