### CLIP as a zero shot classifier
___

In [1]:
import os
import clip
import torch
import torchvision
import numpy as np
from tqdm import tqdm

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Torch version: {torch.__version__}")

Device: cuda
GPU: Tesla K80
Torch version: 1.10.0


In [3]:
clip.available_models()

['RN50', 'RN101', 'RN50x4', 'RN50x16', 'ViT-B/32', 'ViT-B/16']

In [19]:
model, preprocess = clip.load("RN101")

In [20]:
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 119,688,033
Input resolution: 224
Context length: 77
Vocab size: 49408


In [21]:
dataset_root = "data/food-101/images"

class_names = sorted(os.listdir(dataset_root))
class_names = [name.replace('_', ' ') for name in class_names]
class_to_idx = {class_names[i]: i for i in range(len(class_names))}

templates = ['a photo of {}, a type of food.']
# class_captions = [f"a photo of {x}, a type of food." for x in class_names]

print(f"{len(class_names)} classes, {len(templates)} templates")

101 classes, 1 templates


In [22]:
images = torchvision.datasets.ImageFolder(root=dataset_root, transform=preprocess)
loader = torch.utils.data.DataLoader(images, batch_size=32, num_workers=os.cpu_count())

In [23]:
def zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).cuda() #tokenize
            class_embeddings = model.encode_text(texts) #embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights

zeroshot_weights = zeroshot_classifier(class_names, templates)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 101/101 [00:06<00:00, 15.85it/s]


In [24]:
def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [25]:
with torch.no_grad():
    top1, top5, n = 0., 0., 0.
    for i, (images, target) in enumerate(tqdm(loader)):
        images = images.cuda()
        target = target.cuda()
        
        # predict
        image_features = model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100. * image_features @ zeroshot_weights

        # measure accuracy
        acc1, acc5 = accuracy(logits, target, topk=(1, 5))
        top1 += acc1
        top5 += acc5
        n += images.size(0)

top1 = (top1 / n) * 100
top5 = (top5 / n) * 100 

print(f"Top-1 accuracy: {top1:.2f}")
print(f"Top-5 accuracy: {top5:.2f}")

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3157/3157 [49:47<00:00,  1.06it/s]

Top-1 accuracy: 80.61
Top-5 accuracy: 95.86



