In [27]:
import clip
from clip.simple_tokenizer import SimpleTokenizer
import torch
import torchvision.transforms as T
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

from src import FoodDataset
from src import TextTransformer

In [5]:
# Load the model
clip_backbone = "ViT-B/32"
device = "cuda" if torch.cuda.is_available() else "cpu"
model, image_transform = clip.load(clip_backbone, jit=False)
model = model.to(dtype=torch.float32)

input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Input resolution: 224
Context length: 77
Vocab size: 49408


In [26]:
# candidate templates to test
templates = [
    'a photo of {}, a type of food.',
    'a cropped photo of the {}, a type of food.',
    'a close-up photo of a {}, a type of food.',
    'a photo of a delicious {}, a type of food.',
    'a photo of the small {}, a type of food.',
    'a photo of the large {}, a type of food.',
]

tokenizer = SimpleTokenizer()
text_transformers = [TextTransformer(tokenizer, [template], context_length) for template in templates]

In [33]:
dataset_root = "data/food-101/images"
dataloaders = [
    DataLoader(FoodDataset(f'{dataset_root}', 
                           image_transform = image_transform,
                           prompt_transform = text_transformer,
                           return_indices=False), 
               batch_size=32, 
               shuffle=True) 
    for text_transformer in text_transformers]

In [34]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [None]:
with torch.no_grad():
    tops1 = []
    tops5 = []
    for template_id, loader in enumerate(dataloaders):
        top1, top5, n = 0., 0., 0.
        for i, (images, text) in enumerate(tqdm(loader)):
            images = images.to(device)
            text = text[:,0,:]
            text = text.to(device)
            target = torch.arange(len(images), device=device)

            image_features = model.visual(images)
            text_features = model.encode_text(text)

            # normalize features
            image_features = image_features / (image_features.norm(dim=-1, keepdim=True) + 1e-6)
            text_features = text_features / (text_features.norm(dim=-1, keepdim=True) + 1e-6)

            logits = 100. * image_features @ text_features.T

            # measure accuracy
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1 += acc1
            top5 += acc5
            n += images.size(0)

        tops1.append((top1 / n) * 100)
        tops5.append((top5 / n) * 100)

    print(f'Template: {template_id}')
    print(f"Top-1 accuracy: {tops1[-1]:.2f}")
    print(f"Top-5 accuracy: {tops5[-1]:.2f}")

In [None]:
import pandas as pd
sorted_templates = [(templates[i], tops1[i], tops5[i]) for i in np.argsort(tops1)[::-1]]
df_templates = pd.DataFrame(sorted_templates, 
                            index=np.arange(1, len(sorted_templates)+1),
                            columns=['Template', 'Top-1 Accuracy', 'Top-5 Accuracy'])

df_templates.to_csv("templates_score.csv", index=False)
df_templates

In [44]:
import pandas as pd
df = pd.read_csv("templates_score.csv")
df

Unnamed: 0,Template,Top-1 Accuracy,Top-5 Accuracy
0,"a photo of a delicious {}, a type of food.",78.173267,98.625743
1,"a photo of {}, a type of food.",78.109901,98.582178
2,"a close-up photo of a {}, a type of food.",78.005941,98.526733
3,"a photo of a tasty {}, a type of food.",77.959406,98.652475
4,"a cropped photo of the {}, a type of food.",77.893069,98.59505
5,"a photo of the small {}, a type of food.",77.543564,98.565347
6,"a photo of the large {}, a type of food.",76.593069,98.427723
