In [18]:
import pandas as pd
import numpy as np
import torch
import random
import pickle
from os.path import join
# import sys
# sys.path.insert(0, '/Users/anna/Documents/Code/ads-official/')
# from bert_utils import *
from src.paths import MAIN_DIR_PATH, OUTPUT_FOLDER, TRADE_PATH, DATA_FOLDER
from src.utils import open_json
import json

## Performance on TRADE

### CLIP

In [2]:
clip_outs = pickle.load(open(OUTPUT_FOLDER / 'clip/clip-vit-large-patch14-336_clip_score_trade.pkl', "rb"))['emb_dict']
trade = pd.read_csv(TRADE_PATH)

In [3]:
logits = np.vstack([clip_outs[k]['logits'] for k in clip_outs])
n, _ = logits.shape
print(f'CLIP achieves an accuracy of: {round((logits.argmax(axis=1)==0).sum().item()/n,2)}')

CLIP achieves an accuracy of: 0.35


### ALIGN

In [4]:
align_outs = pickle.load(open(OUTPUT_FOLDER / 'align/align-base_align_score_trade.pkl', "rb"))['emb_dict']

In [5]:
logits = np.vstack([align_outs[k]['logits'] for k in clip_outs])
n, _ = logits.shape
print(f'ALIGN achieves an accuracy of: {round((logits.argmax(axis=1)==0).sum().item()/n,2)}')

ALIGN achieves an accuracy of: 0.28


### ALBEF

In [6]:
embs = pickle.load(open(OUTPUT_FOLDER / 'albef/albef_retrieval_albef_score_trade.pkl', "rb"))
images = np.vstack([embs['emb_dict'][k]['image_embedding'] for k in embs['emb_dict']])
ar = np.vstack([embs['emb_dict'][k]['text_embedding'][0,:] for k in embs['emb_dict']])
d1 = np.vstack([embs['emb_dict'][k]['text_embedding'][1,:] for k in embs['emb_dict']])
d2 = np.vstack([embs['emb_dict'][k]['text_embedding'][2,:] for k in embs['emb_dict']])

In [7]:
n = images.shape[0]
emb_dim = images.shape[1]
cands = np.stack([ar, d1, d2], axis=1)
print(cands.shape)
temperature = 0.0045
end_mat = (np.matmul(images.reshape(n, 1, emb_dim),
          np.transpose(cands, (0,2,1))).reshape(n, 3)/temperature)
print(end_mat.shape)
print(f'ALBEF achieves an accuracy of: {round((np.argmax(end_mat,axis=1)==0).sum()/n,2)}')

(300, 3, 256)
(300, 3)
ALBEF achieves an accuracy of: 0.33


### LiT

In [8]:
embs = torch.load(join(OUTPUT_FOLDER, 'lit/lit_outputs_trade'), map_location=torch.device('cpu'))
print(embs.keys())
images = embs['images'].to(torch.float32)
ar = embs['ar'].to(torch.float32)
d1 = embs['dist1'].to(torch.float32)
d2 = embs['dist2'].to(torch.float32)

dict_keys(['images', 'ar', 'dist1', 'dist2', 'model_checkpoint', 'temperature'])


In [9]:
n = images.size()[0]
emb_dim = images.size()[1]
cands = torch.stack((ar, d1, d2), dim=1)

end_mat = torch.bmm(images.view((n,1,emb_dim)), 
                    cands.transpose(1,2)).view(n,3)
print(end_mat.size())
print(f'LiT achieves an accuracy of: {round((end_mat.argmax(dim=1)==0).sum().item()/n,2)}')

torch.Size([300, 3])
LiT achieves an accuracy of: 0.31


## Performance on TRADE control

In [30]:
def get_variance(ar_mat, im_mat, control_dict):

    ar_mat = torch.tensor(ar_mat.tolist())
    im_mat = torch.tensor(im_mat.tolist())
    n, emb_dim = ar_mat.shape
    acc_arr = torch.zeros(10)

    cands = torch.zeros(10*n, 3, emb_dim)
    for i in range(10):

        d1, d2 = ar_mat[control_dict[f'split_{i+1}']['dist1']], ar_mat[control_dict[f'split_{i+1}']['dist2']]
        curr_cands = torch.stack((ar_mat, d1, d2), dim=1)
        cands[(n*i):(n*(i+1))] = curr_cands
        end_mat = torch.bmm(im_mat.view((n,1,emb_dim)), 
                            curr_cands.transpose(1,2)).view(n,3)
        acc_arr[i] = (end_mat.argmax(dim=1)==0).sum().item()/n

    end_mat = torch.bmm(im_mat.repeat(10,1).view((n*10,1,emb_dim)), 
                            cands.transpose(1,2)).view(n*10,3)
    rank = ((torch.argsort(end_mat, dim=1, descending=True)+1)[:,0]).to(torch.float32)
    print(f'Acc: {round(acc_arr.mean().item(), 2)} ({round(acc_arr.std().item(), 2)})')
    print(f'Rank: {round(rank.mean().item(), 2)} ({round(rank.std().item(), 2)})')
  
    
    # return round(acc_arr.mean().item(),2), round(acc_arr.std().item(),2), round(rank_arr.mean().item(),2), round(std_rank_arr.mean().item(),2)
        

In [28]:
control_dict = open_json(DATA_FOLDER / 'TRADE' / 'trade_control.json')

In [31]:
# CLIP
embs = pickle.load(open(OUTPUT_FOLDER /'clip/clip-vit-large-patch14-336_clip_score_trade.pkl', "rb"))['emb_dict']
images = np.stack([embs[k]['image_embedding'].flatten() for k in embs])
ar = np.stack([embs[k]['text_embedding'][0].flatten() for k in embs])
get_variance(ar, images, control_dict)


Acc: 0.98 (0.01)
Rank: 1.03 (0.21)


In [32]:
# ALIGN
embs = pickle.load(open(OUTPUT_FOLDER /'align/align-base_align_score_trade.pkl', "rb"))['emb_dict']
images = np.stack([embs[k]['image_embedding'].flatten() for k in embs])
ar = np.stack([embs[k]['text_embedding'][0].flatten() for k in embs])
get_variance(ar, images, control_dict)


Acc: 0.97 (0.01)
Rank: 1.04 (0.26)


In [33]:
# ALBEF
embs = pickle.load(open(OUTPUT_FOLDER / 'albef/albef_retrieval_albef_score_trade.pkl', "rb"))['emb_dict']
images = np.vstack([embs[k]['image_embedding'] for k in embs])
ar = np.vstack([embs[k]['text_embedding'][0,:] for k in embs])
get_variance(ar, images, control_dict)


Acc: 0.87 (0.01)
Rank: 1.19 (0.53)


In [35]:
# LiT

embs = torch.load(join(OUTPUT_FOLDER, 'lit/lit_outputs_trade'), map_location=torch.device('cpu'))

images = embs['images'].to(torch.float32)
ar = embs['ar'].to(torch.float32)
get_variance(ar, images, control_dict)


Acc: 0.82 (0.02)
Rank: 1.26 (0.6)
