In [1]:
import sys
sys.path.append("../")

from PIL import Image

import torch
import torchvision
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset

from simple_clip import CLIP
from simple_clip.utils import accuracy, get_image_encoder, get_text_encoder
from simple_clip.custom_datasets.clip_datasets import get_image_tranforms

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
IMAGE_SIZE = 224
transform = get_image_tranforms(image_size=(IMAGE_SIZE, IMAGE_SIZE))

train_ds = torchvision.datasets.STL10("../data",
                                  split='train',
                                  transform=transform,
                                  download=True)
val_ds = torchvision.datasets.STL10("../data",
                                  split='test',
                                  transform=transform,
                                  download=True)

train_loader = DataLoader(train_ds,
                          batch_size=256,
                          num_workers=4)
val_loader = DataLoader(val_ds,
                       batch_size=256,
                       num_workers=4)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
STL_LABELS = ["airplane", "bird", "car", "cat", "deer", "dog", "horse", "monkey", "ship", "truck"]

texts = [f"a photo of a {l}" for l in STL_LABELS]

In [5]:
len(train_ds), len(val_ds)

(5000, 8000)

In [6]:
image_encoder = get_image_encoder("resnet50")
text_encoder = get_text_encoder("distilbert-base-uncased")
model = CLIP(image_encoder, text_encoder)

ckpt = torch.load("../models/clip_model_best.pth")

model.load_state_dict(ckpt)
model = model.eval().to(device)

In [7]:
from tqdm.auto import tqdm
import numpy as np

def get_embs_labels(dl):
    idx = 0
    embs, labels = [], []
    for idx, (images, targets) in enumerate(tqdm(dl)):
        with torch.no_grad():
            images = images.to(device)
            out = model.extract_image_features(images)
            features = out.cpu().detach().tolist()
            embs.extend(features)
            labels.extend(targets.cpu().detach().tolist())
    return torch.tensor(embs).to(device), torch.tensor(labels).to(device)

In [8]:
image_features, labels = get_embs_labels(train_loader)
image_features_val, labels_val = get_embs_labels(val_loader)

  0%|          | 0/20 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

In [9]:
print(image_features.shape)
print(labels.shape)
print(image_features_val.shape)
print(labels_val.shape)

torch.Size([5000, 768])
torch.Size([5000])
torch.Size([8000, 768])
torch.Size([8000])


In [10]:
from transformers import DistilBertTokenizer

tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

encoded_texts = tokenizer(texts, padding=True, truncation=True, max_length=256)
input_ids = torch.tensor(encoded_texts["input_ids"]).to(device)
attention_mask = torch.tensor(encoded_texts["attention_mask"]).to(device)
text_features = model.extract_text_features(input_ids, attention_mask)

In [11]:
preds_train = image_features @ text_features.t()
preds_val = image_features_val @ text_features.t()

### Resnet 50

In [14]:
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
 
y_train, y_test = labels.cpu().detach().tolist(), labels_val.cpu().detach().tolist()
 
y_pred = preds_val.argmax(dim=-1).cpu().detach().tolist()
 
acc = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred, target_names=STL_LABELS)
 
print("Accuracy: ", acc)
print("Confusion matrix: \n", conf_matrix)
print("Classification report: \n", class_report)
 
y_pred_train = preds_train.argmax(dim=-1).cpu().detach().tolist()
class_report = classification_report(y_train, y_pred_train, target_names=STL_LABELS)
print("Classification report train: \n", class_report)

Accuracy:  0.9375
Confusion matrix: 
 [[735   5   2   0   0   0   0   0  51   7]
 [  0 778   0  10   1   2   0   8   1   0]
 [  1   0 776   0   0   0   0   0   0  23]
 [  0   2   0 641 100  32   1  24   0   0]
 [  0   8   0   1 741  14  22  12   0   2]
 [  0   3   0  50   4 728   9   6   0   0]
 [  0   0   3   3   9  15 761   5   0   4]
 [  0  10   0  11   7  16   0 756   0   0]
 [  1   0   0   0   0   0   0   0 791   8]
 [  1   0   5   0   0   0   0   0   1 793]]
Classification report: 
               precision    recall  f1-score   support

    airplane       1.00      0.92      0.96       800
        bird       0.97      0.97      0.97       800
         car       0.99      0.97      0.98       800
         cat       0.90      0.80      0.85       800
        deer       0.86      0.93      0.89       800
         dog       0.90      0.91      0.91       800
       horse       0.96      0.95      0.96       800
      monkey       0.93      0.94      0.94       800
        ship       