## Similarity function comparison

In [None]:
import os
#virtually move to parent directory
os.chdir("..")

import torch
from sentence_transformers import SentenceTransformer
from sklearn import metrics

import clip
import utils
import similarity

## Settings

In [None]:
similarity_fns = [ "soft_wpmi"]#, 'wpmi', 'rank_reorder']
d_probes = ['cifar100_train', 'imagenet_val', 'imagenet_broden',
           'broden']

clip_name = 'ViT-B/16'
target_name = 'deit-tiny-relu'
target_layer = 'head'
batch_size = 100
device = 'cuda'
pool_mode = 'avg'
print(os.getcwd())


In [None]:
model = SentenceTransformer('all-mpnet-base-v2')
clip_model, _ = clip.load(clip_name, device=device)

with open("data/concept_sets/broden_labels_clean.txt", "r") as f:
    cls_id_to_name = f.read().split("\n")

# Cos similarities

In [None]:
concept_set = 'broden_labels_clean.txt'

with open('data/concept_sets/'+concept_set, 'r') as f:
    words = f.read().split('\n')

for similarity_fn in similarity_fns:
    for d_probe in d_probes:
        folder_name = '{}_FINAL/{}/'.format(target_name, d_probe)
        save_dir = 'experiments/' + folder_name 
        save_dir +=  'saved_activations'


        utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                               d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                               device = device, pool_mode=pool_mode, save_dir = save_dir)

        save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                          target_layer = target_layer, d_probe = d_probe,
                                          concept_set = concept_set, pool_mode=pool_mode,
                                          save_dir = save_dir)

        target_save_name, clip_save_name, text_save_name = save_names

        similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                                           text_save_name, 
                                                                           eval("similarity.{}".format(similarity_fn)),
                                                                           device=device)

        clip_preds = torch.argmax(similarities, dim=1)
        clip_preds = [words[int(pred)] for pred in clip_preds]

        clip_cos, mpnet_cos = utils.get_cos_similarity(clip_preds, cls_id_to_name, clip_model, model, device, batch_size)
        print("Similarity fn: {}, D_probe: {}".format(similarity_fn, d_probe))
        print("Clip similarity: {:.4f}, mpnet similarity: {:.4f}".format(clip_cos, mpnet_cos))

# Accuracies

In [None]:
def get_topk_acc(sim, k=5):
    correct = 0
    for orig_id in range(1000):
        vals, ids = torch.topk(sim[orig_id], k=k)
        for idx in ids[:k]:
            correct += (int(idx)==orig_id)
    return (correct/1000)*100

def get_correct_rank_mean_median(sim):
    ranks = []
    for orig_id in range(1000):
        vals, ids = torch.sort(sim[orig_id], descending=True)
        
        ranks.append(list(ids).index(orig_id)+1)
        
    mean = sum(ranks)/len(ranks)
    median = sorted(ranks)[500]
    return mean, median

def get_auc(sim):
    max_sim, preds = torch.max(sim.cpu(), dim=1)
    gtruth = torch.arange(0, 1000)
    correct = (preds==gtruth)
    fpr, tpr, thresholds = metrics.roc_curve(correct, max_sim)
    auc = metrics.roc_auc_score(correct, max_sim)
    return auc

In [None]:
concept_set = 'imagenet_labels.txt'
with open('data/'+concept_set, 'r') as f: 
    words = (f.read()).split('\n')
    


for similarity_fn in similarity_fns:
    for d_probe in d_probes:
        folder_name = '{}_FINAL/{}'.format(target_name, d_probe)
        save_dir = 'experiments/{}/saved_activations'.format(folder_name)
       

        
        utils.save_activations(clip_name = clip_name, target_name = target_name, target_layers = [target_layer], 
                               d_probe = d_probe, concept_set = concept_set, batch_size = batch_size, 
                               device = device, pool_mode=pool_mode, save_dir = save_dir)

        save_names = utils.get_save_names(clip_name = clip_name, target_name = target_name,
                                          target_layer = target_layer, d_probe = d_probe,
                                          concept_set = concept_set, pool_mode=pool_mode,
                                          save_dir = save_dir)

        target_save_name, clip_save_name, text_save_name = save_names

        similarities, target_feats = utils.get_similarity_from_activations(target_save_name, clip_save_name, 
                                                                           text_save_name, 
                                                                           eval("similarity.{}".format(similarity_fn)),
                                                                           device=device)
        
        print("Similarity fn: {}, D_probe: {}".format(similarity_fn, d_probe))
        print("Top 1 acc: {:.2f}%, Top 5 acc: {:.2f}%".format(get_topk_acc(similarities, k=1),
                                                         get_topk_acc(similarities, k=5)))
        
        print("Similarity fn: {}, D_probe: {}".format(similarity_fn, d_probe))
        print("Top 1 acc: {:.2f}%, Top 5 acc: {:.2f}%".format(get_topk_acc(similarities, k=1),
                                                         get_topk_acc(similarities, k=5)))
        
        #mean, median = get_correct_rank_mean_median(similarities)
        #print("Mean rank of correct class: {:.2f}, Median rank of correct class: {}".format(mean, median))
        #print("AUC: {:.4f}".format(get_auc(similarities)))

