In [1]:
import sys
sys.path.append("../")

from PIL import Image

import torch
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from simple_clip import CLIP
from simple_clip.utils import accuracy, get_image_encoder, get_text_encoder
from simple_clip.custom_datasets.clip_datasets import get_image_tranforms
from simple_clip.imagenet_eval import ImageNetValidation, ImageNetDataset

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
from datasets import load_dataset

IMAGE_SIZE = 224
transform = get_image_tranforms(image_size=(IMAGE_SIZE, IMAGE_SIZE))

dataset = load_dataset("imagenet-1k")
train_ds = dataset["train"].select(list(range(50000)))
val_ds = dataset["validation"]
train_ds, val_ds = ImageNetDataset(train_ds, transform), ImageNetDataset(val_ds, transform)


train_loader = DataLoader(train_ds,
                          batch_size=256,
                          num_workers=4)
val_loader = DataLoader(val_ds,
                       batch_size=256,
                       num_workers=4)

In [4]:
LABELS = dataset["validation"].features["label"].int2str(list(range(1000)))

texts = [f"a photo of a {l}" for l in LABELS]

In [5]:
len(train_ds), len(val_ds)

(50000, 50000)

In [6]:
image_encoder = get_image_encoder("resnet50")
text_encoder = get_text_encoder("distilbert-base-uncased")
model = CLIP(image_encoder, text_encoder)

ckpt = torch.load("../models/clip_model_best.pth")

model.load_state_dict(ckpt)
model = model.eval().to(device)

In [7]:
from tqdm.auto import tqdm
import numpy as np

def get_embs_labels(dl):
    idx = 0
    embs, labels = [], []
    for idx, (images, targets) in enumerate(tqdm(dl)):
        with torch.no_grad():
            images = images.to(device)
            out = model.extract_image_features(images)
            features = out.cpu().detach().tolist()
            embs.extend(features)
            labels.extend(targets.cpu().detach().tolist())
    return torch.tensor(embs).to(device), torch.tensor(labels).to(device)

In [8]:
image_features, labels = get_embs_labels(train_loader)
image_features_val, labels_val = get_embs_labels(val_loader)

  0%|          | 0/196 [00:00<?, ?it/s]

  0%|          | 0/196 [00:00<?, ?it/s]

In [9]:
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

encoded_texts = tokenizer(texts, padding=True, truncation=True, max_length=100)
input_ids = torch.tensor(encoded_texts["input_ids"]).to(device)
attention_mask = torch.tensor(encoded_texts["attention_mask"]).to(device)
text_features = model.extract_text_features(input_ids, attention_mask)

In [10]:
preds_train = image_features @ text_features.t()
preds_val = image_features_val @ text_features.t()

### Results

In [11]:
from sklearn.metrics import accuracy_score, classification_report
 
y_train, y_test = labels.cpu().detach().tolist(), labels_val.cpu().detach().tolist()
 
y_pred = preds_val.argmax(dim=-1).cpu().detach().tolist()
 
acc = accuracy_score(y_test, y_pred)
 
print("Accuracy Val: ", acc)
 
y_pred_train = preds_train.argmax(dim=-1).cpu().detach().tolist()

acc = accuracy_score(y_train, y_pred_train)
print("Accuracy Train: ", acc)

Accuracy Val:  0.36476
Accuracy Train:  0.38378


In [12]:
labels_trunc = [l.split(",")[0] for l in LABELS]
class_report = classification_report(y_test, y_pred, target_names=labels_trunc)
print("Classification report: \n", class_report)

Classification report: 
                                 precision    recall  f1-score   support

                         tench       0.00      0.00      0.00        50
                      goldfish       0.51      0.92      0.65        50
             great white shark       0.33      0.74      0.46        50
                   tiger shark       0.36      0.40      0.38        50
                    hammerhead       0.22      0.14      0.17        50
                  electric ray       0.08      0.02      0.03        50
                      stingray       0.39      0.70      0.50        50
                          cock       0.62      0.68      0.65        50
                           hen       0.56      0.70      0.62        50
                       ostrich       0.98      0.86      0.91        50
                     brambling       0.25      0.02      0.04        50
                     goldfinch       0.82      0.98      0.89        50
                   house finch       0

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
