In [1]:
import sys
sys.path.insert(0, "../")

In [116]:
import torch
import timm
import pandas as pd
from tqdm.notebook import tqdm
import clip
from PIL import Image
import os
from skimage import io

from src.models.resnet import ResNet
from src.utils import update_state_dict, extract_keywords, is_keyword_in_caption
from src.datasets.loader import get_loaders
from src.datasets import CelebA
from src.metrics import WorstGroupAccuracy

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [47]:
test_dataset = CelebA("../../datasets/celeba", "Male", 2)
test_loader = get_loaders(
    test_dataset,
    batch_size=10,
    num_workers=4,
    pin_memory=True,
)

batch = next(iter(test_loader))

In [48]:
batch.keys()

dict_keys(['filename', 'caption', 'image', 'label', 'spurious_label', 'group'])

In [51]:
batch_size = 10
num_classes = 3
num_groups = num_classes**2

y_true = torch.randint(0, num_classes, size=(batch_size,))
y_pred = torch.randint(0, num_classes, size=(batch_size,))
g = batch["group"]

In [52]:
print(
    f"y_pred:\t{y_pred.tolist()}\n"
    f"y_true:\t{y_true.tolist()}\n"
    f"g:\t{g.tolist()}\n"
)

y_pred:	[0, 2, 1, 2, 1, 0, 0, 0, 1, 1]
y_true:	[2, 1, 1, 0, 1, 1, 1, 2, 2, 2]
g:	[0, 0, 0, 0, 0, 1, 1, 0, 0, 1]



In [126]:
@torch.no_grad()
def similarity(keywords, root_dir, image_filenames, device):    
    text_embeddings = torch.cat([clip.tokenize(f"a photo of a {c}") for c in keywords])

    model, preprocess = clip.load("ViT-B/32", device=device)
    images = [
        Image.fromarray(io.imread(os.path.join(root_dir, f))) for f in image_filenames
    ]
    images = torch.cat([preprocess(im).unsqueeze(0) for im in images])

    text_embeddings = model.encode_text(text_embeddings)
    image_embeddings = model.encode_image(images)

    image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
    text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)
    similarity = (
        100.0 * image_embeddings @ text_embeddings.T
    )  # size: (len(image_filenames), 20)
    similarity = similarity.mean(dim=0)  # size: (20,)
    return similarity

In [127]:
t = torch.empty(size=(2, batch_size), dtype=torch.bool)
for i in range(num_classes):
    # retrieve the indices where the model mispredicts the labels of the images
    t[0, :] = ((y_pred == y_true) & (y_true == i)).bool()  # correct set
    t[1, :] = ((y_pred != y_true) & (y_true == i)).bool()  # incorrect set

    # extract keywords from the correct set
    caption = " ".join([c for j, c in enumerate(batch["caption"]) if t[1, j]])
    keywords = extract_keywords(caption)

    # calculate the similarity between the keyword and the correct set, as well
    # as the similarity between the keyword and the incorrect set
    filenames_correct_set = [f for j, f in enumerate(batch["filename"]) if t[0, j]]
    sim_k_correct_set = similarity(
        keywords=keywords,
        image_filenames=filenames_correct_set,
        root_dir="../../datasets/celeba/img_align_celeba/",
        device=device,
    )  # size: (num_keywords,)

    filenames_incorrect_set = [f for j, f in enumerate(batch["filename"]) if t[1, j]]
    sim_k_incorrect_set = similarity(
        keywords=keywords,
        image_filenames=filenames_incorrect_set,
        root_dir="../../datasets/celeba/img_align_celeba/",
        device=device,
    )  # size: (num_keywords,)

    # clip_score = sim_k_correct_set - sim_k_incorrect_set  # size: (num_keywords,)

RuntimeError: torch.cat(): expected a non-empty list of Tensors

In [105]:
filenames = filenames_incorrect_set

In [76]:
images = batch["image"][t[1, :]]

In [84]:
text_embeddings = torch.cat([clip.tokenize(f"a photo of a {c}") for c in keywords])
text_embeddings.size()

torch.Size([20, 77])

In [86]:
device = images.device
model, preprocess = clip.load("ViT-B/32", device=device)


In [114]:
root_dir = "../../datasets/celeba/img_align_celeba/"

images = [Image.fromarray(io.imread(os.path.join(root_dir, f))) for f in filenames]
images = torch.cat([preprocess(im).unsqueeze(0) for im in images])


In [122]:
similarity

tensor([21.0033, 24.3015, 21.6432, 21.5929, 24.8060, 21.9088, 21.8825, 23.3264,
        20.3384, 22.4032, 23.5213, 22.1934, 23.2563, 24.3473, 20.0536, 21.7447,
        22.7959, 22.4792, 21.7573, 23.2692])