Modified from https://github.com/McGill-NLP/imagecode/blob/main/baselines/clip/evaluate_clip.py

## Evaluating Clip Non-Contrastive and Contrastive

In [None]:
!pip install torch torchvision 
!pip install git+https://github.com/openai/CLIP.git 


In [35]:
from tqdm import tqdm
import json
from collections import defaultdict
from glob import glob
import os
import numpy as np
import clip
import torch
from PIL import Image
from pathlib import Path
import statistics
import argparse

def encode_images(photos_batch):
    photos = [Image.open(photo_file) for photo_file in photos_batch]
    photos_preprocessed = torch.stack([preprocess(photo) for photo in photos]).to(device)

    with torch.no_grad():
        photos_features = model.encode_image(photos_preprocessed)
        photos_features /= photos_features.norm(dim=-1, keepdim=True)
    return photos_features.cpu().numpy()


def encode_text(search_query):
    with torch.no_grad():
        text_encoded = model.encode_text(clip.tokenize(search_query, truncate=True).to(device))
        text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
    return text_encoded.cpu().numpy()


def find_best_matches(text_features, photo_features):
    similarities = (photo_features @ text_features.T).squeeze(1)
    best_photo_idx = (-similarities).argsort()
    similarities = -similarities
    similarities.sort()
    return best_photo_idx, similarities

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'USING DEVICE: {device}')
model, preprocess = clip.load('ViT-B/16', device=device, jit=False )  # Must set jit=False for training

checkpoint = torch.load('/kaggle/input/imagecode-checkpoints/NOCONTRA_clip_best__36_4e-6_ViT-B16_--job_id_1258699.pt')
model.load_state_dict(checkpoint['model_state_dict'])
print(checkpoint['epoch'])
clip.model.convert_weights(model)  # Actually this line is unnecessary since clip by default already on float16
model.eval()

correct = 0
total = 0
vid_correct = 0
vid_total = 0
img_correct= 0
img_total = 0


img_dirs = '/kaggle/input/imagecode-simple/image-sets/image-sets'
descriptions = json.load(open('/kaggle/input/imagecode-simple/valid_simple.json', 'r'))

valid = []
for data in descriptions:
    valid.append((data['directory'], data['pos_idx'],data['neg_idx'], data['caption']))

results = defaultdict(dict)
for img_dir, pos_idx, neg_idx, text in tqdm(valid):
    text = [text]
    pos_idx = int(pos_idx)
    img_files = list((Path(img_dirs) / img_dir).glob("*.jpg"))
    img_files = sorted(img_files, key=lambda x: int(str(x).split('/')[-1].split('.')[0][3:]))
    img_files = [img_files[pos_idx], img_files[neg_idx]]
    
    images = [Image.open(photo_file) for photo_file in img_files]
    images = torch.stack([preprocess(photo) for photo in images]).to(device)
    text = clip.tokenize(text, truncate=True).to(device)
    with torch.no_grad():
        image_features = model.encode_image(images)
        text_features = model.encode_text(text)
        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        logits = (image_features @ text_features.T).squeeze(1)
    
    pred = torch.argmax(logits).squeeze()


    if pred.item() == 0:
        correct += 1
    if 'open-images' in img_dir:
        img_total += 1
        if pred.item() == 0:
            img_correct += 1
    else:
        vid_total += 1
        if pred.item() == 0:
            vid_correct += 1     
    results[img_dir].update({f'raw_preds_{pos_idx}': logits.squeeze().tolist(), f'clip_pred_{pos_idx}': int(pred.item()) ,f'correct_{pos_idx}': 1 if pos_idx == pred else 0})


print('OVERALL ACC: ' + str(round(correct/len(valid),4)))
print('VIDEO ACC: ' + str(round(vid_correct/vid_total,4)))
print('IMG ACC: ' + str(round(img_correct/img_total,4)))
json.dump(results, open(f'nocontra-clip-valid-simple-data.json', 'w'), indent=2)

USING DEVICE: cuda
2


100%|██████████| 2302/2302 [01:49<00:00, 20.99it/s]

OVERALL ACC: 0.6338
VIDEO ACC: 0.5908
IMG ACC: 0.8209





## Evaluating Contextual and Temporal Contextual

Here we augment the evaluation_contextual.py in a similiar way that we did above. Instead of a classification task with 10 images, the problem is reduced to a binary classification.

In [5]:
!git clone https://github.com/McGill-NLP/imagecode.git

Cloning into 'imagecode'...
remote: Enumerating objects: 589, done.[K
remote: Counting objects: 100% (55/55), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 589 (delta 45), reused 37 (delta 37), pack-reused 534[K
Receiving objects: 100% (589/589), 18.64 MiB | 18.63 MiB/s, done.
Resolving deltas: 100% (290/290), done.


In [40]:
# inspired from: https://github.com/openai/CLIP/issues/83
# https://github.com/openai/CLIP/issues/83
import json
import os
import random
import wandb
import clip
from clip import model
import torch
from torch import autograd
import tqdm
from torch import nn, optim
from PIL import Image
from pathlib import Path
from collections import defaultdict
import sys
import argparse
print(os.getcwd())
from volta_src.config import BertConfig
from volta_src.embeddings import BertLayerNorm
from volta_src.encoders import GeLU
from extras import convert_sents_to_features, BertLayer


random.seed(10)
torch.manual_seed(10)

def find_best_matches(text_features, photo_features):
    similarities = (photo_features @ text_features.T).squeeze(1)
    best_photo_idx = (-similarities).argsort()
    similarities = -similarities
    similarities.sort()
    return best_photo_idx, similarities


def convert_models_to_fp32(model):
    for p in model.parameters():
        if p.grad is not None:
            p.data = p.data.float()
            p.grad.data = p.grad.data.float()

class ContextualCLIP(torch.nn.Module):
    def __init__(self, bert_config, args):
        super(ContextualCLIP, self).__init__()
        self.clip, self.preprocess = clip.load('ViT-B/16', device=device, jit=False)
        config = BertConfig.from_dict(bert_config)
        self.fusion = args.fusion
        if self.fusion == 'concat':
            hidden_size = 1024
        else:
            hidden_size = 512

        config.hidden_size =  hidden_size
        config.num_attention_heads = 8
        self.transformer = nn.ModuleList([BertLayer(config) for _ in range(args.transformer_layers)])
        self.transformer.cuda()
        self.prediction_layer = nn.Linear(config.hidden_size, 1).cuda()
        self.batch_size = 1
        self.logit_scale = float(args.logit_scale)
        self.frozen_clip = args.frozen_clip
        self.add_input = args.add_input
        self.positional = args.positional
        if args.positional:
            self.positional_emb = torch.nn.Embedding(10,hidden_size).cuda()

    def forward(self, images, text, pos_mask):
        if self.frozen_clip:
            with torch.no_grad():
                image_features = self.clip.encode_image(images)
                text_features = self.clip.encode_text(text)
        else:
            image_features = self.clip.encode_image(images)
            text_features = self.clip.encode_text(text)
        # normalized features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        text_features = torch.cat(10 * [text_features])
        if self.fusion == 'concat':
            x = torch.cat((image_features, text_features), dim=1)
        else:
            x = (self.logit_scale * image_features) * text_features
        x_ = torch.unsqueeze(x,dim=0)
        if self.positional:
            embs = self.positional_emb(torch.arange(10).cuda())
            embs = embs * pos_mask
            x_pos = x_ + embs
        else:
            x_pos = x_
        attention_mask = torch.ones((self.batch_size,1,1,10)).cuda()
        x = self.transformer[0](x_pos, attention_mask)
        for layer_module in self.transformer[1:]:
            x = layer_module(x, attention_mask) #TODO: remove hard-coding of 10
        if self.add_input:
            x = x + x_
        preds = self.prediction_layer(x.half())
        return preds

    def encode_images(self, photos_batch):
        photos = [Image.open(photo_file) for photo_file in photos_batch]
        photos_preprocessed = torch.stack([self.preprocess(photo) for photo in photos]).to(device)

        with torch.no_grad():
            photos_features = self.clip.encode_image(photos_preprocessed)
            photos_features /= photos_features.norm(dim=-1, keepdim=True)
        return photos_features.cpu().numpy()

    def encode_text(self, search_query):
        with torch.no_grad():
            text_encoded = self.clip.encode_text(clip.tokenize(search_query, truncate=True).to(device))
            text_encoded /= text_encoded.norm(dim=-1, keepdim=True)
        return text_encoded.cpu().numpy()



ARGS = {
    'checkpoint': '/kaggle/input/imagecode-checkpoints/TEMP-CONTEXTUAL_clip_best__36_0.0001_2e-06_ViT-B16_mult_gelu_1000_True_True_True_True_1.0_1.0_2_1252131.pt',
    'test_descr_path': '/kaggle/input/imagecode-simple/valid_simple.json',
    'imgs_path': '/kaggle/input/imagecode-simple/image-sets/image-sets',
    'batchsize': 36,
    'fusion': 'mult',
    'activation': 'gelu',
    'logit_scale': 1000,
    'frozen_clip': True,
    'add_input': True,
    'positional': True,
    'head_scheduler': 1.0,
    'base_scheduler': 1.0,
    'transformer_layers': 2
}
args = argparse.Namespace(**ARGS)
assert args.fusion in ['concat', 'mult']
assert args.activation in ['leaky-relu', 'relu', 'gelu']


img_dirs = args.imgs_path
valid_data = json.load(open(args.test_descr_path, 'r'))
valid = []
for data in valid_data:
    valid.append((data['directory'], data['pos_idx'],data['neg_idx'], data['caption']))
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'DEVICE USED: {device}')

bert_config = json.load(open('vilbert-and-bert-config.json', 'r'))
contextual_clip = ContextualCLIP(bert_config, args)
checkpoint = torch.load(args.checkpoint)
contextual_clip.load_state_dict(checkpoint['model_state_dict'])


if device == "cpu":
    contextual_clip.float()
else:
    clip.model.convert_weights(
        contextual_clip)  # Actually this line is unnecessary since clip by default already on float16


correct = 0
total = 0
vid_correct = 0
vid_total = 0
img_correct= 0
img_total = 0

results = defaultdict(dict)
for img_dir, pos_idx, neg_idx, text in tqdm.tqdm(valid):
    text = [text]
    img_idx = int(pos_idx)
    img_files = list((Path(img_dirs) / img_dir).glob("*.jpg"))
    img_files = sorted(img_files, key=lambda x: int(str(x).split('/')[-1].split('.')[0][3:]))
    # Only a binary case for simple, instead of changing anything, just repeat the same image 5 times to
    # maintain the original positional embeddings from the checkpoints provided
    img_files = [img_files[pos_idx]] * 5 + [img_files[neg_idx]] * 5 
    images = [Image.open(photo_file) for photo_file in img_files]
    images = torch.stack([contextual_clip.preprocess(photo) for photo in images]).to(device)
    text = clip.tokenize(text, truncate=True).to(device)
    if "open-images" in str(img_dir):
        pos_mask = torch.zeros((10,1)).cuda() # Change 10 to 1
    else:
        pos_mask = torch.ones((10,1)).cuda() # Change 10 to 1
    with torch.no_grad():
        logits = contextual_clip(images, text, pos_mask).squeeze()
    pred = torch.argmax(logits).squeeze()
    # Correct is always in the first half now
    if pred.item() < 5:
        correct += 1
    if 'open-images' in img_dir:
        img_total += 1
        if pred.item() < 5:
            img_correct += 1
    else:
        vid_total += 1
        if pred.item() < 5:
            vid_correct += 1     
    total += 1
    results[img_dir].update({f'raw_preds_{img_idx}': logits.squeeze().tolist(), f'clip_pred_{img_idx}': int(pred.item()) ,f'correct_{img_idx}': 1 if img_idx == pred else 0})

print('OVERALL ACC: ' + str(round(correct/len(valid),4)))
print('VIDEO ACC: ' + str(round(vid_correct/vid_total,4)))
print('IMG ACC: ' + str(round(img_correct/img_total,4)))

json.dump(results, open(f'CONTEXTUAL_valid_simple.json', 'w'), indent=2)

/kaggle/working/imagecode/baselines/clip
DEVICE USED: cuda


100%|██████████| 2302/2302 [06:22<00:00,  6.02it/s]

OVERALL ACC: 0.7033
VIDEO ACC: 0.6587
IMG ACC: 0.8977



