In [None]:
import numpy as np
import torch
from PIL import Image
import os.path
import argparse
from pathlib import Path

from torch.utils.data import DataLoader
from tqdm import tqdm
from utils.factory import create_model_and_transforms, get_tokenizer
from utils.binary_waterbirds import BinaryWaterbirds
from prs_hook import hook_prs_logger
from torchvision.datasets import CIFAR100, CIFAR10, ImageNet, ImageFolder
from torch.nn import functional as F

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "7"

import torch.nn as nn
from load import *
import random
from torch import optim
from loss import *

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model_name = "ViT-B-32"

# pretrained = "openai"
pretrained = "./DCP-ViT-B-32.pt"

batch_size = 64

In [None]:
model, _, preprocess = create_model_and_transforms(
    model_name, pretrained=pretrained
)
model.to(device)
# model.eval()
model.eval()
context_length = model.context_length
vocab_size = model.vocab_size

print(
    "Model parameters:",
    f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}",
)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
print("Len of res:", len(model.visual.transformer.resblocks))

prs = hook_prs_logger(model, device)

In [5]:
class BatchCosineOrthogonalLoss(nn.Module):
    def __init__(self):
        super(BatchCosineOrthogonalLoss, self).__init__()

    def forward(self, heatmaps):
        b, n, _ = heatmaps.shape  # Assume heatmaps are [batch_size, n, features]

        # Normalize the heatmaps along the last dimension
        norm = torch.norm(heatmaps, p=2, dim=-1, keepdim=True)
        normalized_heatmaps = heatmaps / norm

        # Compute the cosine similarities using batched matrix multiplication
        cosine_similarities = torch.bmm(normalized_heatmaps, normalized_heatmaps.transpose(1, 2))
        
        # Zero out the diagonal (self-cosine similarities)
        mask = torch.eye(n, device=cosine_similarities.device).bool()
        cosine_similarities.masked_fill_(mask.unsqueeze(0), 0)

        # Square the off-diagonal elements
        loss_values = cosine_similarities ** 2

        # Sum the squared values and normalize by the total number of off-diagonal elements in the batch
        loss = loss_values.sum() / (b * n * (n - 1))

        return loss

In [6]:
loss_orth = BatchCosineOrthogonalLoss()
tokenizer = get_tokenizer(model_name)

In [9]:
hparams['descriptor_fname'] = './descriptors/my_cifar10.json'
label_to_classname = label_to_classname_cifar10
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

In [10]:
dataset_cifar10_val = CIFAR10(root=CIFAR10_DIR, train=False, transform=preprocess)
dataloader_val = DataLoader(dataset_cifar10_val, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [12]:
hparams['descriptor_fname'] = './descriptors/my_cifar100.json'
label_to_classname = label_to_classname_cifar100
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_cifar100_val = CIFAR100(root=CIFAR100_DIR, train=False, transform=preprocess)
dataloader_val = DataLoader(dataset_cifar100_val, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:4]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [14]:
hparams['descriptor_fname'] = './descriptors/my_cub.json'
label_to_classname = label_to_classname_cub
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_cub_val = CUBDataset(CUB_DIR, train=False, transform=preprocess)
dataloader_val = DataLoader(dataset_cub_val, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [16]:
hparams['descriptor_fname'] = './descriptors/my_caltech101.json'
label_to_classname = label_to_classname_caltech101
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_caltech101 = torchvision.datasets.Caltech101(root=CALTECH101_DIR, transform=preprocess)
dataloader_val = DataLoader(dataset_caltech101, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [18]:
hparams['descriptor_fname'] = './descriptors/my_oxfordpet.json'
label_to_classname = label_to_classname_oxfordpets
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_oxfordpets_tst = torchvision.datasets.OxfordIIITPet(root=OXFORDPET_DIR, transform=preprocess, split='test')
dataloader_val = DataLoader(dataset_oxfordpets_tst, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [20]:
hparams['descriptor_fname'] = './descriptors/my_food101.json'
label_to_classname = label_to_classname_food101
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_food101_tst = torchvision.datasets.Food101(root=FOOD101_DIR, transform=preprocess, split='test')
dataloader_val = DataLoader(dataset_food101_tst, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [22]:
hparams['descriptor_fname'] = './descriptors/my_sun397.json'
label_to_classname = label_to_classname_sun397
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_sun397_trn, dataset_sun397_tst = torch.utils.data.random_split(dataset_sun397, [100000, 8754])
dataloader_trn = DataLoader(dataset_sun397_trn, batch_size, shuffle=True, num_workers=16, pin_memory=True)
dataloader_val = DataLoader(dataset_sun397_tst, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth

In [24]:
hparams['descriptor_fname'] = './descriptors/my_stanfordcars.json'
label_to_classname = label_to_classname_stanfordcars
gpt_descriptions, unmodify_dict = load_gpt_descriptions(hparams, label_to_classname)

dataset_stanfordcars_trn, dataset_stanfordcars_tst = torch.utils.data.random_split(dataset_stanfordcars, [6000, 2144])
dataloader_val = DataLoader(dataset_stanfordcars_tst, batch_size, shuffle=True, num_workers=16, pin_memory=True)
dataloader_trn = DataLoader(dataset_stanfordcars_trn, batch_size, shuffle=True, num_workers=16, pin_memory=True)

In [None]:
with torch.no_grad():
    total_loss = 0.0
    for batch_number, batch in enumerate(tqdm(dataloader_val)):
        images, labels = batch

        texts = np.array(label_to_classname)[labels].tolist()

        tokenized_concepts_list = []
        for i in range(len(texts)):
            concepts = gpt_descriptions[texts[i]][:5]
            tokenized_concepts = tokenizer(concepts)
            tokenized_concepts_list.append(tokenized_concepts)
        
        images = images.to(device)
        prs.reinit()
        representation = model.encode_image(
            images, attn_method="head", normalize=False
        )
        attentions = prs.finalize(representation)

        tokenized_concepts_list = torch.stack(tokenized_concepts_list).reshape(-1, 77)
        node_text_embeddings = model.encode_text(tokenized_concepts_list.to(device))
        node_text_embeddings = node_text_embeddings.reshape(len(images), -1, 512)
        
        attentions_maps = []
        for i in range(len(attentions)):
            attentions_map = attentions[i, :, 1:, :].sum(axis=(0, 2)) @ node_text_embeddings[i].T
            attentions_maps.append(attentions_map.permute(1, 0))
        attentions_maps = torch.stack(attentions_maps)

        orth_loss = loss_orth(attentions_maps)

        total_loss += orth_loss.item()

avg_orth = total_loss / batch_number
1 - avg_orth