In [1]:
import sys
sys.path.insert(0, "../")

In [2]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import pandas as pd
import clip
import os
from PIL import Image
import numpy as np
from skimage import io

from src.datasets.celeba import CelebADataset
from src.datasets.utils import stratified_sampler
from src.utils import extract_keywords, is_keyword_in_caption

In [3]:
def nonzero(tensor: torch.Tensor) -> torch.Tensor:
    return torch.nonzero(tensor).squeeze(dim=-1)

In [4]:
use_cuda = torch.cuda.is_available()
device = "cuda" if use_cuda else "cpu"

In [5]:
train_dataset = CelebADataset(
    "../../datasets/celeba/",
    "Male",
    0,
    use_image_captions=True,
    df_captions_dir="../checkpoints/list_captions_celeba.csv",
)
shuffle = False
sampler = stratified_sampler(train_dataset)
train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    num_workers=2,
    pin_memory=use_cuda,
    shuffle=shuffle,
    sampler=sampler,
)

In [6]:
for batch in train_loader:
    filename = batch.pop("filename")
    caption = batch.pop("caption")
    batch = {k: v.to(device) for k, v in batch.items()}
    x = batch["image"]
    g = batch["group"]
    y_true = batch["label"].float()
    if y_true.ndim < 2:
        y_true.unsqueeze_(1)
        
    y_pred = torch.randint_like(y_true, 0, 2)
    break

In [7]:
batch_size = y_true.size(0)

y_pred = y_pred.view(batch_size)
y_true = y_true.view(batch_size)
groups = g.view(batch_size)

idx = torch.arange(0, batch_size, device=device)
a = torch.stack((y_true, y_pred, groups, idx), dim=1)

is_correct = a[:, 0] == a[:, 1]

In [9]:
is_correct.size()

torch.Size([128])

: 

In [22]:
class CLIPLoss(nn.Module):
    def __init__(
        self,
        dataset_root_dir: str,
        normalize: bool = True,
        device="cuda" if torch.cuda.is_available() else "cpu",
    ):
        super().__init__()
        self.dataset_root_dir = dataset_root_dir
        self.normalize = normalize
        self.device = device
        self.model, self.preprocess = clip.load(
            "ViT-B/32", device=self.device, jit=False
        )
        self.model.eval()

    def clip_score(self, filenames, keywords):
        filenames = [os.path.join(self.dataset_root_dir, f) for f in filenames]
        images = torch.cat(
            [self.preprocess(Image.open(f)).unsqueeze(0) for f in filenames]
        ).to(device)
        images_features = self.model.encode_image(images)

        k = torch.cat([clip.tokenize(f"a photo of a {k}") for k in keywords]).to(device)
        texts_features = self.model.encode_text(k)

        images_features /= images_features.norm(dim=-1, keepdim=True)
        texts_features /= texts_features.norm(dim=-1, keepdim=True)
        similarity = 100.0 * images_features @ texts_features.T  # (bs, 20)
        return similarity.mean(dim=0)

    def forward_similarity(self, filenames_wrong, filenames_correct, keywords):
        sim_wrong = self.clip_score(filenames_wrong, keywords)
        sim_correct = self.clip_score(filenames_correct, keywords)
        sim = sim_correct - sim_wrong

        if self.normalize:
            sim -= sim.min()
            sim /= sim.max()
        return sim

    def forward(self, y_pred, y_true, groups, loss):
        batch_size = y_true.size(0)

        y_pred = y_pred.view(batch_size)
        y_true = y_true.view(batch_size)
        groups = groups.view(batch_size)

        idx = torch.arange(0, batch_size, device=self.device)
        a = torch.stack((y_true, y_pred, groups, idx), dim=1)

        is_correct = a[:, 0] == a[:, 1]

        correct = a[nonzero(is_correct == 1), :]
        correct_class_0 = correct[nonzero(correct[:, 0] == 0), :]
        correct_class_1 = correct[nonzero(correct[:, 0] == 1), :]

        wrong = a[nonzero(is_correct == 0), :]
        wrong_class_0 = wrong[nonzero(wrong[:, 0] == 0), :]
        wrong_class_1 = wrong[nonzero(wrong[:, 0] == 1), :]

        wrong_class_0_filenames = [
            f for i, f in enumerate(filename) if i in wrong_class_0[:, -1]
        ]
        correct_class_0_filenames = [
            f for i, f in enumerate(filename) if i in correct_class_0[:, -1]
        ]

        wrong_class_1_filenames = [
            f for i, f in enumerate(filename) if i in wrong_class_1[:, -1]
        ]
        correct_class_1_filenames = [
            f for i, f in enumerate(filename) if i in correct_class_1[:, -1]
        ]

        class_0_keywords = extract_keywords(
            " ".join([f for i, f in enumerate(caption) if i in wrong_class_0[:, -1]])
        )
        class_1_keywords = extract_keywords(
            " ".join([f for i, f in enumerate(caption) if i in wrong_class_1[:, -1]])
        )

        with torch.no_grad():
            sim_class_0 = self.forward_similarity(
                wrong_class_0_filenames, correct_class_0_filenames, class_0_keywords
            )
            sim_class_1 = self.forward_similarity(
                wrong_class_1_filenames, correct_class_1_filenames, class_1_keywords
            )

        L = []
        for sim_k, k in zip(sim_class_0, class_0_keywords):
            for i, c in enumerate(caption):
                if is_keyword_in_caption(c, k):
                    # print(sim_k.cpu().item(), k, i)
                    # L += [loss[i] * sim_k]
                    loss[i] *= sim_k

        for sim_k, k in zip(sim_class_1, class_1_keywords):
            for i, c in enumerate(caption):
                if is_keyword_in_caption(c, k):
                    # print(sim_k.cpu().item(), k, i)
                    # L += [loss[i] * sim_k]
                    loss[i] *= sim_k
        return loss

In [26]:
loss_fn = nn.BCELoss(reduction="none")
loss = loss_fn(torch.rand(size=(y_true.size(0), 1), requires_grad=True, device=device), y_true)

In [28]:
clip_loss_fn = CLIPLoss("../../datasets/celeba/img_align_celeba/")
loss = clip_loss_fn(y_pred, y_true, g, loss)

In [31]:
loss.mean().backward()

In [None]:
import clip
import torch
import torchvision

# model, preprocess = clip.load("ViT-B/32", jit=False, device="cpu")
# data = torch.rand(1, 3, 224, 244, device="cpu")
# text = clip.tokenize("data").to(device="cpu")
# trace = torch.jit.trace(model, (data, text))
# model, preprocess = clip.load("ViT-B/32", jit=True)
# torch._C._jit_pass_inline(trace.graph)

In [12]:
def _node_get(node: torch._C.Node, key: str):
    """Gets attributes of a node which is polymorphic over return type."""
    sel = node.kindOf(key)
    return getattr(node, sel)(key)


torch._C.Node.__getitem__ = _node_get
model, preprocess = clip.load("ViT-B/32", jit=True)