## Clone and Dataset

In [1]:
!git clone https://github.com/YiwuZhong/Sub-GC.git 

Cloning into 'Sub-GC'...
remote: Enumerating objects: 265, done.[K
remote: Counting objects: 100% (78/78), done.[K
remote: Compressing objects: 100% (71/71), done.[K
remote: Total 265 (delta 10), reused 44 (delta 6), pack-reused 187[K
Receiving objects: 100% (265/265), 125.17 MiB | 9.54 MiB/s, done.
Resolving deltas: 100% (36/36), done.


In [7]:
%cd Sub-GC

/home/jaleed/Jaleed/SubGC/Sub-GC


- Copy all contents of '/home/jaleed/Jaleed/ModelsNDatasets/SubGC/data/' to 'data/'
- Copy all contents of '/home/jaleed/Jaleed/ModelsNDatasets/SubGC/pretrained/' to 'pretrained/'
- Copy all contents of '/home/jaleed/Jaleed/ModelsNDatasets/SubGC/misc/' to 'misc/'

## Code Edits

In [10]:
%%writefile models/AttModel.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F
from misc.utils import obj_edge_vectors, load_word_vectors
from torch.nn.utils.rnn import PackedSequence, pack_padded_sequence, pad_packed_sequence
import numpy as np

from models.CaptionModel import CaptionModel
import models.lib.gcn_backbone as GBackbone
import models.lib.gpn as GPN

def sort_pack_padded_sequence(input, lengths):
    sorted_lengths, indices = torch.sort(lengths, descending=True)
    tmp = pack_padded_sequence(input[indices], sorted_lengths.cpu(), batch_first=True)
    inv_ix = indices.clone()
    inv_ix[indices] = torch.arange(0,len(indices)).type_as(inv_ix)
    return tmp, inv_ix

def pad_unsort_packed_sequence(input, inv_ix):
    tmp, _ = pad_packed_sequence(input, batch_first=True)
    tmp = tmp[inv_ix]
    return tmp

def pack_wrapper(module, att_feats, att_masks):
    """
    for batch computation, pack sequences with different lenghth with explicit setting the batch size at each time step
    """
    if att_masks is not None:
        packed, inv_ix = sort_pack_padded_sequence(att_feats, att_masks.data.long().sum(1))
        return pad_unsort_packed_sequence(PackedSequence(module(packed[0]), packed[1]), inv_ix)
    else:
        return module(att_feats)

"""
Captioning model using image scene graph
"""
class AttModel(CaptionModel):
    def __init__(self, opt):
        super(AttModel, self).__init__()
        self.vocab_size = opt.vocab_size
        self.input_encoding_size = opt.input_encoding_size
        self.rnn_size = opt.rnn_size 
        self.num_layers = opt.num_layers  
        self.drop_prob_lm = opt.drop_prob_lm 
        self.seq_length = opt.max_length or opt.seq_length 
        self.fc_feat_size = opt.fc_feat_size
        self.att_feat_size = opt.att_feat_size 
        self.att_hid_size = opt.att_hid_size 
        self.use_bn = opt.use_bn 
        self.ss_prob = opt.sampling_prob 
        
        self.gpn = True if opt.use_gpn == 1 else False 
        self.embed_dim = opt.embed_dim 
        self.GCN_dim = opt.gcn_dim  
        self.noun_fuse = True if opt.noun_fuse == 1 else False  
        self.pred_emb_type = opt.pred_emb_type 
        self.GCN_layers = opt.gcn_layers 
        self.GCN_residual = opt.gcn_residual  
        self.GCN_use_bn = False if opt.gcn_bn == 0 else True   

        self.test_LSTM = False if getattr(opt, 'test_LSTM', 0) == 0 else True 
        self.topk_sampling = False if getattr(opt, 'use_topk_sampling', 0) == 0 else True
        self.topk_temp = getattr(opt, 'topk_temp', 0.6)
        self.the_k = getattr(opt, 'the_k', 3)
        self.sct = False if getattr(opt, 'sct', 0) == 0 else True # show-control-tell testing mode

        # feature fusion layer
        self.obj_v_proj = nn.Linear(self.att_feat_size, self.GCN_dim)
        object_names = np.load(opt.obj_name_path,encoding='latin1') # [0] is 'background'
        self.sg_obj_cnt = object_names.shape[0]
        if self.noun_fuse:
            embed_vecs = obj_edge_vectors(list(object_names), wv_dim=self.embed_dim)
            self.sg_obj_embed = nn.Embedding(self.sg_obj_cnt, self.embed_dim)
            self.sg_obj_embed.weight.data = embed_vecs.clone()
            self.obj_emb_proj = nn.Linear(self.embed_dim, self.GCN_dim)
            self.relu = nn.ReLU(inplace=True)
        predicate_names = np.load(opt.rel_name_path,encoding='latin1') # [0] is 'background'
        self.sg_pred_cnt = predicate_names.shape[0]
        p_embed_vecs = obj_edge_vectors(list(predicate_names), wv_dim=self.embed_dim)
        self.sg_pred_embed = nn.Embedding(predicate_names.shape[0], self.embed_dim)
        self.sg_pred_embed.weight.data = p_embed_vecs.clone()
        self.pred_emb_prj = nn.Linear(self.embed_dim, self.GCN_dim)

        # GCN backbone
        self.gcn_backbone = GBackbone.gcn_backbone(GCN_layers=self.GCN_layers, GCN_dim=self.GCN_dim, \
                                                   GCN_residual=self.GCN_residual, GCN_use_bn=self.GCN_use_bn)

        # GPN (sGPN)
        if self.gpn:
            self.gpn_layer = GPN.gpn_layer(GCN_dim=self.GCN_dim, hid_dim=self.att_hid_size, \
                                           test_LSTM=self.test_LSTM, use_nms=False if self.sct else True, \
                                           iou_thres=getattr(opt, 'gpn_nms_thres', 0.75), \
                                           max_subgraphs=getattr(opt, 'gpn_max_subg', 1), \
                                           use_sGPN_score=True if getattr(opt, 'use_gt_subg', 0) == 0 else False)
        else:
            self.read_out_proj = nn.Sequential(nn.Linear(self.GCN_dim, self.att_hid_size), nn.Linear(self.att_hid_size,self.GCN_dim*2))
            nn.init.constant_(self.read_out_proj[0].bias, 0)
            nn.init.constant_(self.read_out_proj[1].bias, 0)
        
        # projection layers in attention-based LSTM
        self.logit = nn.Linear(self.rnn_size, self.vocab_size + 1)
        self.embed = nn.Sequential(nn.Embedding(self.vocab_size + 1, self.input_encoding_size),
                                nn.ReLU(),
                                nn.Dropout(self.drop_prob_lm))
        self.fc_embed = nn.Sequential(nn.Linear(self.att_feat_size, self.fc_feat_size),
                                    nn.ReLU(),
                                    nn.Linear(self.fc_feat_size, self.rnn_size),
                                    nn.ReLU(),
                                    nn.Dropout(self.drop_prob_lm))
        self.att_embed = nn.Sequential(*(
                                    ((nn.BatchNorm1d(self.att_feat_size),) if self.use_bn else ())+
                                    (nn.Linear(self.GCN_dim, self.rnn_size),
                                    nn.ReLU(),
                                    nn.Dropout(self.drop_prob_lm))+
                                    ((nn.BatchNorm1d(self.rnn_size),) if self.use_bn==2 else ())))
        self.ctx2att = nn.Linear(self.rnn_size, self.att_hid_size)        

    def _forward(self, fc_feats, att_feats, seq, att_masks=None, trip_pred=None, obj_dist=None, obj_box=None, rel_ind=None, \
                 pred_fmap=None, pred_dist=None, gpn_obj_ind=None, gpn_pred_ind=None, gpn_nrel_ind=None,gpn_pool_mtx=None):
        """
        Model feedforward: input scene graph features and sub-graph indices, output token probabilities
        fusion layers --> GCN backbone --> GPN (sGPN) --> attention-based LSTM
        """
        # fuse features (visual, embedding) for each node in graph
        att_feats, pred_fmap = self.feat_fusion(obj_dist, att_feats, pred_dist)
        b = att_feats.size(0); N = att_feats.size(1); K = rel_ind.size(1); L = self.GCN_dim

        # GCN backbone (will expand feats to 5 counterparts)
        att_feats, x_pred = self.gcn_backbone(b,N,K,L,att_feats, obj_dist, pred_fmap, rel_ind)
        b = att_feats.size(0) # has expanded to 5 counterparts

        # sGPN
        if self.gpn:
            gpn_loss, subgraph_score, att_feats, fc_feats, att_masks = \
                self.gpn_layer(b,N,K,L,gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind,gpn_pool_mtx,att_feats,x_pred,fc_feats,att_masks)
        else: # no gpn module, baseline model with full scene graph
            gpn_loss = None
            subgraph_score = None

            # mean pooling, wo global img feats
            read_out = torch.mean(att_feats,1).detach()  # mean pool over full scene graph
            fc_feats = self.read_out_proj(read_out) 

            att_masks = att_masks[:,0,0]
            att_masks[:,:36].fill_(1.0).float()  
        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)
        outputs = fc_feats.new_zeros(batch_size, seq.size(1) - 1, self.vocab_size+1)
        
        # Prepare the features for attention-based LSTM
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        for i in range(seq.size(1) - 1):
            if self.training and i >= 1 and self.ss_prob > 0.0: # otherwiste no need to sample
                sample_prob = fc_feats.new(batch_size).uniform_(0, 1)
                sample_mask = sample_prob < self.ss_prob
                if sample_mask.sum() == 0:
                    it = seq[:, i].clone()
                else:
                    sample_ind = sample_mask.nonzero().view(-1)
                    it = seq[:, i].data.clone()
                    prob_prev = torch.exp(outputs[:, i-1].detach()) # fetch prev distribution: shape Nx(M+1)
                    it.index_copy_(0, sample_ind, torch.multinomial(prob_prev, 1).view(-1).index_select(0, sample_ind))
            else:
                it = seq[:, i].clone()          
            # break if all the sequences end
            if i >= 1 and seq[:, i].sum() == 0:
                break

            output, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)
            outputs[:, i] = output  # output is probability after log_softmax at current time step, sized [batch, self.vocab_size+1]
        
        return outputs, gpn_loss, subgraph_score

    def _sample_sentences(self, fc_feats, att_feats, att_masks=None, trip_pred=None, obj_dist=None, obj_box=None, rel_ind=None, \
                                pred_fmap=None, pred_dist=None, gpn_obj_ind=None, gpn_pred_ind=None, gpn_nrel_ind=None,gpn_pool_mtx=None, opt={}):
        """
        Model inference / sentence decoding: generate captions with beam size > 1
        """
        # fuse features (visual, embedding) for each node in graph
        att_feats, pred_fmap = self.feat_fusion(obj_dist, att_feats, pred_dist)
        b = att_feats.size(0); N = att_feats.size(1); K = rel_ind.size(1); L = self.GCN_dim
        
        # GCN backbone
        att_feats, x_pred = self.gcn_backbone(b,N,K,L,att_feats, obj_dist, pred_fmap, rel_ind)
        b = att_feats.size(0) # has expanded to 5 counterparts
        
        # GPN
        if self.gpn:
            gpn_loss, subgraph_score, att_feats, fc_feats, att_masks, keep_ind = \
                self.gpn_layer(b,N,K,L,gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind,gpn_pool_mtx,att_feats,x_pred,fc_feats,att_masks)
        else: # no gpn module, baseline model that use full graph
            gpn_loss = None
            att_feats = att_feats[0:1] # use one of 5 counterparts

            read_out = torch.mean(att_feats,1)  # mean pool over full scene graph
            fc_feats = self.read_out_proj(read_out) 

            att_masks = att_masks[0:1,0,0] 
            att_masks[:,:36].fill_(1.0).float()
            keep_ind = torch.arange(att_feats.size(0)).type_as(gpn_obj_ind)  
            subgraph_score = torch.arange(att_feats.size(0)).fill_(1.0).type_as(att_feats)

        beam_size = opt.get('beam_size', 10)
        batch_size = fc_feats.size(0)
        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        seq = torch.LongTensor(self.seq_length, batch_size).zero_()
        seqLogprobs = torch.FloatTensor(self.seq_length, batch_size)

        self.done_beams = [[] for _ in range(batch_size)]
        for k in range(batch_size):
            state = self.init_hidden(beam_size)
            tmp_fc_feats = p_fc_feats[k:k+1].expand(beam_size, p_fc_feats.size(1))
            tmp_att_feats = p_att_feats[k:k+1].expand(*((beam_size,)+p_att_feats.size()[1:])).contiguous()
            tmp_p_att_feats = pp_att_feats[k:k+1].expand(*((beam_size,)+pp_att_feats.size()[1:])).contiguous()
            tmp_att_masks = p_att_masks[k:k+1].expand(*((beam_size,)+p_att_masks.size()[1:])).contiguous() if att_masks is not None else None

            for t in range(1):
                if t == 0: # input <bos>
                    it = fc_feats.new_zeros([beam_size], dtype=torch.long)

                logprobs, state = self.get_logprobs_state(it, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, state)
            
            self.done_beams[k] = self.beam_search(state, logprobs, tmp_fc_feats, tmp_att_feats, tmp_p_att_feats, tmp_att_masks, None, None, opt=opt)
            seq[:, k] = self.done_beams[k][0]['seq'] # the first beam has highest cumulative score
            seqLogprobs[:, k] = self.done_beams[k][0]['logps']

        # return the samples and their log likelihoods
        return seq.transpose(0, 1), seqLogprobs.transpose(0, 1), subgraph_score, keep_ind 

    def _sample(self, fc_feats, att_feats, att_masks=None, trip_pred=None, obj_dist=None, obj_box=None, rel_ind=None, \
                pred_fmap=None, pred_dist=None, gpn_obj_ind=None, gpn_pred_ind=None, gpn_nrel_ind=None,gpn_pool_mtx=None, opt={}):
        """
        Model inference / sentence decoding: generate captions with beam size == 1 (disabling beam search)
        """
        sample_max = opt.get('sample_max', 1)
        beam_size = opt.get('beam_size', 1)
        return_att = True if opt.get('return_att', 0) == 1 else False

        if beam_size > 1:
            return self._sample_sentences(fc_feats, att_feats, att_masks, trip_pred, obj_dist, obj_box, rel_ind, pred_fmap, pred_dist, \
                                      gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind,gpn_pool_mtx, opt)
        
        # fuse features (visual, embedding) for each node in graph
        att_feats, pred_fmap = self.feat_fusion(obj_dist, att_feats, pred_dist)
        b = att_feats.size(0); N = att_feats.size(1); K = rel_ind.size(1); L = self.GCN_dim
        
        # GCN backbone
        att_feats, x_pred = self.gcn_backbone(b,N,K,L,att_feats, obj_dist, pred_fmap, rel_ind)
        b = att_feats.size(0) # has expanded to 5 counterparts
        
        # GPN
        if self.gpn:
            gpn_loss, subgraph_score, att_feats, fc_feats, att_masks, keep_ind = \
                self.gpn_layer(b,N,K,L,gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind,gpn_pool_mtx,att_feats,x_pred,fc_feats,att_masks)
        else: # no gpn module, baseline model that use full graph
            gpn_loss = None
            att_feats = att_feats[0:1] # use one of 5 counterparts

            read_out = torch.mean(att_feats,1)  # mean pool over full scene graph
            fc_feats = self.read_out_proj(read_out) 

            att_masks = att_masks[0:1,0,0] 
            att_masks[:,:36].fill_(1.0).float()
            keep_ind = torch.arange(att_feats.size(0)).type_as(gpn_obj_ind)  
            subgraph_score = torch.arange(att_feats.size(0)).fill_(1.0).type_as(att_feats)

        batch_size = fc_feats.size(0)
        state = self.init_hidden(batch_size)

        p_fc_feats, p_att_feats, pp_att_feats, p_att_masks = self._prepare_feature(fc_feats, att_feats, att_masks)

        seq = fc_feats.new_zeros((batch_size, self.seq_length), dtype=torch.long)
        seqLogprobs = fc_feats.new_zeros(batch_size, self.seq_length)
        att2_weights = []
        
        for t in range(self.seq_length + 1):
            if t == 0: # input <bos>
                it = fc_feats.new_zeros(batch_size, dtype=torch.long)
            if return_att:
                logprobs, state, att_weight = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state, return_att=True)
                att2_weights.append(att_weight)
            else:
                logprobs, state = self.get_logprobs_state(it, p_fc_feats, p_att_feats, pp_att_feats, p_att_masks, state)

            # sample the next word
            if t == self.seq_length: # skip if we achieve maximum length
                break

            if self.topk_sampling:  # sample top-k word from a re-normalized probability distribution
                logprobs = F.log_softmax(logprobs / float(self.topk_temp), dim=1)
                tmp = torch.empty_like(logprobs).fill_(float('-inf'))
                topk, indices = torch.topk(logprobs, self.the_k, dim=1)
                tmp = tmp.scatter(1, indices, topk)
                logprobs = tmp
                # sample the word index according to log probability (negative values)
                it = torch.distributions.Categorical(logits=logprobs.data).sample() # logits: log(probability) are negative values
                sampleLogprobs = logprobs.gather(1, it.unsqueeze(1)) # gather the logprobs at sampled positions 
            else:                
                if sample_max: # True (greedy decoding)
                    sampleLogprobs, it = torch.max(logprobs.data, 1)
                    it = it.view(-1).long()               

            # stop when all finished, unfinished: 0 or 1
            if t == 0:
                unfinished = it > 0
            else:
                unfinished = unfinished * (it > 0)
            it = it * unfinished.type_as(it)
            seq[:,t] = it
            seqLogprobs[:,t] = sampleLogprobs.view(-1)
            # early quit loop if all sequences have finished
            if unfinished.sum() == 0:
                break
      
        if return_att:
            # attention weights [b,20+1,N]
            att2_weights = torch.cat([_.unsqueeze(1) for _ in att2_weights], 1)
            return seq, seqLogprobs, subgraph_score, keep_ind, att2_weights
        else:
            return seq, seqLogprobs, subgraph_score, keep_ind 

    def get_logprobs_state(self, it, fc_feats, att_feats, p_att_feats, att_masks, state, sg_emb=None, p_sg_emb=None,return_att=False):
        """
        Attention-based LSTM feedforward
        """
        xt = self.embed(it) # 'it' contains a word index
        
        if return_att:
            output, state, att_weight = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks, return_att=return_att)
            logprobs = F.log_softmax(self.logit(output), dim=1)
            return logprobs, state, att_weight
        else:
            output, state = self.core(xt, fc_feats, att_feats, p_att_feats, state, att_masks)
            logprobs = F.log_softmax(self.logit(output), dim=1)
            return logprobs, state

    def init_hidden(self, bsz):
        weight = self.logit.weight if hasattr(self.logit, "weight") else self.logit[0].weight
        return (weight.new_zeros(self.num_layers, bsz, self.rnn_size),
                weight.new_zeros(self.num_layers, bsz, self.rnn_size))

    def clip_att(self, att_feats, att_masks):
        # Clip the length of att_masks and att_feats to the maximum length
        if att_masks is not None:
            max_len = att_masks.data.long().sum(1).max()
            att_feats = att_feats[:, :max_len].contiguous()
            att_masks = att_masks[:, :max_len].contiguous()
        return att_feats, att_masks

    def _prepare_feature(self, fc_feats, att_feats, att_masks, sg_emb=None):
        """
        Project features and prepare for the inputs of attention-based LSTM
        """
        att_feats, att_masks = self.clip_att(att_feats, att_masks)

        # embed fc and att feats
        fc_feats = self.fc_embed(fc_feats)
        att_feats = pack_wrapper(self.att_embed, att_feats, att_masks) # pack sequences with different length
        # Project the attention feats first to reduce memory and computation comsumptions.
        p_att_feats = self.ctx2att(att_feats)
        
        return fc_feats, att_feats, p_att_feats, att_masks

    def feat_fusion(self, obj_dist, att_feats, pred_dist):
        """
        Fuse visual and word embedding features for nodes and edges
        """
        # fuse features (visual, embedding) for each node in graph
        if self.noun_fuse: # Sub-GC
            obj_emb = self.obj_emb_proj(self.sg_obj_embed(obj_dist.view(-1, self.sg_obj_cnt)[:,1:].max(1)[1] + 1)).view(obj_dist.size(0), obj_dist.size(1), self.GCN_dim)
            att_feats = self.obj_v_proj(att_feats)
            att_feats = self.relu(att_feats + obj_emb)
        else: # GCN-LSTM baseline that use full graph
            att_feats = self.obj_v_proj(att_feats)
        
        if self.pred_emb_type == 1: # hard emb, not including background
            pred_emb = self.sg_pred_embed(pred_dist.view(-1, self.sg_pred_cnt)[:,1:].max(1)[1] + 1)
        elif self.pred_emb_type == 2: # hard emb, including background
            pred_emb = self.sg_pred_embed(pred_dist.view(-1, self.sg_pred_cnt).max(1)[1])
        pred_fmap = self.pred_emb_prj(pred_emb).view(pred_dist.size(0), pred_dist.size(1), self.GCN_dim) 
        return att_feats, pred_fmap

"""
Attention-based LSTM
"""
class TopDownCore(nn.Module):
    def __init__(self, opt, use_maxout=False):
        super(TopDownCore, self).__init__()
        self.drop_prob_lm = opt.drop_prob_lm
        self.attention = Attention(opt)
        self.att_lstm = nn.LSTMCell(opt.input_encoding_size + opt.rnn_size * 2, opt.rnn_size) 
        self.lang_lstm = nn.LSTMCell(opt.rnn_size * 2, opt.rnn_size) 

    def forward(self, xt, fc_feats, att_feats, p_att_feats, state, att_masks=None, sg_emb=None, p_sg_emb=None,return_att=False):
        """
        prev_h: h_lang output of previous language LSTM
        fc_feats: vector after pooling over K regions, one vector per image 
        xt: embedding of previous word
        att_feats: packed region features 
        p_att_feats: projected [packed region features]
        h_att, c_att: hidden state and cell state of attention LSTM
        h_lang, c_lang: hidden state and cell state of language LSTM
        """
        prev_h = state[0][-1]
        att_lstm_input = torch.cat([prev_h, fc_feats, xt], 1)

        h_att, c_att = self.att_lstm(att_lstm_input, (state[0][0], state[1][0])) # the 2nd arg is from previous att_lstm
        
        # attended region features
        if return_att:
            att, att_weight = self.attention(h_att, att_feats, p_att_feats, att_masks, return_att=return_att) 
        else:
            att = self.attention(h_att, att_feats, p_att_feats, att_masks)

        lang_lstm_input = torch.cat([att, h_att], 1)

        h_lang, c_lang = self.lang_lstm(lang_lstm_input, (state[0][1], state[1][1])) # the 2nd arg is from previous lang_lstm

        output = F.dropout(h_lang, self.drop_prob_lm, self.training)
        state = (torch.stack([h_att, h_lang]), torch.stack([c_att, c_lang])) 
        
        if return_att:
            return output, state, att_weight
        else:
            return output, state

"""
Attention module in attention-based LSTM
"""
class Attention(nn.Module):

    def __init__(self, opt):
        super(Attention, self).__init__()
        self.rnn_size = opt.rnn_size 
        self.att_hid_size = opt.att_hid_size 
        self.h2att = nn.Linear(self.rnn_size, self.att_hid_size)
        self.alpha_net = nn.Linear(self.att_hid_size, 1)

    def forward(self, h, att_feats, p_att_feats, att_masks=None, return_att=False):
        """
        Input hidden state and region features, output the attended visual features
        """
        # The p_att_feats here is already projected
        att_size = att_feats.numel() // att_feats.size(0) // att_feats.size(-1)
        att = p_att_feats.view(-1, att_size, self.att_hid_size)
        
        att_h = self.h2att(h)                        # [batch,512]
        att_h = att_h.unsqueeze(1).expand_as(att)            # [batch, K, 512]
        dot = att + att_h                                   # [batch, K, 512]
        dot = torch.tanh(dot) #F.tanh(dot)                  # [batch, K, 512]
        dot = dot.view(-1, self.att_hid_size)               # [(batch * K), 512]
        dot = self.alpha_net(dot)                           # [(batch * K), 1]
        dot = dot.view(-1, att_size)                        # [batch, K]
        
        weight = F.softmax(dot, dim=1)                             # [batch, K]
        if att_masks is not None:  # necessary since empty box proposals (att_mask) may exist
            weight = weight * att_masks.view(-1, att_size).float()
            weight = weight / weight.sum(1, keepdim=True) # normalize to 1
        att_feats_ = att_feats.view(-1, att_size, att_feats.size(-1)) # [batch, K, 1000]
        att_res = torch.bmm(weight.unsqueeze(1), att_feats_).squeeze(1) # [batch, 1000]

        if return_att:
            return att_res, weight
        else:
            return att_res

"""
Captioning model wrapper
"""
class TopDownModel(AttModel):
    def __init__(self, opt):
        super(TopDownModel, self).__init__(opt)
        self.num_layers = 2
        self.core = TopDownCore(opt)


Overwriting models/AttModel.py


In [20]:
%%writefile dataloaders/dataloader_test.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import h5py
import os
import numpy as np
import random

import torch
import torch.utils.data as data

class HybridLoader:
    """
    If db_path is a director, then use normal file loading
    The loading method depend on extention.
    """
    def __init__(self, db_path, ext):
        self.db_path = db_path
        self.ext = ext
        if self.ext == '.npy':
            self.loader = lambda x: np.load(x,encoding='latin1')
        else:
            if "sg" or "graph" in db_path:
                self.loader = lambda x: np.load(x,allow_pickle=True,encoding='latin1')['feat'].tolist()
            else:
                self.loader = lambda x: np.load(x,encoding='latin1')['feat']
        self.db_type = 'dir'
    
    def get(self, key):
        f_input = os.path.join(self.db_path, key + self.ext)

        # load image
        feat = self.loader(f_input)

        return feat

class pklLoader:
    def __init__(self, db_path):
        self.db_path = db_path
        self.loader = lambda x: np.load(x,encoding='latin1')['feat']

class DataLoader(data.Dataset):

    def reset_iterator(self, split):
        del self._prefetch_process[split]
        self._prefetch_process[split] = BlobFetcher(split, self, split=='single') ### single
        self.iterators[split] = 0

    def get_vocab_size(self):
        return self.vocab_size

    def get_vocab(self):
        return self.ix_to_word

    def get_seq_length(self):
        return self.seq_length

    def __init__(self, opt):
        self.opt = opt
        self.batch_size = self.opt.batch_size
        self.seq_per_img = opt.seq_per_img 

        """""""""

         test_id

        """""""""

        self.test_id = opt.test_id

        # load the json file which contains additional information about the dataset
        print('DataLoader loading json file: ', opt.input_json)
        self.info = json.load(open(self.opt.input_json))
        self.ix_to_word = self.info['ix_to_word']
        self.vocab_size = len(self.ix_to_word)
        print('vocab size is ', self.vocab_size)
        
        # open the hdf5 file
        print('DataLoader loading h5 file: ', opt.input_label_h5)
        self.h5_label_file = h5py.File(self.opt.input_label_h5, 'r', driver='core')

        if 'flickr' in opt.input_label_h5:
            dataset_name = 'flickr30k' 
        elif 'coco' in opt.input_label_h5:
            dataset_name = 'COCO'

        use_MRNN_split = opt.use_MRNN_split 
        mask_version = '1000'
        #self.thres = opt.gpn_label_thres 
        #self.use_gt_subg = opt.use_gt_subg  
        self.trip_loader = HybridLoader("data/{}_sg_output_64".format(dataset_name), ".npz") 
        self.subgraph_mask = HybridLoader("data/{}_graph_mask_{}_rm_duplicate".format(dataset_name,mask_version), ".npz") 

        # load in the sequence data
        self.label = self.h5_label_file['labels'][:]
        seq_size = self.h5_label_file['labels'].shape
        self.seq_length = seq_size[1]
        print('max sequence length in data is', self.seq_length)
        # load the pointers in full to RAM (should be small enough)
        self.label_start_ix = self.h5_label_file['label_start_ix'][:]
        self.label_end_ix = self.h5_label_file['label_end_ix'][:]

        self.num_images = self.label_start_ix.shape[0]
        print('read %d image features' %(self.num_images))
        
        """""""""""""""

         Single Image

        """""""""""""""
        #self.split_ix = {'single': []}
        self.split_ix = {'train': [], 'val': [], 'test': [], 'single':[]}

        for t_m in range(len(self.info['images'])):
            img_t = self.info['images'][t_m]
            if img_t["id"] == self.test_id:
                self.split_ix['single'].append(t_m)     ## append complete dict of test_id

        
        if use_MRNN_split:
            MRNN_split_dict = np.load('data/MRNN_split_dict.npy',allow_pickle=True,encoding='latin1').tolist()
            for ix in range(len(self.info['images'])):
                img = self.info['images'][ix]
                if MRNN_split_dict[img['id']] == 'train':
                    self.split_ix['train'].append(ix)
                elif MRNN_split_dict[img['id']] == 'val': 
                    self.split_ix['val'].append(ix)
                elif MRNN_split_dict[img['id']] == 'test':
                    self.split_ix['test'].append(ix)
                elif opt.train_only == 0: # restval
                    self.split_ix['train'].append(ix)
        else:
            for ix in range(len(self.info['images'])):
                img = self.info['images'][ix]
                if img['split'] == 'train': #
                    self.split_ix['train'].append(ix)
                elif img['split'] == 'val': #
                    self.split_ix['val'].append(ix)
                elif img['split'] == 'test': #
                    self.split_ix['test'].append(ix)
                elif opt.train_only == 0: # restval
                    self.split_ix['train'].append(ix)  

        print('assigned %d images to split train' %len(self.split_ix['train']))
        print('assigned %d images to split val' %len(self.split_ix['val']))
        print('assigned %d images to split test' %len(self.split_ix['test']))
        print('assigned %d image to split single' %len(self.split_ix['single']))

        self.iterators = {'train': 0, 'val': 0, 'test': 0,'single': 0}
        #self.iterators = {'single': 0}
        
        self._prefetch_process = {} # The three prefetch process
        for split in self.iterators.keys():
            self._prefetch_process[split] = BlobFetcher(split, self, split=='single')      ## single
            # Terminate the child process when the parent exists
        def cleanup():
            print('Terminating BlobFetcher')
            for split in self.iterators.keys():
                del self._prefetch_process[split]
        import atexit
        atexit.register(cleanup)

        self.obj_num = opt.obj_num 
        self.rel_num = opt.rel_num 
        self.half_mini_batch = None 

    def get_captions(self, ix, seq_per_img):
        # fetch the sequence labels
        ix1 = self.label_start_ix[ix] - 1 #label_start_ix starts from 1
        ix2 = self.label_end_ix[ix] - 1
        ncap = ix2 - ix1 + 1 # number of captions available for this image
        assert ncap > 0, 'an image does not have any label. this can be handled but right now isn\'t'

        if ncap < seq_per_img:
            # we need to subsample (with replacement)
            seq = np.zeros([seq_per_img, self.seq_length], dtype = 'int')
            for q in range(seq_per_img):
                ixl = random.randint(ix1,ix2)
                seq[q, :] = self.label[ixl, :self.seq_length]
        else:
            ixl = ix1 #random.randint(ix1, ix2 - seq_per_img + 1)
            seq = self.label[ixl: ixl + seq_per_img, :self.seq_length]
        return seq

    def get_batch(self, split, batch_size=None, seq_per_img=None):
        batch_size = batch_size or self.batch_size
        seq_per_img = seq_per_img or self.seq_per_img

        fc_batch = [] 
        att_batch = [] 
        label_batch = [] 
        trip_pred_batch = []
        obj_dist_batch = []
        obj_box_batch = []
        rel_ind_batch = []
        pred_fmap_batch = []
        pred_dist_batch = []

        wrapped = False
        data = {}
        
        tmp_list = self._prefetch_process[split].get()  # call one time to get a whole batch instead of fetching one by one instance
        wrapped = tmp_list[-1]
        tmp_list = tmp_list[:-1]

        # merge features
        data['trip_pred'] = None 
        data['pred_fmap'] = None

        tmp_fc, tmp_att, tmp_object_dist, tmp_rel_ind, tmp_pred_dist, tmp_label, tmp_masks, tmp_ix, \
        tmp_gpn_obj_ind, tmp_gpn_pred_ind, tmp_gpn_nrel_ind, tmp_att_mask, tmp_gpn_pool_mtx, tmp_this_mini_batch = tmp_list
        data['fc_feats'] = tmp_fc.view(-1, 2048)
        data['att_feats'] = tmp_att.view(-1, self.obj_num, 2048)
        data['obj_dist'] = tmp_object_dist.view(-1, self.obj_num, 1599)
        data['rel_ind']= tmp_rel_ind.view(-1, self.rel_num, 2)
        data['pred_dist'] = tmp_pred_dist.view(-1, self.rel_num, 21)
        data['labels'] = tmp_label.view(-1, self.seq_length + 2)
        data['masks'] = tmp_masks.view(-1, self.seq_length + 2)
        data['att_masks'] = tmp_att_mask.view(-1,2,tmp_this_mini_batch, self.obj_num)
        data['gpn_obj_ind'] = tmp_gpn_obj_ind.view(-1,2,tmp_this_mini_batch,self.obj_num)
        data['gpn_pred_ind'] = tmp_gpn_pred_ind.view(-1,2,tmp_this_mini_batch,self.rel_num)
        data['gpn_nrel_ind'] = tmp_gpn_nrel_ind.view(-1,2,tmp_this_mini_batch,self.rel_num,2)
        data['gpn_pool_mtx'] = tmp_gpn_pool_mtx.view(-1,2,tmp_this_mini_batch,self.obj_num,self.obj_num)
        data['obj_box'] = None 

        # batch data not in pin_memory, which stays as list
        data['gts'] = [] # all ground truth captions of each images
        data['infos'] = []
        for ix in tmp_ix.view(-1).numpy():
            data['gts'].append(self.label[self.label_start_ix[ix] - 1: self.label_end_ix[ix]])
            data['infos'].append({'ix':ix, 'id':self.info['images'][ix]['id'], 'file_path':self.info['images'][ix]['file_path']})

        data['bounds'] = {'it_pos_now': self.iterators[split], 'it_max': len(self.split_ix[split]), 'wrapped': wrapped}

        return data

    def __getitem__(self, index):
        ix = index
        this_ix = np.array([index])
        img_id = self.info['images'][ix]['id']

        ############### load subgraph mini-batch for each sentence  ###############
        # 1. sentence id within 5 sentences
        subgraph_dict = self.subgraph_mask.get(str(img_id)) 
        this_mini_batch = int(subgraph_dict['node_iou_mtx'][:,5:].shape[1] / 2)
        mask_idx = np.full((self.seq_per_img,this_mini_batch,2),-1)
        # original pos are occupied by first part of sub-graphs; original neg occupied by second part of sub-graphs
        mask_idx[:,:,0] = np.repeat(np.arange(this_mini_batch).reshape(1,-1), self.seq_per_img, axis=0)
        mask_idx[:,:,1] = mask_idx[:,:,0] + this_mini_batch
        # Note: shift back to original index which includes sentence noun subgraph
        mask_idx = mask_idx + 5

        # get the mask of mini-batch
        mask_info = subgraph_dict['subgraph_mask_list']
        gpn_obj_ind = np.full((self.seq_per_img,2,this_mini_batch,self.obj_num),self.obj_num-1)
        gpn_att_mask = np.full((self.seq_per_img,2,this_mini_batch,self.obj_num), 0).astype('float32') 
        gpn_pred_ind = np.full((self.seq_per_img,2,this_mini_batch,self.rel_num),self.rel_num-1)
        gpn_nrel_ind = np.full((self.seq_per_img,2,this_mini_batch,self.rel_num,2), self.obj_num-1)
        gpn_pool_mtx = np.zeros((self.seq_per_img,2,this_mini_batch,self.obj_num,self.obj_num)).astype('float32') 
        
        for i in range(self.seq_per_img):
            for k in range(this_mini_batch):
                # obj
                tmp = mask_info[mask_idx[i,k,0]][1].nonzero()[0]
                if tmp.shape[0] != 0:
                    gpn_obj_ind[i,0,k,:tmp.shape[0]] = tmp # pos obj
                else:
                    print("error: no object node in this sub-graph!")
                gpn_att_mask[i,0,k,:tmp.shape[0]] = 1 # pos obj used in LSTM
                gpn_pool_mtx[i,0,k,np.arange(tmp.shape[0]),np.arange(tmp.shape[0])] = 1 # sparse scatter mtx

                tmp = mask_info[mask_idx[i,k,1]][1].nonzero()[0]
                if tmp.shape[0] != 0:
                    gpn_obj_ind[i,1,k,:tmp.shape[0]] = tmp # neg obj
                else:
                    print("error: no object node in this sub-graph!")
                gpn_att_mask[i,1,k,:tmp.shape[0]] = 1 # pos obj used in LSTM
                gpn_pool_mtx[i,1,k,np.arange(tmp.shape[0]),np.arange(tmp.shape[0])] = 1 # sparse scatter mtx

                # predicate
                tmp = mask_info[mask_idx[i,k,0]][2].nonzero()[0]
                if tmp.shape[0] != 0:
                    gpn_pred_ind[i,0,k,:tmp.shape[0]] = tmp # pos pred
                tmp = mask_info[mask_idx[i,k,1]][2].nonzero()[0]
                if tmp.shape[0] != 0:
                    gpn_pred_ind[i,1,k,:tmp.shape[0]] = tmp # neg pred

                # new rel ind
                tmp = mask_info[mask_idx[i,k,0]][3]
                if tmp.shape[0] != 0:
                    gpn_nrel_ind[i,0,k,:tmp.shape[0]] = tmp
                tmp = mask_info[mask_idx[i,k,1]][3]
                if tmp.shape[0] != 0:
                    gpn_nrel_ind[i,1,k,:tmp.shape[0]] = tmp
        ############### load subgraph mini-batch for each sentence  ###############   
        
        ############### load full SG with dummy node and dummpy predicate/rel_ind  ###############
        # 2. object related data 
        sg_output = self.trip_loader.get(str(img_id))
        object_fmap = sg_output['object_fmap'][:self.obj_num,:] 
        object_dist = sg_output['object_dist'][:self.obj_num,:]

        pad_object_fmap = np.full((1, self.obj_num, object_fmap.shape[1]), 0).astype('float32') # pad with the dummy/empty node
        pad_object_dist = np.concatenate((np.ones((1, self.obj_num, 1)), np.zeros((1, self.obj_num, object_dist.shape[1]-1))), axis=2).astype('float32')
        fc_feat = np.full((1,object_fmap.shape[1]), 0).astype('float32')

        pad_object_fmap[0,:(self.obj_num - 1),:] = object_fmap 
        pad_object_dist[0,:(self.obj_num - 1),:] = object_dist 

        # 3. predicate related data
        pred_dist = sg_output['pred_dist'] 
        rel_ind = sg_output['rel_ind']
        pad_rel_ind = np.full((1, self.rel_num, rel_ind.shape[1]), self.obj_num-1) # pad the rel_ind with the dummy/empty node index
        pad_pred_dist = np.concatenate((np.ones((1, self.rel_num, 1)), np.zeros((1, self.rel_num, pred_dist.shape[1]-1))), axis=2).astype('float32')

        this_len = min(rel_ind.shape[0], self.rel_num - 1)
        pad_pred_dist[0, :this_len,:] = pred_dist[:this_len] 
        pad_rel_ind[0, :this_len,:] = rel_ind[:this_len]        
        ############### load full SG with dummy node and dummpy predicate/rel_ind  ###############

        label = np.zeros([self.seq_per_img, self.seq_length + 2], dtype = 'int64')
        label[:, 1 : self.seq_length + 1] = self.get_captions(ix, self.seq_per_img)
        nonzeros = np.array(list(map(lambda x: (x != 0).sum()+2, label)))
        mask_batch = np.zeros([label.shape[0], self.seq_length + 2], dtype = 'float32')
        for idx, row in enumerate(mask_batch):
            row[:nonzeros[idx]] = 1  # keep the 'start' + sentence + 'end', and mask out the rest

        # return as tmp in BlobFetcher.get()
        return fc_feat, pad_object_fmap, pad_object_dist, pad_rel_ind, pad_pred_dist, label, mask_batch, this_ix, \
        gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_att_mask, gpn_pool_mtx, np.array([this_mini_batch])

    def __len__(self):
        return len(self.info['images'])

class SubsetSampler(torch.utils.data.sampler.Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.
    Arguments:
        indices (list): a list of indices
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return (self.indices[i] for i in range(len(self.indices)))

    def __len__(self):
        return len(self.indices)

def _init_fn(worker_id):  
    # worker seed
    pass
    #np.random.seed(2019)

class BlobFetcher():
    """Experimental class for prefetching blobs in a separate process."""
    def __init__(self, split, dataloader, if_shuffle=False):
        """
        db is a list of tuples containing: imcrop_name, caption, bbox_feat of gt box, imname
        """
        self.split = split
        self.dataloader = dataloader
        self.if_shuffle = if_shuffle
        self.batch_size = dataloader.batch_size

    # Add more in the queue
    def reset(self):
        """
        Two cases for this function to be triggered:
        1. not hasattr(self, 'split_loader'): Resume from previous training. Create the dataset given the saved split_ix and iterator
        2. wrapped: a new epoch, the split_ix and iterator have been updated in the get_minibatch_inds already.
        """
        # batch_size is 1, the merge is done in DataLoader class
        self.split_loader = iter(data.DataLoader(dataset=self.dataloader,
                                            batch_size=self.batch_size,  # should same as the number in ri_next = ri + self.batch_size
                                            sampler=SubsetSampler(self.dataloader.split_ix[self.split][self.dataloader.iterators[self.split]:]),
                                            shuffle=False,
                                            pin_memory=True,
                                            worker_init_fn=_init_fn,
                                            num_workers=self.dataloader.opt.num_workers))
    
    def _get_next_minibatch_inds(self):
        max_index = len(self.dataloader.split_ix[self.split])
        wrapped = False
        last_batch = False

        if self.split == 'train' or self.split == 'val' :  # batch size >= 1, drop last batch
            ri = self.dataloader.iterators[self.split]  # count of images
            ix = self.dataloader.split_ix[self.split][ri]  # the index for train/val/test in the image list
            
            ri_next = ri + self.batch_size # should same as the number in "batch_size=self.batch_size,"
            if ri_next >= max_index:
                ri_next = 0
                if self.if_shuffle:
                    random.shuffle(self.dataloader.split_ix[self.split])
                wrapped = True
            
            self.dataloader.iterators[self.split] = ri_next  # shadow #data loaded by the dataloader 
            
            if wrapped is False and ri_next + self.batch_size >= max_index: # the next wrapped will be True, then current batch becomes last batch to be used
                last_batch = True
        elif self.split == 'test' or self.split == 'single':  # batch size = 1, include all data
            ri = self.dataloader.iterators[self.split]  # count of images
            
            ri_next = ri + self.batch_size # should same as the number in "batch_size=self.batch_size,"
            if ri_next > max_index:
                ri_next = 0
                if self.if_shuffle:
                    random.shuffle(self.dataloader.split_ix[self.split])
                wrapped = True
            
            self.dataloader.iterators[self.split] = ri_next  # shadow #data loaded by the dataloader 
            
            if wrapped is False and ri_next + self.batch_size > max_index: # the next wrapped will be True, then current batch becomes last batch to be used
                last_batch = True
        else:
            assert False, "\nMode isn't correct! \n"

        return ri_next, wrapped, last_batch #ix, wrapped
    
    def get(self):
        if not hasattr(self, 'split_loader'):
            self.reset()
        
        ix, wrapped, last_batch = self._get_next_minibatch_inds()
        
        if wrapped:  # drop the final incomplete batch
            self.reset()  # self.dataloader.iterators[self.split] has been reset to 0 before call self.reset(); enter the new epoch
            ix, wrapped, last_batch = self._get_next_minibatch_inds()  # shadow #data loaded by the dataloader 
            tmp = self.split_loader.next()
        else:
            tmp = self.split_loader.next()  # shadow #data loaded by the dataloader

        #assert tmp[-1][2] == ix, "ix not equal"
        # return to self._prefetch_process[split].get() in Dataloader.get_batch()

        if last_batch:  # last batch
            wrapped = True

        return tmp + [wrapped]

Overwriting dataloaders/dataloader_test.py


In [34]:
%%writefile misc/eval_utils.py


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np
import json
import random
import time
import os
import sys
import misc.utils as utils
import math
from collections import defaultdict
from misc.sentence_utils import *
from misc.grd_utils import *

random_seed = 2019
np.random.seed(random_seed)
random.seed(random_seed)

def eval_split(model, crit, loader, eval_kwargs={}, opt=None, val_model=None):
    '''
    This function contains 2 branches: 
    1. model inference: validation or testing
    2. evaluation for input sentences
    '''
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
    os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings) # a global configuration
    
    # grounding experiments
    return_att_weight = True if eval_kwargs.get('return_att', 0) == 1 else False
    if return_att_weight:
        assert beam_size == 1, "GVD repo only supports grounding evaluation with beam size as 1"
        gvd_all_dict = np.load('data/gvd_all_dict.npy', allow_pickle=True,encoding='latin1').tolist()
        ind_to_wd = gvd_all_dict['ind_to_wd']
        wd_to_lemma = gvd_all_dict['wd_to_lemma']
        lemma_det_id_dict = gvd_all_dict['lemma_det_id_dict']
        det_id_to_det_wd = gvd_all_dict['det_id_to_det_wd']
        grd_output = defaultdict(list)
        model_path = eval_kwargs['infos_path'].split('/')
        consensus_rerank_file = model_path[0] + '/' + model_path[1] + '/consensus_rerank_ind.npy'
        grd_sGPN_consensus = True if os.path.isfile(consensus_rerank_file) else False

    # controllability experiments
    sct_mode = True if eval_kwargs.get('sct', 0) == 1 else False

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []        
    
    # 1. run model in inference mode
    if model is not None:
        model.eval()
        loader.reset_iterator(split)
        while True:
            data = loader.get_batch(split)
            n = n + loader.batch_size
            
            if data.get('labels', None) is not None and verbose_loss: # model validation
                tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['trip_pred'],\
                      data['obj_dist'], data['obj_box'], data['rel_ind'], data['pred_fmap'], data['pred_dist'],\
                      data['gpn_obj_ind'], data['gpn_pred_ind'], data['gpn_nrel_ind'],data['gpn_pool_mtx']]
                tmp = [_.cuda() if _ is not None else _ for _ in tmp]
                fc_feats, att_feats, labels, masks, att_masks, trip_pred, obj_dist, obj_box, rel_ind, pred_fmap, pred_dist,\
                gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx = tmp

                with torch.no_grad():
                    lang_output, _, _ = model(fc_feats, att_feats, labels, att_masks, trip_pred,\
                               obj_dist, obj_box, rel_ind, pred_fmap, pred_dist, gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx)
                    loss = crit(lang_output, labels[:,1:], masks[:,1:]).item()
                loss_sum += loss  # only use validation loss
                loss_evals += 1
            else: # model testing
                tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['trip_pred'],\
                      data['obj_dist'], data['obj_box'], data['rel_ind'], data['pred_fmap'], data['pred_dist'],\
                      data['gpn_obj_ind'], data['gpn_pred_ind'], data['gpn_nrel_ind'], data['gpn_pool_mtx']]
                tmp = [_.cuda() if _ is not None else _ for _ in tmp]
                fc_feats, att_feats, labels, masks, att_masks, trip_pred, obj_dist, obj_box, rel_ind, pred_fmap, pred_dist,\
                gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx = tmp
                
                # send all subgraphs of a image to generate sentences
                with torch.no_grad():
                    if return_att_weight:  # grounding experiments
                        seqq, seqLogprobs, subgraph_score, keep_nms_ind, att_weights = model(fc_feats, att_feats, att_masks, trip_pred,\
                                   obj_dist, obj_box, rel_ind, pred_fmap, pred_dist, gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx,\
                                   opt=eval_kwargs, mode='sample')
                    else:
                        seqq, seqLogprobs, subgraph_score, keep_nms_ind = model(fc_feats, att_feats, att_masks, trip_pred,\
                                   obj_dist, obj_box, rel_ind, pred_fmap, pred_dist, gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx,\
                                   opt=eval_kwargs, mode='sample')
                    if not sct_mode:
                        if model.gpn: # sub-graph captioning model
                            sorted_score, sort_ind = torch.sort(subgraph_score,descending=True)
                            seq = seqq[sort_ind].data
                            subgraph_score = sorted_score.data
                            sorted_subgraph_ind = keep_nms_ind[sort_ind] # the indices are to index sub-graph in original order
                        else: # model that use full graph
                            sort_ind = torch.arange(subgraph_score.size(0)).type_as(keep_nms_ind)
                            seq = seqq.data
                            sorted_subgraph_ind = keep_nms_ind.data                               
                    else: # for show control and tell, order should be same as inputs and thus no sorting
                        valid_num = int(subgraph_score.size(0) / 2)
                        seq = seqq.data[:valid_num]
                        subgraph_score = subgraph_score.data[:valid_num]
                        sorted_subgraph_ind = keep_nms_ind[:valid_num]
                        sort_ind = keep_nms_ind[:valid_num].long()                            

                print('\nNo {}:'.format(n))
                
                if beam_size > 1 and verbose_beam:
                    keep_ind = sort_ind.cpu().numpy()
                    print('beam seach sentences of image {}:'.format(data['infos'][0]['id']))
                    for i in np.random.choice(keep_ind, size=1, replace=True):
                        print('subgraph {}'.format(i))
                        print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                        print('--' * 10)
                
                sents = utils.decode_sequence(loader.get_vocab(), seq)  # use the first beam which has highest cumulative score
                
                # save best sentence generated by all subgraphs of a image
                entry = {'image_id': data['infos'][0]['id'], 'caption': []}
                entry['subgraph_score'] = subgraph_score.cpu().numpy()
                entry['sorted_subgraph_ind'] = sorted_subgraph_ind.cpu().numpy()

                for k, sent in enumerate(sents):
                    entry['caption'].append(sent)
                predictions.append(entry)
                if verbose:
                    print('keeping {} subgraphs'.format(len(sents)))
                    #print(subgraph_score)
                    with open('subgc_output_caption.txt', 'w') as fp:
                        for i in range(len(entry['caption'])): #range(3)
                            #print(entry['caption'][i])
                            fp.write("%s\n" % entry['caption'][i])
                        print('Captions saved to subgc_output_caption.txt.')                    
                    print('--' * 20)
                # collect grounding material for grounding evaluation
                if return_att_weight:
                    get_grounding_material(eval_kwargs['infos_path'], data, sents, sorted_subgraph_ind, att_weights, sort_ind, \
                        wd_to_lemma, lemma_det_id_dict, det_id_to_det_wd, grd_output, \
                        use_full_graph=not model.gpn, grd_sGPN_consensus=grd_sGPN_consensus)

            if data['bounds']['wrapped']:
                break
            if num_images >= 0 and n >= num_images:
                break

        # save prediction results
        if data.get('labels', None) is not None and verbose_loss:  # after model validation, switch back to training mode
            model.train()
            return loss_sum/loss_evals
        else:  # after model testing, save generated results
            save_path = eval_kwargs['infos_path'].split('/')

            if not sct_mode:  # sub-graph captioning
                np.save(save_path[0] + '/' + save_path[1] + '/' + 'captions_{}.npy'.format(save_path[-1].split('-')[1].split('.')[0]),predictions)
            else:  # sct mode, controllability experiments
                np.save(save_path[0] + '/' + save_path[1] + '/' + 'ctl_captions_{}.npy'.format(save_path[-1].split('-')[1].split('.')[0]),predictions)

            if return_att_weight:  # grounding experiments
                with open(save_path[0] + '/' + save_path[1] + '/' + 'grounding_file.json', 'w') as f:
                    json.dump({'results':grd_output, 'eval_mode':'gen', 'external_data':{'used':True, 'details':'grounding experiment'}}, f)

    # 2. only evaluate the generated sentences
    if model is None:
        oracle_num = eval_kwargs.get('oracle_num', 1)
        sent_cnt = []
        align_pred = []
        save_path = eval_kwargs['infos_path'].split('/')
        predictions = np.load(save_path[0] + '/' + save_path[1] + '/' + 'captions_{}.npy'.format(\
            save_path[-1].split('-')[1].split('.')[0]), allow_pickle=True,encoding='latin1').tolist()
        for p_i in range(len(predictions)):
            sent_cnt.append(len(predictions[p_i]['caption']))
            entry = {'image_id': predictions[p_i]['image_id'], 'caption': predictions[p_i]['caption'][:oracle_num]} 
            if len(entry['caption']) < oracle_num: # if subgraphs aren't engough
                for p_j in range(oracle_num)[len(entry['caption']):]:
                    entry['caption'].append(predictions[p_i]['caption'][0]) # pad with first sentence
            assert len(entry['caption']) == oracle_num
            align_pred.append(entry)
        if lang_eval == 1:
            language_eval(dataset, align_pred, eval_kwargs['id'], split, save_path, \
                            is_flickr='coco' not in eval_kwargs['input_label_h5'])

Overwriting misc/eval_utils.py


In [25]:
%%writefile test.py


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import numpy as np
import random

import time
import os

import opts
import models

from misc import eval_utils
import argparse
import misc.utils as utils
import torch

# reproducibility
random_seed = 2019
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
np.random.seed(random_seed)
random.seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Input arguments and options
parser = argparse.ArgumentParser()
####### Original hyper-parameters #######
# Input paths
parser.add_argument('--model', type=str, default='',
                help='path to model to evaluate')
parser.add_argument('--cnn_model', type=str,  default='resnet101',
                help='resnet101, resnet152')
parser.add_argument('--infos_path', type=str, default='',
                help='path to infos to evaluate')
# Basic options
parser.add_argument('--batch_size', type=int, default=0,
                help='if > 0 then overrule, otherwise load from checkpoint.')
parser.add_argument('--num_images', type=int, default=-1,
                help='how many images to use when periodically evaluating the loss? (-1 = all)')
parser.add_argument('--language_eval', type=int, default=0,
                help='Evaluate language as well (1 = yes, 0 = no)? BLEU/CIDEr/METEOR/ROUGE_L? requires coco-caption code from Github.')
parser.add_argument('--dump_images', type=int, default=1,
                help='Dump images into vis/imgs folder for vis? (1=yes,0=no)')
parser.add_argument('--dump_json', type=int, default=1,
                help='Dump json with predictions into vis folder? (1=yes,0=no)')
parser.add_argument('--dump_path', type=int, default=0,
                help='Write image paths along with predictions into vis json? (1=yes,0=no)')
# Sampling options
parser.add_argument('--sample_max', type=int, default=1,
                help='1 = sample argmax words. 0 = sample from distributions.')
parser.add_argument('--beam_size', type=int, default=2,
                help='used when sample_max = 1, indicates number of beams in beam search. Usually 2 or 3 works well. More is not better. Set this to 1 for faster runtime but a bit worse performance.')
parser.add_argument('--max_length', type=int, default=20,
                help='Maximum length during sampling')
parser.add_argument('--length_penalty', type=str, default='',
                help='wu_X or avg_X, X is the alpha')
parser.add_argument('--group_size', type=int, default=1,
                help='used for diverse beam search. if group_size is 1, then it\'s normal beam search')
parser.add_argument('--diversity_lambda', type=float, default=0.5,
                help='used for diverse beam search. Usually from 0.2 to 0.8. Higher value of lambda produces a more diverse list')
parser.add_argument('--temperature', type=float, default=1.0,
                help='temperature when sampling from distributions (i.e. when sample_max = 0). Lower = "safer" predictions.')
parser.add_argument('--decoding_constraint', type=int, default=0,
                help='If 1, not allowing same word in a row')
parser.add_argument('--block_trigrams', type=int, default=0,
                help='block repeated trigram.')
parser.add_argument('--remove_bad_endings', type=int, default=0,
                help='Remove bad endings')
# For evaluation on a folder of images:
parser.add_argument('--image_folder', type=str, default='', 
                help='If this is nonempty then will predict on the images in this folder path')
parser.add_argument('--image_root', type=str, default='', 
                help='In case the image paths have to be preprended with a root path to an image folder')



# For evaluation on a single image:
parser.add_argument('--test_id', type=int, default='', 
                help='Gives caption of particular image id')




# For evaluation on MSCOCO images from some split:
parser.add_argument('--input_fc_dir', type=str, default='',
                help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_att_dir', type=str, default='',
                help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_box_dir', type=str, default='',
                help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_label_h5', type=str, default='',
                help='path to the h5file containing the preprocessed dataset')
parser.add_argument('--input_json', type=str, default='', 
                help='path to the json file containing additional info and vocab. empty = fetch from model checkpoint.')
parser.add_argument('--split', type=str, default='single', ### split = test
                help='if running on MSCOCO images, which split to use: val|test|train')
parser.add_argument('--coco_json', type=str, default='', 
                help='if nonempty then use this file in DataLoaderRaw (see docs there). Used only in MSCOCO test evaluation, where we have a specific json file of only test set images.')
# misc
parser.add_argument('--id', type=str, default='', 
                help='an id identifying this run/job. used only if language_eval = 1 for appending to intermediate files')
parser.add_argument('--verbose_beam', type=int, default=1, 
                help='if we need to print out all beam search beams.')
parser.add_argument('--verbose_loss', type=int, default=0, 
                help='if we need to calculate loss.')

####### Graph captioning model hyper-parameters #######
parser.add_argument('--use_gpn', type=int, default=1, 
                help='1: use GPN module in the captioning model')
parser.add_argument('--embed_dim', type=int, default=300, 
                help='dim of word embeddings')
parser.add_argument('--gcn_dim', type=int, default=1024, 
                help='dim of the node/edge features in GCN')
parser.add_argument('--noun_fuse', type=int, default=1, 
                help='1: fuse the word embedding with visual features for noun nodes')
parser.add_argument('--pred_emb_type', type=int, default=1, 
                help='predicate embedding type')
parser.add_argument('--gcn_layers', type=int, default=2, 
                help='the layer number of GCN')
parser.add_argument('--gcn_residual', type=int, default=2,
                help='2: there is a skip connection every 2 GCN layers')
parser.add_argument('--gcn_bn', type=int, default=0, 
                help='0: not use BN in GCN layers')
parser.add_argument('--sampling_prob', type=float, default=0.0, 
                help='Schedule sampling probability')
parser.add_argument('--obj_name_path', type=str, default='data/object_names_1600-0-20.npy', 
                help='the file path for object names')
parser.add_argument('--rel_name_path', type=str, default='data/predicate_names_1600-0-20.npy', 
                help='the file path for predicate names')

# parser.add_argument('--gpn_label_thres', type=float, default=0.75, 
#                 help='the threshold of positive/negative sub-graph labels during training')
parser.add_argument('--use_MRNN_split', action='store_true',
                help='use the split of MRNN on COCO Caption dataset')
parser.add_argument('--use_gt_subg', action='store_true',
                help='(Sup. model for SCT) use the ground-truth sub-graphs without neighbors and same-cls nodes') 
parser.add_argument('--use_greedy_subg', action='store_true',
                help='(Unsup. model for SCT) use gt box to greedily find the sub-graphs with neighbors and same-cls nodes') 
# parser.add_argument('--gpn_batch', type=int, default=2, 
#                 help='the batch size for positive/negative sub-graphs during training')    
parser.add_argument('--obj_num', type=int, default=37, 
                help='the number of detected objects + 1 dummy object')  
parser.add_argument('--rel_num', type=int, default=65, 
                help='the number of detected relationships + 1 dummy relationship')  

parser.add_argument('--num_workers', type=int, default=6, 
            help='number of workers to use')  

####### Hyper-parameters that only belongs to evaluation #######
parser.add_argument('--test_LSTM', type=int, default=1,
                help='1: generate captions, used during evaluation (testing)')
parser.add_argument('--use_topk_sampling', type=int, default=0,
                help='1: use topk sampling during decoding each word')
parser.add_argument('--topk_temp', type=float, default=0.6,
                help='the temperature used in topk sampling')
parser.add_argument('--the_k', type=int, default=3,
                help='k top candidates are used in sampling')
parser.add_argument('--gpn_nms_thres', type=float, default=0.75, 
            help='the threshold in sub-graph NMS during testing')
parser.add_argument('--gpn_max_subg', type=int, default=1, 
            help='the maximum number of sub-graphs to be kept during testing')

# sentence evaluation
parser.add_argument('--only_sent_eval', type=int, default=0, 
                help='evaluate sentence scores: 1, only run sentence evaluation; 0, only generate sentences')
parser.add_argument('--oracle_num', type=int, default=1, 
                help='how many sentences are used to calculate the top-1 accuracy')
# grounding attention triger
parser.add_argument('--return_att', type=int, default=0, 
                help='1: return attention weight for each time step, for grounding evaluation')
# show-control-tell mode triger
parser.add_argument('--sct', type=int, default=0, 
                help='1: use sct mode where not sorting the sub-graphs and ensure the order is same as input region sets; for controllability experiments')

opt = parser.parse_args()

if __name__ == '__main__':
    # Load infos from trained model files
    with open(opt.infos_path, 'rb') as f:
        infos = utils.pickle_load(f)

    # override and collect parameters
    if len(opt.input_fc_dir) == 0:
        opt.input_fc_dir = infos['opt'].input_fc_dir
        opt.input_att_dir = infos['opt'].input_att_dir
        opt.input_box_dir = getattr(infos['opt'], 'input_box_dir', '')
        opt.input_label_h5 = infos['opt'].input_label_h5
    if len(opt.input_json) == 0:
        opt.input_json = infos['opt'].input_json
    if opt.batch_size == 0:
        opt.batch_size = 1
    if len(opt.id) == 0:
        opt.id = infos['opt'].id
    ignore = ["id", "batch_size", "beam_size", "start_from", "language_eval", "block_trigrams"]

    # Ensure the common vars are the same; for vars only in train, copy to the opt in eval; for vars only in eval, no overrriding
    for k in vars(infos['opt']).keys():
        if k not in ignore:
            if k in vars(opt):
                assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
            else:
                vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model

    vocab = infos['vocab'] # ix -> word mapping

    if opt.only_sent_eval == 1:  # no model inference, only evaluate generated sentences
        model = None
    elif opt.only_sent_eval == 0:  # Setup the model for inference
        model = models.setup(opt)
        model.load_state_dict(torch.load(opt.model))
        model.cuda()
        model.eval()
        
    crit = utils.LanguageModelCriterion()

    # Create the Data Loader instance
    if opt.sct == 0: # normal mode
        from dataloaders.dataloader_test import * 
    else:  # sct mode
        from dataloaders.dataloader_test_sct import * 
    if len(opt.image_folder) == 0:
      loader = DataLoader(opt)

    # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
    # So make sure to use the vocab in infos file.
    loader.ix_to_word = infos['vocab']

    eval_utils.eval_split(model, crit, loader, vars(opt))

Overwriting test.py


## MS COCO

In [48]:
test_id = 483108

#print GT captions

!(python test.py --dump_images 0 --dump_json 1 --num_images -1 \
   --test_id {test_id} --num_workers 2 --language_eval 1 --beam_size 1 \
   --gpn_nms_thres 0.55 --gpn_max_subg 1000 --use_MRNN_split \
   --model pretrained/sub_gc_MRNN/model-60000.pth --oracle_num 5\
   --infos_path pretrained/sub_gc_MRNN/infos_topdown-60000.pkl \
   --only_sent_eval 0)

loading word vectors from data/glove.6B.300d.pt
Fail on __background__
loading word vectors from data/glove.6B.300d.pt
Fail on __background__
DataLoader loading json file:  data/cocotalk.json
vocab size is  9487
DataLoader loading h5 file:  data/cocotalk_label.h5
max sequence length in data is 16
read 123287 image features
assigned 118287 images to split train
assigned 4000 images to split val
assigned 1000 images to split test
assigned 1 image to split single

No 1:
keeping 29 subgraphs
Captions saved to subgc_output_caption.txt.
----------------------------------------
Terminating BlobFetcher


In [49]:
test_id = 483108

#print GT captions

!(python test.py --dump_images 0 --dump_json 1 --num_images -1 \
   --test_id {test_id} --num_workers 2 --language_eval 1 --beam_size 1 \
   --gpn_nms_thres 0.55 --gpn_max_subg 1000 --use_MRNN_split \
   --model pretrained/sub_gc_karpathy/model-60000.pth --oracle_num 5\
   --infos_path pretrained/sub_gc_karpathy/infos_topdown-60000.pkl \
   --only_sent_eval 0)

loading word vectors from data/glove.6B.300d.pt
Fail on __background__
loading word vectors from data/glove.6B.300d.pt
Fail on __background__
DataLoader loading json file:  data/cocotalk.json
vocab size is  9487
DataLoader loading h5 file:  data/cocotalk_label.h5
max sequence length in data is 16
read 123287 image features
assigned 118287 images to split train
assigned 4000 images to split val
assigned 1000 images to split test
assigned 1 image to split single

No 1:
keeping 31 subgraphs
Captions saved to subgc_output_caption.txt.
----------------------------------------
Terminating BlobFetcher


## Run on COCO Dataset

In [83]:
import os, json
image_data = json.load(open('/home/jaleed/Jaleed/ModelsNDatasets/VG/image_data.json'))
vg_ids = [data['image_id'] for data in image_data]
coco_ids = [data['coco_id'] for data in image_data]
dir_eval_io = '/home/jaleed/Jaleed/Eval_IO/coco/'
for filename in os.listdir(f'{dir_eval_io}0_images/'):
    img_path = f'{dir_eval_io}0_images/{filename}'
    img_id = int(filename.split('.jpg')[0])
    cap_path = f'{dir_eval_io}3_captions/{filename}.txt'
    if os.path.exists(cap_path):
        print('***** EXISTS: '+ cap_path +' *****')
        continue
    else:
        print('***** PROCESSING '+ img_path +' *****')
        coco_id = coco_ids[vg_ids.index(img_id)]
        !(python test.py --dump_images 0 --dump_json 1 --num_images -1 \
           --test_id {coco_id} --num_workers 2 --language_eval 1 --beam_size 1 \
           --gpn_nms_thres 0.55 --gpn_max_subg 1000 --use_MRNN_split \
           --model pretrained/sub_gc_karpathy/model-60000.pth --oracle_num 5\
           --infos_path pretrained/sub_gc_karpathy/infos_topdown-60000.pkl \
           --only_sent_eval 0)
        shutil.move('/home/jaleed/Jaleed/SubGC/Sub-GC/subgc_output_caption.txt', cap_path)
        print('Saved '+ cap_path)
    print('\n')

***** EXISTS: /home/jaleed/Jaleed/Eval_IO_1000/coco/3_captions/150268.jpg.txt *****
***** EXISTS: /home/jaleed/Jaleed/Eval_IO_1000/coco/3_captions/2415125.jpg.txt *****
***** EXISTS: /home/jaleed/Jaleed/Eval_IO_1000/coco/3_captions/150536.jpg.txt *****
***** EXISTS: /home/jaleed/Jaleed/Eval_IO_1000/coco/3_captions/150269.jpg.txt *****
***** EXISTS: /home/jaleed/Jaleed/Eval_IO_1000/coco/3_captions/498411.jpg.txt *****
***** PROCESSING /home/jaleed/Jaleed/Eval_IO_1000/coco/0_images/150277.jpg *****
loading word vectors from data/glove.6B.300d.pt
Fail on __background__
loading word vectors from data/glove.6B.300d.pt
Fail on __background__
DataLoader loading json file:  data/cocotalk.json
vocab size is  9487
DataLoader loading h5 file:  data/cocotalk_label.h5
max sequence length in data is 16
read 123287 image features
assigned 118287 images to split train
assigned 4000 images to split val
assigned 1000 images to split test
assigned 1 image to split single

No 1:
keeping 21 subgraphs
Capti