In [15]:
import argparse
import json
import logging
import os
import random
from io import open
import math
import sys

from time import gmtime, strftime
from timeit import default_timer as timer

import numpy as np
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange

import torch
from torch.utils.data import DataLoader, Dataset, RandomSampler
from torch.utils.data.distributed import DistributedSampler

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert import BertModel

from multimodal_bert.datasets import ConceptCapLoaderTrain, ConceptCapLoaderVal
from multimodal_bert.multi_modal_bert import BertForMultiModalPreTraining, BertConfig
import matplotlib.pyplot as plt


In [2]:
from types import SimpleNamespace

args = SimpleNamespace(validation_file="data/conceptual_caption/validation",
                       pretrained_weight= "save/3layer_4connection/pytorch_model_1.bin",
                       bert_model="bert-base-uncased",
                       output_dir='save',
                       config_file="config/3layer_4connection.json",
                       max_seq_length=36,
                       train_batch_size=10,
                       do_lower_case=True,
                       predict_feature=False,
                       seed=42,
                       num_workers=0,
                       from_pretrained=True,
                       baseline=False,
                       img_weight=1,
                      )

if args.baseline:
    from pytorch_pretrained_bert.modeling import BertConfig
    from multimodal_bert.bert import BertForMultiModalPreTraining
else:
    from multimodal_bert.multi_modal_bert import BertForMultiModalPreTraining, BertConfig


In [3]:
config = BertConfig.from_json_file(args.config_file)
tokenizer = BertTokenizer.from_pretrained(
    args.bert_model, do_lower_case=args.do_lower_case
)


validation_dataset = ConceptCapLoaderVal(
    args.validation_file,
    tokenizer,
    seq_len=args.max_seq_length,
    batch_size=args.train_batch_size,
    predict_feature=args.predict_feature,
    num_workers=args.num_workers,
)

if args.predict_feature:
    config.v_target_size = 2048
    config.predict_feature = True
else:
    config.v_target_size = 1601
    config.predict_feature = False
    
if args.from_pretrained:
    model = BertForMultiModalPreTraining.from_pretrained(args.pretrained_weight, config)
else:
    model = BertForMultiModalPreTraining(config)
    
model.eval()
model.cuda()

05/15/2019 21:38:28 - INFO - pytorch_pretrained_bert.tokenization -   loading vocabulary file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-vocab.txt from cache at /home/jiasen/.pytorch_pretrained_bert/26bc1ad6c0ac742e9b52263248f6d0f00068293b33709fae12320c0e35ccfbbb.542ce4285a40d23a559526243235df47c5f75c197f04f37d1a0c124c32c9a084


Loading from data/conceptual_caption/validation_feat_all.lmdb
[32m[0515 21:38:28 @format.py:92][0m Found 6130 entries in data/conceptual_caption/validation_feat_all.lmdb


05/15/2019 21:38:28 - INFO - multimodal_bert.multi_modal_bert -   loading archive file save/3layer_4connection/pytorch_model_1.bin
05/15/2019 21:38:28 - INFO - multimodal_bert.multi_modal_bert -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "bi_attention_type": 1,
  "bi_hidden_size": 1024,
  "bi_intermediate_size": 3072,
  "bi_num_attention_heads": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "predict_feature": false,
  "t_biattention_id": [
    9,
    10,
    11,
    12
  ],
  "type_vocab_size": 2,
  "v_attention_probs_dropout_prob": 0.1,
  "v_biattention_id": [
    0,
    1,
    2,
    3
  ],
  "v_feature_size": 2048,
  "v_hidden_act": "gelu",
  "v_hidden_dropout_prob": 0.1,
  "v_hidden_size": 2048,
  "v_initializer_range": 0.2,
  "v_intermediate_size": 3072,
  "v_num_attention_heads": 16,


model's option for predict_feature is  False


BertForMultiModalPreTraining(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (v_embeddings): BertImageEmbeddings(
      (image_location_embeddings): Linear(in_features=5, out_features=2048, bias=True)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1)
            )
            (output): BertSelfOutput(
       

In [16]:
for step, batch in enumerate(validation_dataset):
    batch = tuple(t.cuda() for t in batch)
    
    break
    
input_ids, input_mask, segment_ids, lm_label_ids, is_next, image_feat, image_loc, image_target, image_label, image_mask, image_ids = (
    batch
)

prediction_scores_t, prediction_scores_v, seq_relationship_score, all_attention_mask = model(
    input_ids,
    image_feat,
    image_target,
    image_loc,
    segment_ids,
    input_mask,
    image_mask,
    output_all_attention_masks=True
    )
all_attention_mask_t, all_attnetion_mask_v, all_attention_mask_c = all_attention_mask

In [17]:
# visualization of the attention mask

idx = 0
image_id = image_ids.long()[idx].item()
print(image_id)

image_path = 'data/conceptual_caption/validation/%d' %image_id


plt.imread(image_path)


1595581184


FileNotFoundError: [Errno 2] No such file or directory: 'data/conceptual_caption/validation/1595581184'

tensor([1595581184, 1506493312, 2841214720, 2703412992, 2227185920, 1465580928,
        3570136576, 1666482304, 1372388736, 3727110400], device='cuda:0')