In [1]:
import os
#virtually move to parent directory
os.chdir("..")

import math
import numpy as np
import torch
import pandas as pd

import clip
import utils
import similarity

In [2]:
#Arguments
clip_name = 'ViT-B/16'
target_name = 'resnet50'
target_layer = 'fc'
d_probe = 'imagenet_broden'#"cifar100_train"#"imagenet_val"#"broden"#"imagenet_broden"

batch_size = 200
device = 'cuda'
pool_mode = 'avg'

save_dir = 'saved_activations'
target_preprocess = utils.get_resnet_imagenet_preprocess()
similarity_fn = similarity.soft_wpmi #wpmi, rank_reorder

# Cos similarities

In [3]:
concept_set = 'data/20k.txt'

utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                       d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                       device = device, pool_mode=pool_mode, target_preprocess = target_preprocess,
                      save_dir = save_dir)

save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                  target_layer = target_layer, d_probe = d_probe,
                                  concept_set = concept_set, pool_mode=pool_mode,
                                  save_dir = save_dir)

target_save_name, clip_save_name, text_save_name = save_names

similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                        text_save_name, similarity_fn)

with open(concept_set, 'r') as f:
    words = f.read().split('\n')

100%|██████████| 1000/1000 [00:01<00:00, 843.98it/s]


In [4]:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('all-mpnet-base-v2')
clip_model, _ = clip.load(clip_name, device=device)

with open("data/imagenet_labels.txt", "r") as f:
    cls_id_to_name = f.read().split("\n")

In [5]:
clip_preds = torch.argmax(similarities, dim=1)
clip_preds = [words[int(pred)] for pred in clip_preds]

clip_cos, mpnet_cos = utils.get_cos_similarity(clip_preds, cls_id_to_name, clip_model, model, device, batch_size)
print("CLIP-Dissect - Clip similarity: {:.4f}, mpnet similarity: {:.4f}".format(clip_cos, mpnet_cos))

CLIP-Dissect - Clip similarity: 0.7900, mpnet similarity: 0.5233


# Accuracies

In [6]:
concept_set = 'data/imagenet_labels.txt'

utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                       d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                       device = device, pool_mode=pool_mode, target_preprocess = target_preprocess,
                      save_dir = save_dir)

save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                  target_layer = target_layer, d_probe = d_probe,
                                  concept_set = concept_set, pool_mode=pool_mode,
                                  save_dir = save_dir)

target_save_name, clip_save_name, text_save_name = save_names

similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                        text_save_name, similarity_fn)

with open(concept_set, 'r') as f: 
    words = (f.read()).split('\n')

100%|██████████| 1000/1000 [00:00<00:00, 4374.03it/s]


In [7]:
from sklearn import metrics

def get_topk_acc(sim, k=5):
    correct = 0
    for orig_id in range(1000):
        vals, ids = torch.topk(sim[orig_id], k=k)
        for idx in ids[:k]:
            correct += (int(idx)==orig_id)
    return (correct/1000)*100

def get_correct_rank_mean_median(sim):
    ranks = []
    for orig_id in range(1000):
        vals, ids = torch.sort(sim[orig_id], descending=True)
        
        ranks.append(list(ids).index(orig_id)+1)
        
    mean = sum(ranks)/len(ranks)
    median = sorted(ranks)[500]
    return mean, median

def get_auc(sim):
    max_sim, preds = torch.max(sim.cpu(), dim=1)
    gtruth = torch.arange(0, 1000)
    correct = (preds==gtruth)
    fpr, tpr, thresholds = metrics.roc_curve(correct, max_sim)
    auc = metrics.roc_auc_score(correct, max_sim)
    return auc

In [8]:
print("CLIP-Dissect Top 1 acc:{:.4f}".format(get_topk_acc(similarities, k=1)))
print("CLIP-Dissect Top 5 acc:{:.4f}".format(get_topk_acc(similarities, k=5)))

mean, median = get_correct_rank_mean_median(similarities)
print("Mean rank of correct class:{}, Median rank of correct class:{}".format(mean, median))
print("AUC:{:.4f}".format(get_auc(similarities)))

CLIP-Dissect Top 1 acc:95.4000
CLIP-Dissect Top 5 acc:99.0000
Mean rank of correct class:1.194, Median rank of correct class:1
AUC:0.9166
