In [1]:
import os
import clip
import torch
from torchvision import transforms, models

import numpy as np
from sklearn.linear_model import LogisticRegression
from torch.utils.data import DataLoader
from tqdm import tqdm

import argparse
from omegaconf import OmegaConf

import json

from datasets import *
from clip_model_comparison import *

results_path = "results"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

clip_model, clip_preprocess = clip.load("ViT-B/32", device)

import torch.nn as nn
import torch.optim as optim

def get_clip_features(dataset):
    all_features = []
    all_labels = []

    global clip_model

    with torch.no_grad():
        for images, labels in tqdm(dataset):
            features = clip_model.encode_image(images.to(device))
            all_features.append(features)
            all_labels.append(labels)

    return torch.cat(all_features), torch.cat(all_labels)

cuda


In [2]:
dataset_obj = Flowers102(4, 100)
train_loader, _ = dataset_obj.get_train_loaders(transform_fn=clip_preprocess,num_elements_per_class=10)
test_loader = dataset_obj.get_test_loader(transform_fn=clip_preprocess)
classes = dataset_obj.classes

In [3]:
def text_zeroshot_classifier(classnames, templates):
    with torch.no_grad():
        zeroshot_weights = []
        for classname in tqdm(classnames):
            texts = [
                template.format(classname) for template in templates
            ]  # format with class
            texts = clip.tokenize(texts).cuda()  # tokenize
            class_embeddings = clip_model.encode_text(
                texts
            )  # embed with text encoder
            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)
            class_embedding = class_embeddings.mean(dim=0)
            class_embedding /= class_embedding.norm()
            zeroshot_weights.append(class_embedding)
        zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()
    return zeroshot_weights



In [4]:
global clip_model, clip_preprocess

len_classes = len(classes)

train_features, train_labels = get_clip_features(train_loader)
test_features, test_labels = get_clip_features(test_loader)

train_features = train_features / train_features.norm(dim=-1, keepdim=True)
test_features = test_features / test_features.norm(dim=-1, keepdim=True)

classifier = LogisticRegression(C=1, max_iter=1000, n_jobs=6)
classifier.fit(train_features.cpu().numpy(), train_labels.cpu().numpy())
predictions = classifier.predict(test_features.cpu().numpy())
accuracy = np.mean((test_labels.cpu().numpy() == predictions).astype(np.float)) * 100.0

print(accuracy)



100%|██████████| 11/11 [00:09<00:00,  1.10it/s]
100%|██████████| 9/9 [00:06<00:00,  1.38it/s]


83.15018315018315


In [5]:
print(classifier.coef_.shape)

(102, 512)


In [6]:
zeroshot_weights = torch.from_numpy(classifier.coef_.T).to(torch.float16)
zeroshot_weights.shape

torch.Size([512, 102])

In [4]:
zeroshot_weights = torch.from_numpy(classifier.coef_.T).to(torch.float16)
# zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda()

def accuracy(output, target, topk=(1,)):
    pred = output.topk(max(topk), 1, True, True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [
        float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy())
        for k in topk
    ]

# lazy load
if clip_model == None:
    clip_model, clip_preprocess = clip.load(clip_model_name, device)

with torch.no_grad():
    top1, top5, n = 0.0, 0.0, 0.0
    for i, (images, target) in enumerate(tqdm(test_loader)):
        images = images.cuda()
        target = target.cuda()

        # predict
        
        image_features = clip_model.encode_image(images)
        image_features /= image_features.norm(dim=-1, keepdim=True)
        logits = 100.0 * image_features.to(device) @ zeroshot_weights.to(device)

        # measure accuracy
        acc1, _ = accuracy(logits, target, topk=(1, 5))
        top1 += acc1
        n += images.size(0)

top1 = (top1 / n) * 100

print(top1)

100%|██████████| 9/9 [00:07<00:00,  1.25it/s]

84.004884004884



