In [None]:
import os
import numpy as np

import clip

import torchvision
import torch

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset = 'gtsrb' #choose from ['imagenet', 'food101', 'gtsrb', 'places365']
base_dir = 'data'

In [None]:
clip_model, clip_preprocess = clip.load("ViT-B/32", device=device)
clip_model.eval()

if dataset == 'imagenet':
    from data.imagenet_labels import imagenet_classes as class_list
    data_dir = f'{base_dir}/ImageNet_Validation/'
    all_images = torchvision.datasets.ImageNet(data_dir, split='val', transform=clip_preprocess)

elif dataset == 'food101':
    data_dir = f'{base_dir}/Food101/'
    all_images = torchvision.datasets.Food101(data_dir, split = 'test', transform = clip_preprocess)
    
    class_list = all_images.classes
    class_list = [c.replace('_', ' ') for c in class_list]

elif dataset == 'places365':
    from data.places365_labels import places_classlist as class_list

    data_dir = f'{base_dir}/Places365/'
    all_images = torchvision.datasets.Places365(data_dir, split='val', transform = clip_preprocess)
    
elif dataset == 'gtsrb':
    from data.gtsrb_labels import gtsrb_classes as class_list
    
    data_dir = f'{base_dir}/GTSRB/'
    all_images = torchvision.datasets.GTSRB(data_dir, split = 'test', transform = clip_preprocess)
else:
    print('DATASET NOT IMPLEMENTED')

token_text = [f'a photo of a {x}.' for x in class_list]
num_base_classes = len(class_list)

with torch.no_grad():
    clip_text = clip.tokenize(token_text).to(device)

## Standard test (no negative embeddings added)

In [None]:
retest = True

if not os.path.isfile(f'pred_files/clip/{dataset}/standard_gt_clip.npy') or not os.path.isfile(f'pred_files/clip/{dataset}/standard_cosine_clip.npy') or retest:
    loader = torch.utils.data.DataLoader(all_images, batch_size=64, num_workers=4)
    
    all_cosine = []
    gt_labels = []

    for i, (images, target) in enumerate(tqdm(loader)):
        image = images.cuda()
        target = target.cuda()
        
        with torch.no_grad():
            logits_per_image, logits_per_text = clip_model(image, clip_text)
    
        all_cosine += logits_per_image.cpu().tolist()
        gt_labels += target.cpu().tolist()

    gt_labels = np.array(gt_labels)
    all_cosine = np.array(all_cosine)
    
    with open(f'pred_files/clip/{dataset}/standard_gt_clip.npy', 'wb') as f:
        np.save(f, gt_labels)
    
    with open(f'pred_files/clip/{dataset}/standard_cosine_clip.npy', 'wb') as f:
        np.save(f, all_cosine)

## Testing Negative Embeddings

In [None]:
retest = True

for neg_count in [10, 50, 100, 250, 1000, 2500]:
    for type in ['embedding', 'word']:
        other_type = f'{neg_count}_{type}'
        if not os.path.isfile(f'pred_files/clip/{dataset}/{other_type}_gt_clip.npy') or not os.path.isfile(f'pred_files/clip/{dataset}/{other_type}_cosine_clip.npy') or retest
            loader = torch.utils.data.DataLoader(images, batch_size=64, num_workers=4)

            extra_words = []
            if 'word' in other_type:
                count = int(other_type.replace('_word', ''))
                all_letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
                lengths = [2, 3, 4, 5, 6, 7, 8]
                l_p = [0.01, 0.04, 0.2, 0.25, 0.25, 0.2]
                
                extra_words = []
                
                for i in range(neg_count):
                    np.random.seed(i)
                
                    num = np.random.choice(lengths, 1, l_p)
                    letters = np.random.choice(all_letters, num)
                    word = ''.join(letters)
                    extra_words += [f'a photo of a {word}.']
                    
            with torch.no_grad():
                clip_text_other = clip.tokenize(token_text+extra_words).to(device)
        
            if 'emb' in other_type:
                text_dims = clip_text_other.size(1)
               
                torch.manual_seed(0)
                other_emb = torch.randint(49407, (neg_count, text_dims), dtype = torch.int32).cuda()
                clip_text_other = torch.cat((clip_text_other, other_emb))


            all_cosine = []
            gt_labels = []
 
            for i, (images, target) in enumerate(tqdm(loader)):
                image = images.cuda()
                target = target.cuda()
                
                with torch.no_grad():
                    logits_per_image, logits_per_text = clip_model(image, clip_text_other)
            
                all_cosine += logits_per_image.cpu().tolist()
                gt_labels += target.cpu().tolist()
        
           
            gt_labels = np.array(gt_labels)
            all_cosine = np.array(all_cosine)
            
            with open(f'pred_files/clip/{dataset}/{other_type}_gt_clip.npy', 'wb') as f:
                np.save(f, gt_labels)
            
            with open(f'pred_files/clip/{dataset}/{other_type}_cosine_clip.npy', 'wb') as f:
                np.save(f, all_cosine)
                