In [3]:
import argparse
import json
import logging
import os
import random
from io import open
import math
import sys

from time import gmtime, strftime
from timeit import default_timer as timer

import numpy as np
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange

import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert import BertModel

from multimodal_bert.datasets import ConceptCapLoaderTrain, ConceptCapLoaderVal
from multimodal_bert.multi_modal_bert import BertForMultiModalPreTraining, BertConfig
import matplotlib.pyplot as plt
import PIL
%matplotlib inline  


05/23/2019 12:36:01 - INFO - multimodal_bert.multi_modal_bert -   Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex .


In [4]:
from types import SimpleNamespace

args = SimpleNamespace(validation_file="data/conceptual_caption/validation",
                       pretrained_weight= "save/3layer_4connection/pytorch_model_6.bin",
                       bert_model="bert-base-uncased",
                       output_dir='save',
                       config_file="config/3layer_4connection.json",
                       max_seq_length=36,
                       train_batch_size=1,
                       do_lower_case=True,
                       predict_feature=False,
                       seed=42,
                       num_workers=0,
                       from_pretrained=True,
                       baseline=False,
                       img_weight=1,
                      )

if args.baseline:
    from pytorch_pretrained_bert.modeling import BertConfig
    from multimodal_bert.bert import BertForMultiModalPreTraining
else:
    from multimodal_bert.multi_modal_bert import BertForMultiModalPreTraining, BertConfig


In [5]:
config = BertConfig.from_json_file(args.config_file)
tokenizer = BertTokenizer.from_pretrained(
    args.bert_model, do_lower_case=args.do_lower_case
)


validation_dataset = ConceptCapLoaderVal(
    args.validation_file,
    tokenizer,
    seq_len=args.max_seq_length,
    batch_size=args.train_batch_size,
    predict_feature=args.predict_feature,
    num_workers=args.num_workers,
)

if args.predict_feature:
    config.v_target_size = 2048
    config.predict_feature = True
else:
    config.v_target_size = 1601
    config.predict_feature = False


05/23/2019 12:36:01 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/jiasen/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


Loading from data/conceptual_caption/validation_feat_all.lmdb
[32m[0523 12:36:01 @format.py:92][0m Found 6130 entries in data/conceptual_caption/validation_feat_all.lmdb


In [6]:
if args.from_pretrained:
    model = BertForMultiModalPreTraining.from_pretrained(args.pretrained_weight, config)
else:
    model = BertForMultiModalPreTraining(config)
    
model.eval()
model.cuda()

05/23/2019 12:36:02 - INFO - multimodal_bert.multi_modal_bert -   loading archive file save/3layer_4connection/pytorch_model_6.bin
05/23/2019 12:36:02 - INFO - multimodal_bert.multi_modal_bert -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "bi_attention_type": 1,
  "bi_hidden_size": 1024,
  "bi_intermediate_size": 3072,
  "bi_num_attention_heads": 16,
  "fast_mode": false,
  "fixed_t_layer": 0,
  "fixed_v_layer": 0,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.2,
  "hidden_size": 768,
  "in_batch_pairs": false,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "predict_feature": false,
  "t_biattention_id": [
    9,
    10,
    11,
    12
  ],
  "type_vocab_size": 2,
  "v_attention_probs_dropout_prob": 0.1,
  "v_biattention_id": [
    0,
    1,
    2,
    3
  ],
  "v_feature_size": 2048,
  "v_hidden_act": "gelu",
  "v_hidden_dropout_prob": 0.2,
  "v_hidden_size": 2048,

model's option for predict_feature is  False


RuntimeError: CUDA error: out of memory

In [None]:
caption_path = "data/conceptual_caption/caption_val.json"
captions = json.load(open(caption_path, 'r'))

for step, batch in enumerate(validation_dataset):
    image_id = batch[-1]
    batch = tuple(t.cuda() for t in list(batch)[:-1])
    
    input_ids, input_mask, segment_ids, lm_label_ids, is_next, image_feat, image_loc, image_target, image_label, image_mask = (
        batch
    )

    prediction_scores_t, prediction_scores_v, seq_relationship_score, all_attention_mask = model(
        input_ids,
        image_feat,
        image_loc,
        segment_ids,
        input_mask,
        image_mask,
        output_all_attention_masks=True
        )
    all_attention_mask_t, all_attnetion_mask_v, all_attention_mask_c = all_attention_mask
    
    # visualization of the attention mask

    idx = 0
    image_id = image_id[0]
    print(image_id)

    image_path = 'data/conceptual_caption/validation/%s' %image_id

    img = PIL.Image.open(image_path).convert('RGB')
    img = torch.tensor(np.array(img))
    plt.imshow(img)
    plt.show()
    width, height,_ = img.shape
    num_box = image_mask.sum().item()

    image_loc = image_loc.squeeze(0)[:num_box,:4]

#     print(input_ids.cpu().numpy())
#     print(captions[image_id])
    print(tokenizer.convert_ids_to_tokens(input_ids.cpu().numpy()[0]))
    
    words = tokenizer.convert_ids_to_tokens(input_ids.cpu().numpy()[0])
    
    attention_mask = all_attention_mask_c[0][0][0]
    print(attention_mask.shape)
    
#     for i in range(16):
#         aimg = bottomup_heatmap_image(img,attention_mask[i][7], image_loc)
#         plt.imshow(aimg)
#         plt.show()
    print(step)
    if step == 2 :
        break
            

    

In [7]:
# def bottomup_heatmap_image(img, att, spatial, cm=plt.get_cmap('viridis')):
#     global hi1
#     # color the background according to the colormap because everything's too small to see
#     img[:] = 255.
#     H, W, _ = img.shape
#     n_head, _ = att.shape
#     n_obj = spatial.shape[0]
#     att_img = torch.zeros(img.shape)
#     left =   (spatial[:, 0] * W).to(torch.int).clamp(0, W-1)
#     top =    (spatial[:, 1] * H).to(torch.int).clamp(0, H-1)
#     right =  (spatial[:, 2] * W).to(torch.int).clamp(0, W-1)
#     bottom = (spatial[:, 3] * H).to(torch.int).clamp(0, H-1)
# #     for i in range(n_head):
#     if True:
#     i = 1
#         for k in range(n_obj):
#             t, b, l, r = top[k], bottom[k], left[k], right[k]
#             # this version is just black and white without using a color map
#             att_img[t:b, l:r] += att[i,k] * img[t:b, l:r].to(att.dtype)
#     att_img = att_img[:, :, 0].detach().numpy() / 255
#     att_img = 255 * torch.tensor(cm(att_img))[:, :, :3]
#     return att_img.to(img.dtype)

In [8]:
def bottomup_heatmap_image(img, att, spatial, cm=plt.get_cmap('viridis')):
    global hi1
    # color the background according to the colormap because everything's too small to see
    img[:] = 255.
    H, W, _ = img.shape
    n_obj, _ = spatial.shape
    att_img = torch.zeros(img.shape)
    left =   (spatial[:, 0] * W).to(torch.int).clamp(0, W-1)
    top =    (spatial[:, 1] * H).to(torch.int).clamp(0, H-1)
    right =  (spatial[:, 2] * W).to(torch.int).clamp(0, W-1)
    bottom = (spatial[:, 3] * H).to(torch.int).clamp(0, H-1)
    for k in range(n_obj):
        t, b, l, r = top[k], bottom[k], left[k], right[k]
        # this version is just black and white without using a color map
        att_img[t:b, l:r] += att[k] * img[t:b, l:r].to(att.dtype)
    att_img = att_img[:, :, 0].detach().numpy() / 255
    att_img = 255 * torch.tensor(cm(att_img))[:, :, :3]
    return att_img.to(img.dtype)

In [9]:
# directly print the weight

for j in range(len(all_attention_mask_c)):
    attention_mask =  all_attention_mask_c[j][0][0]
#     for i in range(input_mask.sum()):
    if True:
        i = 10
        print(words[i])
        attention = attention_mask[:,i,:]
        fig, (ax1, ax2) = plt.subplots(2, 1)
        # c = ax.pcolor(Z)
        ax1.imshow(attention.cpu().detach().numpy())
        ax1.axis('off')
        fig.tight_layout()

        max_over_head, _ = torch.max(attention, dim=1)
        _, idx = torch.max(max_over_head, dim=0)


        aimg = bottomup_heatmap_image(img, attention[idx], image_loc)
        ax2.imshow(aimg)
        ax2.axis('off')
        plt.show()


NameError: name 'all_attention_mask_c' is not defined

In [None]:
def crop(img, spatial):
    patches = []
    H, W, _ = img.shape
    left =   (spatial[:, 0] * W).to(torch.int).clamp(0, W-1)
    top =    (spatial[:, 1] * H).to(torch.int).clamp(0, H-1)
    right =  (spatial[:, 2] * W).to(torch.int).clamp(0, W-1)
    bottom = (spatial[:, 3] * H).to(torch.int).clamp(0, H-1)
    for i in range(spatial.size(0)):
        patches.append(img[left[i]:right[i], top[i]:bottom[i]])
    
    return patches


In [None]:
# crop the image based on the bounding box

# print(image_loc)
# patches = crop(img, image_loc)

# num = len(patches)
# for i, patch in enumerate(patches):
#     f = plt.figure()
#     ax1 = f.add_subplot(1,num, i+1)
#     plt.axis('off')
#     plt.imshow(patches[i])
#     ax2 = f.add_subplot(1,num, i+1)
#     plt.axis('off')
#     plt.imshow(patches[i])
#     # asp = np.diff(ax2.get_xlim())[0]/np.diff(ax2.get_ylim())[0]
#     # ax2.set_aspect(asp)

# #print(image_loc.shape[0])
# # figs, axses = plt.subplots(1, image_loc.shape[0])#(1, image_loc.size(0), sharex='col', sharey='row',
# #                         #gridspec_kw={'hspace': 0, 'wspace': 0})

# # for i, axs in enumerate(axses):
# #     axs.imshow(patches[i])
    