In [None]:
import os
import numpy as np
from omegaconf import OmegaConf

import torch
from torchtnt.utils.device import copy_data_to_device
from torchvision.transforms.functional import resize, normalize
import clip

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sns

from sklearn.metrics import average_precision_score
import logging

from model import Model
from dataset import *

logger = logging.getLogger()
logger.setLevel(logging.DEBUG)

dataset_name = "CELEBA"
root = "./experiments/celeba_linear"
device = "cuda:0"

with open(f"{root}/.hydra/config.yaml", "r") as f:
    cfg = OmegaConf.load(f)

In [None]:
trainloader = get_loader(dataset_name=dataset_name,
                         split="train",
                         img_size=cfg.general.img_size,
                         batch_size=cfg.train.batch_size,
                         num_workers=cfg.train.num_workers,
                         pin_memory=cfg.train.pin_memory,
                         persistent_workers=cfg.val.persistent_workers,)

testloader = get_loader(dataset_name=dataset_name,
                        split="test",
                        img_size=cfg.general.img_size,
                        batch_size=cfg.val.batch_size,
                        num_workers=cfg.train.num_workers,
                        pin_memory=cfg.train.pin_memory,
                        persistent_workers=cfg.val.persistent_workers,)

# Load CLIP

In [None]:
clip_model, clip_preprocess = clip.load("ViT-B/32");
clip_model.to(device).eval();

# Load our model

In [None]:
module = Model(
    dim_in=cfg.model.dim_in,
    dim_out=cfg.model.dim_out,
    backbone=cfg.model.backbone,
    kernel=cfg.model.kernel,
    activation=torch.nn.functional.leaky_relu,
    alpha=cfg.train.loss_coefs.alpha,
    beta=cfg.train.loss_coefs.beta,
    device=device,
).to(device)

state_dict = torch.load(os.path.join(root, "state_dict.pt"), map_location="cpu", weights_only=True)
module.load_state_dict(state_dict)
module.eval()

module.forget()
for data in trainloader:
    data = copy_data_to_device(data, device=device)
    with torch.no_grad():
        output = module(data)
module.update_minterms()

minterms = []
minterm_evecs = []
for k,v in module._minterms.items():
    minterms.append(k)
    minterm_evecs.append(v)
minterms = torch.tensor(minterms).to(device)
minterm_evecs = torch.cat(minterm_evecs).to(device)

In [None]:
plt.figure()
sns.heatmap(module.kernel(minterm_evecs, minterm_evecs).cpu() ** 2)

# Embed test set with our model

In [None]:
our_test_embeddings = []
test_images = []
test_labels = []
for data in testloader:
    data = copy_data_to_device(data, device=device)
    test_images.append(data.images)
    test_labels.append(data.labels)
    with torch.no_grad():
        our_test_embeddings.append(module.embed(data)[0])
our_test_embeddings = torch.cat(our_test_embeddings, dim=0)
test_images = torch.cat(test_images, dim=0)
test_labels = torch.cat(test_labels)

# Embed test set with CLIP

In [None]:
clip_test_embeddings = []
for data in testloader:
    data = copy_data_to_device(data, device=device)
    with torch.no_grad():
        preprocessed = resize(data.images, 224)
        preprocessed = normalize(preprocessed,
                                 mean=(0.48145466, 0.4578275, 0.40821073),
                                 std=(0.26862954, 0.26130258, 0.27577711))
        clip_test_embeddings.append(clip_model.encode_image(preprocessed).float())
clip_test_embeddings = torch.cat(clip_test_embeddings, dim=0)

# Retrieval (all queries)

In [None]:
bag_of_words = True
with open("./queries.txt") as file:
    lines = [line.rstrip() for line in file]
    
queries_ = [l.split("-")[0].strip() for l in lines]
nl_queries_ = [l.split("-")[1].strip() for l in lines]

our_ap = []
our_pr = []
clip_ap = []
clip_pr = []
queries = []
nl_queries = []

for query, nl_query in zip(queries_, nl_queries_):
    if bag_of_words:
        nl_query = query.replace(" and", ", ").replace("not", "").replace("_", " ").strip()
        
    print(f"\n{query}\n{nl_query}")

    literals = [l.strip() for l in query.split("and")]
    pos_literals = [l for l in literals if "not" not in l]
    neg_literals = [l.replace("not", "").strip() for l in literals if "not" in l]
    pos_idx = [testloader.dataset.class_to_idx[l] for l in pos_literals]
    neg_idx = [testloader.dataset.class_to_idx[l] for l in neg_literals]

    target = torch.logical_and((test_labels[:,pos_idx] == 1).all(-1), (test_labels[:,neg_idx] == 0).all(-1))

    if target.sum() < 10:
        continue

    queries.append(query)
    nl_queries.append(nl_query)

    masks = [minterms[:,idx] == 1 for idx in pos_idx]
    masks.extend([minterms[:,idx] == 0 for idx in neg_idx])

    mask = masks[0]
    for i in range(len(masks)):
        mask = torch.logical_and(mask, masks[i])
        
    projection = minterm_evecs[mask].T @ minterm_evecs[mask]

    p = torch.einsum(
        "bi,ij,bj->b",
        F.normalize(our_test_embeddings, dim=-1, p=2),
        projection,
        F.normalize(our_test_embeddings, dim=-1, p=2)
    )
    idx = torch.argsort(p, descending=True)

    ap = average_precision_score(target.cpu(), p.cpu())
    precision = (target[idx[:10]].sum() / 10).cpu().item()
    print(f"Our AP = {ap}")
    print(f"Our Pr@10 = {precision}")
    our_ap.append(ap)
    our_pr.append(precision)

    #CLIP
    with torch.no_grad():
        text_tokens = clip.tokenize([nl_query]).to(device)
        clip_text_embedding = clip_model.encode_text(text_tokens).float()

    clip_test_embeddings /= clip_test_embeddings.norm(dim=-1, keepdim=True)
    clip_text_embedding /= clip_text_embedding.norm(dim=-1, keepdim=True)
    similarity = (clip_text_embedding @ clip_test_embeddings.T).squeeze()
    idx = torch.argsort(similarity.squeeze(), descending=True) 

    ap = average_precision_score(target.cpu(), similarity.cpu())
    precision = (target[idx[:10]].sum() / 10).cpu().item()
    print(f"CLIP AP = {ap}")
    print(f"CLIP Pr@10 = {precision}")
    clip_ap.append(ap)
    clip_pr.append(precision)
    
print("-"*50)
print(f"Pr@10 positive (Ours) = {np.array(our_pr)[np.array(['not' not in q for q in queries])].mean()}")
print(f"mAP   positive (Ours) = {np.array(our_ap)[np.array(['not' not in q for q in queries])].mean()}")
print(f"Pr@10 negative (Ours) = {np.array(our_pr)[np.array(['not' in q for q in queries])].mean()}")
print(f"mAP   negative (Ours) = {np.array(our_ap)[np.array(['not' in q for q in queries])].mean()}")
print("-"*50)
print(f"Pr@10 positive (CLIP) = {np.array(clip_pr)[np.array(['not' not in q for q in queries])].mean()}")
print(f"mAP   positive (CLIP) = {np.array(clip_ap)[np.array(['not' not in q for q in queries])].mean()}")
print(f"Pr@10 negative (CLIP) = {np.array(clip_pr)[np.array(['not' in q for q in queries])].mean()}")
print(f"mAP   negative (CLIP) = {np.array(clip_ap)[np.array(['not' in q for q in queries])].mean()}")

# Retrieval visualization

In [None]:
query = "Bald"
nl_query = "a male person that is not bald"

literals = [l.strip() for l in query.split("and")]

pos_literals = [l for l in literals if "not" not in l]
neg_literals = [l.replace("not", "").strip() for l in literals if "not" in l]

print(f"Positive literals {pos_literals}")
print(f"Negative literals {neg_literals}")

pos_idx = [testloader.dataset.class_to_idx[l] for l in pos_literals]
neg_idx = [testloader.dataset.class_to_idx[l] for l in neg_literals]

target = torch.logical_and(
    (test_labels[:,pos_idx] == 1).all(-1),
    (test_labels[:,neg_idx] == 0).all(-1)
)

In [None]:
masks = [minterms[:,idx] == 1 for idx in pos_idx]
masks.extend([minterms[:,idx] == 0 for idx in neg_idx])
mask = masks[0]
for i in range(len(masks)):
    mask = torch.logical_and(mask, masks[i])
    
projection = minterm_evecs[mask].T @ minterm_evecs[mask]
u, s, vt = torch.linalg.svd(projection)
print(s)

p = torch.einsum(
    "bi,ij,bj->b",
    F.normalize(our_test_embeddings, dim=-1, p=2),
    projection,
    F.normalize(our_test_embeddings, dim=-1, p=2)
)
idx = torch.argsort(p, descending=True)

ap = average_precision_score(target.cpu(), p.cpu())
precision = target[idx[:10]].sum() / 10
print(f"AP = {ap}")
print(f"Precision@10 = {precision}")

In [None]:
n = 20
fig = plt.figure(figsize=(20,5))
for i in range(n):
    plt.subplot(2,n//2,i+1)
    plt.imshow(test_images[idx[i]].permute(1,2,0).cpu())
    plt.axis("off")
plt.tight_layout()
#fig.savefig(f"{query}_ours.png")

In [None]:
with torch.no_grad():
    text_tokens = clip.tokenize([nl_query]).to(device)
    clip_text_embedding = clip_model.encode_text(text_tokens).float()

clip_test_embeddings /= clip_test_embeddings.norm(dim=-1, keepdim=True)
clip_text_embedding /= clip_text_embedding.norm(dim=-1, keepdim=True)
similarity = (clip_text_embedding @ clip_test_embeddings.T).squeeze()
idx = torch.argsort(similarity.squeeze(), descending=True) 

ap = average_precision_score(target.cpu(), similarity.cpu())
precision = target[idx[:10]].sum() / 10
print(f"AP = {ap}")
print(f"Precision@10 = {precision}")

In [None]:
n = 20
fig = plt.figure(figsize=(20,5))
for i in range(n):
    plt.subplot(2,n//2,i+1)
    plt.imshow(test_images[idx[i]].permute(1,2,0).cpu())
    plt.axis("off")
plt.tight_layout()
#fig.savefig(f"{nl_query}_clip.png")