In [1]:
# Imports 
import os
import clip
import torch
from torchvision import transforms, models

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm

import argparse
from omegaconf import OmegaConf

import json

from datasets import *
device = "cuda" if torch.cuda.is_available() else "cpu"
from sklearn.metrics.pairwise import euclidean_distances, cosine_similarity
import seaborn as sn

from columnar import columnar
from nltk.corpus import wordnet as wn

In /nethome/bdevnani3/anaconda3/envs/p3/lib/python3.8/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The text.latex.preview rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nethome/bdevnani3/anaconda3/envs/p3/lib/python3.8/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The mathtext.fallback_to_cm rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /nethome/bdevnani3/anaconda3/envs/p3/lib/python3.8/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: Support for setting the 'mathtext.fallback_to_cm' rcParam is deprecated since 3.3 and will be removed two minor releases later; use 'mathtext.fallback : 'cm' instead.
In /nethome/bdevnani3/anaconda3/envs/p3/lib/python3.8/site-packages/matplotlib/mpl-data/stylelib/_classic_test.mplstyle: 
The validate_bool_maybe_none function was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
In /

In [2]:
def clip_zero_shot(
    loader,
    classes,
    zeroshot_weights,
    clip_model_name="ViT-B/32",
):

    global clip_model, clip_preprocess
    device = "cuda" if torch.cuda.is_available() else "cpu"

    def accuracy(output, target, topk=(1,)):
        pred = output.topk(max(topk), 1, True, True)[1].t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        return [
            float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
            for k in topk
        ]

    # lazy load
    if clip_model == None:
        clip_model, clip_preprocess = clip.load(clip_model_name, device)
        
    per_class_accuracy_top1 = { k:[0,0, classes[k]] for k in range(len(classes))} # correct, total, class_name
    per_class_accuracy_top5 = { k:[0,0, classes[k]] for k in range(len(classes))} 

    with torch.no_grad():
        top1, top5, n = 0.0, 0.0, 0.0
        for i , (images, target) in enumerate(tqdm(loader)):
            images = images.cuda()
            target = target.cuda()

            # predict
            image_features = clip_model.encode_image(images)
            image_features /= image_features.norm(dim=-1, keepdim=True)
            logits = 100.0 * image_features @ zeroshot_weights

            # measure accuracy
            acc1, acc5 = accuracy(logits, target, topk=(1, 5))
            top1 += acc1
            top5 += acc5
            n += images.size(0)
            
            per_class_accuracy_top1[target.cpu().detach().numpy()[0]][0]+= acc1
            per_class_accuracy_top1[target.cpu().detach().numpy()[0]][1]+= 1

    top1 = (top1 / n) * 100
    top5 = (top5 / n) * 100

    return top1, per_class_accuracy_top1

In [3]:
clip_model, clip_preprocess = clip.load("ViT-B/32", )

flowers = Flowers102(4, 1, '/nethome/bdevnani3/raid/data/')
f_train_loader, _ = flowers.get_train_loaders(transform_fn=clip_preprocess)
f_test_loader = flowers.get_test_loader(transform_fn=clip_preprocess)

pets = OxfordPets(4, 1, '/nethome/bdevnani3/raid/data/')
p_train_loader, _ = pets.get_train_loaders(transform_fn=clip_preprocess)
p_test_loader = pets.get_test_loader(transform_fn=clip_preprocess)

In [4]:
def per_class_performance(d, print_out=True):
    x = []
    labels = []
    corr = []
    tot = []
    for el in d:
        x.append((d[el][0]*100)/d[el][1])
        labels.append(d[el][2])
        corr.append(d[el][0])
        tot.append(d[el][1])
    idx = np.argsort(x)
    x = np.array(x)[idx]
    labels = np.array(labels)[idx]
    corr = np.array(corr)[idx]
    tot = np.array(tot)[idx]
    out = {}
    for l,per,c,t in zip(labels,x,corr,tot):
        out[l] = [np.around(per,5),int(c),int(t)]
    if print_out:
        table = columnar([[l, o[0], o[1], o[2]] for l,o in out.items()], ["Class Name", "Accuracy(%)", "Num Correct", "Total"])
        print(table)
    return out

In [5]:
def relative_per_class_performance(d1,d2, print_out=True):
    "Positive value means d1 did better"
    x1 = []
    labels = []
    x2 = []
    tot_num_labels_difference = 0
    for el in d1:
        x1.append(((d1[el][0]-d2[el][0])*100)/(d2[el][0]+0.0000001))
        labels.append(d1[el][2])
        num_labels_difference = d1[el][0]
        num_labels_difference -= d2[el][0]
        x2.append(num_labels_difference)
        tot_num_labels_difference += num_labels_difference
        
#         x2.append((d2[el][0]*100)/d2[el][1])
        
#     final = np.array(x1)-np.array(x2)
    idx = np.argsort(x1)
    x1 = np.array(x1)[idx]
    x2 = np.array(x2)[idx]
    labels = np.array(labels)[idx]
    out = {}
    for l,per,num_per in zip(labels,x1,x2):
        if per > 100:
            per = 100
        out[l] = np.around(per,decimals=5), num_per
    if print_out:
        table = columnar([[x,out[x][0],out[x][1]] for x in out], ["Class Name", "Relative Accuracy(%)", "Difference in Num of Labels"])
        print(table)
    return out, tot_num_labels_difference

### Main Goal
In this notebook, we will try to understand the features of a good text prompt. We will be using the flowers and the pets dataset to do so.

In [6]:
def zeroshot_classifier_baseline(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [
                template.format(classname) for template in templates
            ]  # format with class
            texts = clip.tokenize(texts).cuda()  # tokenize
            class_embeddings = clip_model.encode_text(texts)  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights


In [7]:
def per_class_performance(d, print_out=True):
    x = []
    labels = []
    corr = []
    tot = []
    for el in d:
        x.append((d[el][0]*100)/d[el][1])
        labels.append(d[el][2])
        corr.append(d[el][0])
        tot.append(d[el][1])
    idx = np.argsort(x)
    x = np.array(x)[idx]
    labels = np.array(labels)[idx]
    corr = np.array(corr)[idx]
    tot = np.array(tot)[idx]
    out = {}
    for l,per,c,t in zip(labels,x,corr,tot):
        out[l] = [np.around(per,5),int(c),int(t)]
    if print_out:
        table = columnar([[l, o[0], o[1], o[2]] for l,o in out.items()], ["Class Name", "Accuracy(%)", "Num Correct", "Total"])
        print(table)
    return out

### Baseline

In [8]:
# Performance with just class names


templates = ["{}"]

classes = [c.replace("_", " ")for c in pets.classes]

pets_baseline_zw = zeroshot_classifier_baseline(classes,templates)

pets_baseline_czs = clip_zero_shot(
    p_test_loader,
    classes,
    pets_baseline_zw
)
print("Pets baseline performance: ",pets_baseline_czs[0])
per_class_performance(pets_baseline_czs[1])

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/3669 [00:00<?, ?it/s]

Pets baseline performance:  82.33851185609157
|--------------------------|-----------|-----------|-----|
|Class Name                |Accuracy(%)|Num Correct|Total|
|Persian                   |0.0        |0          |100  |
|--------------------------|-----------|-----------|-----|
|Ragdoll                   |9.0        |9          |100  |
|--------------------------|-----------|-----------|-----|
|newfoundland              |11.0       |11         |100  |
|--------------------------|-----------|-----------|-----|
|boxer                     |40.40404   |40         |99   |
|--------------------------|-----------|-----------|-----|
|Bengal                    |49.0       |49         |100  |
|--------------------------|-----------|-----------|-----|
|american pit bull terrier |72.0       |72         |100  |
|--------------------------|-----------|-----------|-----|
|staffordshire bull terrier|76.40449   |68         |89   |
|--------------------------|-----------|-----------|-----|
|British S

{'Persian': [0.0, 0, 100],
 'Ragdoll': [9.0, 9, 100],
 'newfoundland': [11.0, 11, 100],
 'boxer': [40.40404, 40, 99],
 'Bengal': [49.0, 49, 100],
 'american pit bull terrier': [72.0, 72, 100],
 'staffordshire bull terrier': [76.40449, 68, 89],
 'British Shorthair': [79.0, 79, 100],
 'chihuahua': [79.0, 79, 100],
 'english cocker spaniel': [81.0, 81, 100],
 'beagle': [85.0, 85, 100],
 'Birman': [87.0, 87, 100],
 'Maine Coon': [88.0, 88, 100],
 'Sphynx': [90.0, 90, 100],
 'Egyptian Mau': [90.72165, 88, 97],
 'Abyssinian': [90.81633, 89, 98],
 'havanese': [92.0, 92, 100],
 'Siamese': [93.0, 93, 100],
 'great pyrenees': [93.0, 93, 100],
 'wheaten terrier': [94.0, 94, 100],
 'american bulldog': [95.0, 95, 100],
 'basset hound': [95.0, 95, 100],
 'saint bernard': [95.0, 95, 100],
 'keeshond': [95.9596, 95, 99],
 'english setter': [96.0, 96, 100],
 'japanese chin': [96.0, 96, 100],
 'leonberger': [96.0, 96, 100],
 'Russian Blue': [96.0, 96, 100],
 'german shorthaired': [97.0, 97, 100],
 'pome

### Simple sentences

In [9]:
templates = ["{}",
             "a {}",
             "is a {}",
             "This is a {}",
             "This is a {}.",
             "a is this {}.",
             "is a this {}.",
             "That is a {}.",
             "That is a {} which is a very wonderful and cute pet to have, we love it very much.",
             "This is a photo of a {}", 
             "This is a photo of a {}, a pet", 
             "house pet cute {}"]
out = []
for template in templates:
    
    t = [template]

    pets_zw = zeroshot_classifier_baseline(classes,t)

    pets_czs = clip_zero_shot(
        p_test_loader,
        classes,
        pets_zw
    )
    out.append(pets_czs[0])
    
for t,o in zip(templates,out):
    print(f"{t}: {o}")

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/3669 [00:00<?, ?it/s]

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/3669 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [10]:
def zeroshot_classifier_1(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            if classname[0].isupper():
                animal = "dog"
            else:
                animal = "cat"
            texts = [
                template.format(classname, animal) for template in templates
            ]  # format with class
            texts = clip.tokenize(texts).cuda()  # tokenize
            class_embeddings = clip_model.encode_text(texts)  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights
templates = ["This is a {}, a pet {}"]
out = []
for template in templates:
             
    
    t = [template]

    pets_zw = zeroshot_classifier_1(classes,t)

    pets_czs = clip_zero_shot(
        p_test_loader,
        classes,
        pets_zw
    )
    out.append(pets_czs[0])
    
for t,o in zip(templates,out):
    print(f"{t}: {o}")

  0%|          | 0/37 [00:00<?, ?it/s]

  0%|          | 0/3669 [00:00<?, ?it/s]

This is a {}, a pet {}: 86.91741618969746
