In [None]:
import os
import numpy as np

from imagebind import data
import torch
from imagebind.models import imagebind_model
from imagebind.models.imagebind_model import ModalityType
from imagebind.models.multimodal_preprocessors import SimpleTokenizer

import torchvision
import torch

from tqdm import tqdm

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
dataset = 'gtsrb' #choose from ['imagenet', 'food101', 'gtsrb', 'places365']
base_dir = 'data'
imagebind_path = 'ImageBind/'

In [None]:
# Instantiate model
model = imagebind_model.imagebind_huge(pretrained=True)
model.eval()
model.to(device)

tokenizer = SimpleTokenizer(bpe_path=f"{imagebind_path}bpe/bpe_simple_vocab_16e6.txt.gz")

imagebind_preprocess = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize(
            224, interpolation=torchvision.transforms.InterpolationMode.BICUBIC
        ),
        torchvision.transforms.CenterCrop(224),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(
            mean=(0.48145466, 0.4578275, 0.40821073),
            std=(0.26862954, 0.26130258, 0.27577711),
        ),
    ]
)


if dataset == 'imagenet':
    from data.imagenet_labels import imagenet_classes as class_list
    data_dir = f'{base_dir}/ImageNet_Validation/'
    all_images = torchvision.datasets.ImageNet(data_dir, split='val', transform=imagebind_preprocess)

elif dataset == 'food101':
    data_dir = f'{base_dir}/Food101/'
    all_images = torchvision.datasets.Food101(data_dir, split = 'test', transform = imagebind_preprocess)
    
    class_list = all_images.classes
    class_list = [c.replace('_', ' ') for c in class_list]

elif dataset == 'places365':
    from data.places365_labels import places_classlist as class_list

    data_dir = f'{base_dir}/Places365/'
    all_images = torchvision.datasets.Places365(data_dir, split='val', transform = imagebind_preprocess)

elif dataset == 'gtsrb':
    from data.gtsrb_labels import gtsrb_classes as class_list
    
    data_dir = f'{base_dir}/GTSRB/'
    all_images = torchvision.datasets.GTSRB(data_dir, split = 'test', transform = imagebind_preprocess)
else:
    print('DATASET NOT IMPLEMENTED')

token_text = [f'a photo of a {x}.' for x in class_list]
num_base_classes = len(class_list)

tokens = [tokenizer(t).unsqueeze(0).to(device) for t in token_text]
tokens = torch.cat(tokens, dim=0)

## Standard test (no negative embeddings added)

In [None]:
retest = True

if not os.path.isfile(f'pred_files/imagebind/{dataset}/standard_gt_imagebind.npy') or not os.path.isfile(f'pred_files/imagebind/{dataset}/standard_cosine_imagebind.npy') or retest:
    loader = torch.utils.data.DataLoader(all_images, batch_size=64, num_workers=4)
    
    all_cosine = []
    gt_labels = []
    for i, (images, target) in enumerate(tqdm(loader)):
        images = images.cuda()
        target = target.tolist()
        
        inputs = {
            ModalityType.TEXT: tokens,
            ModalityType.VISION: images
        }
        
        with torch.no_grad():
            embeddings = model(inputs)

        logits_per_image = embeddings[ModalityType.VISION] @ embeddings[ModalityType.TEXT].T
        
        all_cosine += logits_per_image.cpu().tolist()
        gt_labels += target

    gt_labels = np.array(gt_labels)
    all_cosine = np.array(all_cosine)
    
    with open(f'pred_files/imagebind/{dataset}/standard_gt_imagebind.npy', 'wb') as f:
        np.save(f, gt_labels)
    
    with open(f'pred_files/imagebind/{dataset}/standard_cosine_imagebind.npy', 'wb') as f:
        np.save(f, all_cosine)

## Testing Negative Embeddings

In [None]:
retest = True

for neg_count in [10, 50, 100, 250, 1000, 2500]:
    for type in ['embedding', 'word']:
        other_type = f'{neg_count}_{type}'
        if not os.path.isfile(f'pred_files/imagebind/{dataset}/{other_type}_gt_imagebind.npy') or not os.path.isfile(f'pred_files/imagebind/{dataset}/{other_type}_cosine_imagebind.npy') or retest
            loader = torch.utils.data.DataLoader(images, batch_size=64, num_workers=4)

            extra_words = []
            if 'word' in other_type:
                count = int(other_type.replace('_word', ''))
                all_letters = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
                lengths = [2, 3, 4, 5, 6, 7, 8]
                l_p = [0.01, 0.04, 0.2, 0.25, 0.25, 0.2]
                
                extra_words = []
                
                for i in range(neg_count):
                    np.random.seed(i)
                
                    num = np.random.choice(lengths, 1, l_p)
                    letters = np.random.choice(all_letters, num)
                    word = ''.join(letters)
                    extra_words += [f'a photo of a {word}.']
                    
            tokens_other = [tokenizer(t).unsqueeze(0).to(device) for t in token_text+extra_words]
            tokens_other = torch.cat(tokens_other, dim=0)

            if 'emb' in other_type:
                text_dims = tokens.size(1)
               
                torch.manual_seed(0)
                other_emb = torch.randint(49407, (neg_count, text_dims), dtype = torch.int32).cuda()
                tokens_other = torch.cat((tokens_other, other_emb))

            all_cosine = []
            gt_labels = []
 
            for i, (images, target) in enumerate(tqdm(loader)):
                image = images.cuda()
                target = target.cuda()
                
                with torch.no_grad():
                    logits_per_image, logits_per_text = clip_model(image, clip_text_other)
            
                all_cosine += logits_per_image.cpu().tolist()
                gt_labels += target.cpu().tolist()
        
           
            gt_labels = np.array(gt_labels)
            all_cosine = np.array(all_cosine)
            
            with open(f'pred_files/imagebind/{dataset}/{other_type}_gt_imagebind.npy', 'wb') as f:
                np.save(f, gt_labels)
            
            with open(f'pred_files/imagebind/{dataset}/{other_type}_cosine_imagebind.npy', 'wb') as f:
                np.save(f, all_cosine)
                