In [None]:
import os
import numpy as np

from transformers import AlignProcessor, AlignModel

import torchvision
import torch

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset = 'places365' #choose from ['imagenet', 'food101', 'gtsrb', 'places365']
base_dir = 'data'

In [None]:
align_processor = AlignProcessor.from_pretrained("kakaobrain/align-base")
align_model = AlignModel.from_pretrained("kakaobrain/align-base").to(device)
align_model = align_model.eval()

transform = torchvision.transforms.Compose([
    # you can add other transformations in this list
    torchvision.transforms.ToTensor()
])

if dataset == 'imagenet':
    from data.imagenet_labels import imagenet_classes as class_list
    data_dir = f'{base_dir}/ImageNet_Validation/'
    all_images = torchvision.datasets.ImageNet(data_dir, split='val')

elif dataset == 'food101':
    data_dir = f'{base_dir}/Food101/'
    all_images = torchvision.datasets.Food101(data_dir, split = 'test')
    
    class_list = all_images.classes
    class_list = [c.replace('_', ' ') for c in class_list]

elif dataset == 'places365':
    from data.places365_labels import places_classlist as class_list

    data_dir = f'{base_dir}/Places365/'
    all_images = torchvision.datasets.Places365(data_dir, split='val')
    
elif dataset == 'gtsrb':
    from data.gtsrb_labels import gtsrb_classes as class_list
    
    data_dir = f'{base_dir}/GTSRB/'
    all_images = torchvision.datasets.GTSRB(data_dir, split = 'test')
else:
    print('DATASET NOT IMPLEMENTED')

token_text = [f'a photo of a {x}.' for x in class_list]
num_base_classes = len(class_list)

with torch.no_grad():
    align_text = align_processor(text = token_text, return_tensors = "pt")

## Standard test (no negative embeddings added)

In [None]:
retest = True
batch_size = 125
if dataset == 'gtsrb':
    batch_size = 30
if not os.path.isfile(f'pred_files/align/{dataset}/standard_gt_align.npy') or not os.path.isfile(f'pred_files/align/{dataset}/standard_cosine_align.npy') or retest:
    all_cosine = []
    gt_labels = []

    for i in tqdm(range(0, len(all_images), batch_size)):
        ims = []
        target = []
        for j in range(batch_size):
            im, t = all_images[i+j]
            ims += [im]
            target += [t]          

        with torch.no_grad():
            inputs = align_processor(images=ims, return_tensors="pt")
        
            inputs['pixel_values'] = inputs['pixel_values'].to(device)
            for k in align_text.keys():
                inputs[k] = align_text[k].to(device)
                
            outputs = align_model(**inputs)

            logits_per_text = torch.matmul(outputs.text_embeds, outputs.image_embeds.t()) / align_model.temperature
            logits_per_image = logits_per_text.t()
            del logits_per_text, outputs
            
            all_cosine += logits_per_image.cpu().tolist()
            gt_labels += target

    gt_labels = np.array(gt_labels)
    all_cosine = np.array(all_cosine)
    
    with open(f'pred_files/align/{dataset}/standard_gt_align.npy', 'wb') as f:
        np.save(f, gt_labels)
    
    with open(f'pred_files/align/{dataset}/standard_cosine_align.npy', 'wb') as f:
        np.save(f, all_cosine)

## Testing Negative Embeddings

In [None]:
retest = True

batch_size = 125
if dataset == 'gtsrb':
    batch_size = 30

for neg_count in [10, 50, 100, 250, 1000, 2500]:
    for type in ['embedding', 'word']:
        other_type = f'{neg_count}_{type}'
        if not os.path.isfile(f'pred_files/align/{dataset}/{other_type}_gt_align.npy') or not os.path.isfile(f'pred_files/align/{dataset}/{other_type}_cosine_align.npy') or retest
            loader = torch.utils.data.DataLoader(images, batch_size=64, num_workers=4)

            extra_words = []
            if 'word' in other_type:
                count = int(other_type.replace('_word', ''))
                all_letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
                lengths = [2, 3, 4, 5, 6, 7, 8]
                l_p = [0.01, 0.04, 0.2, 0.25, 0.25, 0.2]
                
                extra_words = []
                
                for i in range(neg_count):
                    np.random.seed(i)
                
                    num = np.random.choice(lengths, 1, l_p)
                    letters = np.random.choice(all_letters, num)
                    word = ''.join(letters)
                    extra_words += [f'a photo of a {word}.']
                    
            with torch.no_grad():
                align_text_other = align_processor(text = token_text+extra_words, return_tensors = "pt")
    
            if 'emb' in other_type:
                text_dims = 640
               
                torch.manual_seed(0)
                count = int(other_type.replace('_emb', ''))
                other_emb = torch.Tensor(np.random.normal(0, 0.25, (count, text_dims)))
                other_emb = (other_emb / other_emb.norm(p=2, dim=-1, keepdim=True)).cuda()
            else:
                other_emb = torch.Tensor([[]]).cuda()
            
            all_cosine = []
            gt_labels = []
            for i in tqdm(range(0, len(all_images), batch_size)):
                ims = []
                target = []
                for j in range(batch_size):
                    im, t = all_images[i+j]
                    ims += [im]
                    target += [t]          
                    
                with torch.no_grad():
                    inputs = align_processor(images=ims, return_tensors="pt")
                
                    inputs['pixel_values'] = inputs['pixel_values'].to(device)
                    for k in align_text_other.keys():
                        inputs[k] = align_text_other[k].to(device)
                        
                    outputs = align_model(**inputs)
                    
                    if 'emb' in other_type:
                        text_embeds = torch.concat((outputs.text_embeds, other_emb))
                    else:
                        text_embeds = outputs.text_embeds
            
                    logits_per_text = torch.matmul(text_embeds, outputs.image_embeds.t()) / align_model.temperature
                
                logits_per_image = logits_per_text.t()
                del logits_per_text, outputs, text_embeds
                all_cosine += logits_per_image.cpu().tolist()
                gt_labels += target
            
            gt_labels = np.array(gt_labels)
            all_cosine = np.array(all_cosine)
            
            with open(f'pred_files/align/{dataset}/{other_type}_gt_align.npy', 'wb') as f:
                np.save(f, gt_labels)
            
            with open(f'pred_files/align/{dataset}/{other_type}_cosine_align.npy', 'wb') as f:
                np.save(f, all_cosine)
                