In [1]:
%load_ext autoreload
%autoreload 2
%pylab inline
%matplotlib inline

import os
os.chdir('../')
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

Populating the interactive namespace from numpy and matplotlib


In [2]:
import argparse
import torch
import json
import pickle
import random
import time
import collections
import tqdm
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, ConcatDataset
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from collections import ChainMap

from horovod import torch as hvd

from uniter_model.data import ImageLmdbGroup
from transformers.tokenization_bert import BertTokenizer

from dvl.options import default_params, add_itm_params, add_logging_params, parse_with_config
from dvl.data.itm import itm_fast_collate
from dvl.models.bi_encoder import BiEncoder, setup_for_distributed_mode, load_biencoder_checkpoint
from dvl.utils import print_args, num_of_parameters, is_main_process, get_model_encoded_vecs, retrieve_query, display_img
from dvl.trainer import build_dataloader, load_dataset
from dvl.indexer.faiss_indexers import DenseFlatIndexer

from GLOBAL_VARIABLES import PROJECT_FOLDER

    
def train_parser(parser):
    default_params(parser)
    add_itm_params(parser)
    add_logging_params(parser)
    parser.add_argument('--teacher_checkpoint', default=None, type=str, help="")
    return parser


def display_res(res, args):
    f, axarr = plt.subplots(2, 5, figsize=(30, 15))
    i, j = 0, 0
    for name in res[0][0][:10]:
        img = mpimg.imread(args.img_meta[name]['img_file'])
        axarr[i, j].set_axis_off()

        axarr[i, j].margins(x=0, y=0)
        axarr[i, j].imshow(img)
        j += 1
        if j == 5:
            i += 1
            j = 0

In [3]:
data_name, Full, EVAL = 'coco', True, False

if data_name == 'flickr':
    cmd = '--config ./config/flickr30k_eval_config.json '\
          '--biencoder_checkpoint  /good_models/flickr_two-stream-add/biencoder.last.pt ' \
          '--teacher_checkpoint /pretrain/uniter_teacher_flickr.pt ' \
          '--img_meta /db/meta/flickr_meta.json'
    if Full:
        txt_dbs = [
            "/db/itm_flickr30k_train_base-cased.db",
            "/db/itm_flickr30k_val_base-cased.db",
            "/db/itm_flickr30k_test_base-cased.db",
        ]
        img_dbs = [
            "/img/flickr30k/",
            "/img/flickr30k/",
            "/img/flickr30k/",
        ]
    else:
        txt_db, img_db = '/db/itm_flickr30k_test_base-cased.db', '/img/flickr30k/'
else:
    cmd = '--config ./config/coco_eval_config.json '\
          '--biencoder_checkpoint  /good_models/coco_two-stream-add/biencoder.last.pt ' \
          '--teacher_checkpoint /pretrain/uniter_teacher_coco.pt ' \
          '--img_meta /db/meta/coco_meta.json'
    if Full:
        txt_dbs = [
            "/db/itm_coco_train_base-cased.db",
            "/db/itm_coco_restval_base-cased.db",
            "/db/itm_coco_val_base-cased.db",
            "/db/itm_coco_test_base-cased.db"
        ]
        img_dbs = [
            "/img/coco_train2014/",
            "/img/coco_val2014",
            "/img/coco_val2014/",
            "/img/coco_val2014/"
        ]

    else:
        txt_db, img_db = '/db/itm_coco_test_base-cased.db', '/img/coco_val2014'


parser = argparse.ArgumentParser()
parser = train_parser(parser)
args = parse_with_config(parser, cmd.split())

# options safe guard
if args.conf_th == -1:
    assert args.max_bb + args.max_txt_len + 2 <= 512
else:
    assert args.num_bb + args.max_txt_len + 2 <= 512

hvd.init()
torch.cuda.set_device(hvd.local_rank())
args.device = torch.device("cuda", hvd.local_rank())
args.local_rank = hvd.rank()
args.n_gpu = hvd.size()
args.vector_size = 768
args.tokenizer = BertTokenizer.from_pretrained(args.txt_model_config)
print_args(args)

with open(args.itm_global_file) as f:
    args.img_meta = json.load(f)

EMBEDDED_FILE = os.path.join(os.path.dirname(args.biencoder_checkpoint), data_name + '.' + ('full' if Full else 'test') + '.pkl')

# Init Model
bi_encoder = BiEncoder(args, args.fix_img_encoder, args.fix_txt_encoder, project_dim=args.project_dim)
load_biencoder_checkpoint(bi_encoder, args.biencoder_checkpoint)

img_model, txt_model = bi_encoder.img_model, bi_encoder.txt_model
img_model.to(args.device)
txt_model.to(args.device)

img_model, _ = setup_for_distributed_mode(img_model, None, args.device, args.n_gpu, -1, args.fp16, args.fp16_opt_level)
img_model.eval()

txt_model, _ = setup_for_distributed_mode(txt_model, None, args.device, args.n_gpu, -1, args.fp16, args.fp16_opt_level)
txt_model.eval()

loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /home/jjteam/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1
 **************** CONFIGURATION **************** 
biencoder_checkpoint           -->   /good_models/coco_two-stream-add/biencoder.last.pt
caption_score_weight           -->   0.0
cls_concat                     -->   
compressed_db                  -->   False
conf_th                        -->   0.2
config                         -->   ./config/coco_eval_config.json
device                         -->   cuda:0
expr_name_prefix               -->   
fix_img_encoder                -->   False
fix_txt_encoder                -->   False
fp16                           -->   True
fp16_opt_level                 -->   O1
gradient_accumulation_steps    -->   1
hard_negatives_sampling        -->   none
hnsw_index                

****************************************************************************************************
loading txt model
loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /home/jjteam/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391
Model config {
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,

Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights       

BertEncoder(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(28996, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)


In [4]:
if not os.path.isfile(EMBEDDED_FILE):
    # Load Data
    print('embedded file', EMBEDDED_FILE, 'not exist, creating one...')
    FILE_MAPPER = {
        'train': [args.train_txt_dbs, args.train_img_dbs, True],
        'dev': [args.val_txt_db, args.val_img_db, False],
        'test': [args.test_txt_db, args.test_img_db, False]
    }
    all_img_dbs = ImageLmdbGroup(args.conf_th, args.max_bb, args.min_bb, args.num_bb, args.compressed_db)

    if Full:
        dataset = load_dataset(all_img_dbs, txt_dbs, img_dbs, args, True)
        for d in dataset.datasets:
            d.new_epoch()
        dataloader = build_dataloader(dataset, itm_fast_collate, False, args, batch_size=512)
        img2txt = dict(collections.ChainMap(*[json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in txt_dbs]))
    else:
        dataset = load_dataset(all_img_dbs, txt_db, img_db, args, is_train=False)
        dataset.new_epoch()
        dataloader = build_dataloader(dataset, itm_fast_collate, False, args, batch_size=512)
        img2txt = dict(collections.ChainMap(*[json.load(open(os.path.join(db_folder, 'img2txts.json'))) for db_folder in [txt_db]]))

    print(f'dataset len = {len(dataset)}, dataloader len = {len(dataloader)}')

    img_embedding = dict()
    caption_embedding = dict()
    labels_img_name = []
    embeds = get_model_encoded_vecs(bi_encoder, dataloader)
    with open(EMBEDDED_FILE, 'wb') as f:
        pickle.dump(embeds, f)
else:
    print('embedded file found, loading...')
    with open(EMBEDDED_FILE, 'rb') as f:
        embeds = pickle.load(f)

embedded file found, loading...


In [5]:
indexer = DenseFlatIndexer(args.vector_size)   # modify in future
indexer.index_data(list(embeds['img_embed'].items()))
all_keys = indexer.index_id_to_db_id

if EVAL:
    recall = {1: 0, 5: 0, 10: 0}
    counter_query = 0
    for key in tqdm.tqdm(all_keys):
        for q in args.img_meta[key]['annotation']:
            res = retrieve_query(bi_encoder, q, indexer, args)
            for top in recall:
                recall[top] += key in res[0][0][:top]
            counter_query += 1

    for top in recall:
        print(recall[top] / counter_query)

Total data indexed 123287


In [None]:
searched_queries = []
for key in tqdm.tqdm(all_keys):
    for q in args.img_meta[key]['annotation']:
        if 'generous' in q:
            searched_queries.append(q)

print(len(searched_queries))
for i in range(10):
    if i < len(searched_queries):
        print(searched_queries[i])

In [None]:
query = 'blue girl boy ball'
print('query =', query)
input_ids = args.tokenizer.encode(query)
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0)
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
_, query_vector, _ = bi_encoder.txt_model(input_ids, None, attn_mask, pos_ids)
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
display_res(res, args)

In [None]:
while True:
    random_key = random.sample(all_keys, 1) 
    if len(args.img_meta[random_key[0]]['annotation']) > 0:
        query = random.sample(args.img_meta[random_key[0]]['annotation'], 1)[0]
        break

img = mpimg.imread(args.img_meta[random_key[0]]['img_file'])

fig, ax = plt.subplots()
ax.set_axis_off()
ax.margins(x=0, y=0)
ax.imshow(img)
print('query =', query)

input_ids = args.tokenizer.encode(query)
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0)
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
_, query_vector, _ = bi_encoder.txt_model(input_ids, None, attn_mask, pos_ids)
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
display_res(res, args)

In [None]:
print('query =', query)
input_ids = args.tokenizer.encode(query)
input_ids = torch.LongTensor(input_ids).to(args.device).unsqueeze(0)
attn_mask = torch.ones(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
pos_ids = torch.arange(len(input_ids[0]), dtype=torch.long, device=args.device).unsqueeze(0)
_, query_vector, _ = bi_encoder.txt_model(input_ids, None, attn_mask, pos_ids)
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
display_res(res, args)

In [None]:
res = indexer.search_knn(query_vector.detach().cpu().numpy(), 100)
for name, score in zip(res[0][0][:top], res[0][1][:top]):
    display(args.imddg_meta, name)
    print('='*100)