In [1]:
% cd /home/mayu-ot/durga/Experiments/loc_iparaphrasing/

/mnt/fs1/mayu-ot/Experiments/loc_iparaphrasing


In [8]:
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import initializers
from chainercv.utils import read_image
import numpy as np
import sys
sys.path.append('func/nets/')

In [3]:
from faster_rcnn import FasterRCNNExtractor

In [6]:
class FullModel(chainer.Chain):
    def __init__(self):
        super(FullModel, self).__init__()
        with self.init_scope():
            self.frcnn = FasterRCNNExtractor(n_fg_class=20)
            self.fc_l = L.Linear(None, 1000, initializers.HeNormal())
            self.fc_v = L.Linear(None, 1000, initializers.HeNormal())
            
            # fusion net
            self.fuse_l = L.Linear(None, 500, nobias=True, initialW=initializers.HeNormal())
            self.fuse_v = L.Linear(None, 500, initialW=initializers.HeNormal())
            
            # classification net
            self.mlp_0_l = L.Linear(None, 500, nobias=True, initialW=initializers.HeNormal())
            self.mlp_0_r = L.Linear(None, 500, initialW=initializers.HeNormal())
            self.mlp_1 = L.Linear(None, 2, initialW=initializers.LeCunNormal())
            
    def __call__(self, Xim, Xp1, Xp2, roi1, roi2, L):
        roi_indices = self.xp.arange(len(Xim)).astype('f')
        h_v1 = self.frcnn.extract(Xim, Xroi1, roi_indices)
        h_v2 = self.frcnn.extract(Xim, Xroi2, roi_indices) # duplicate computation in the bottom layers
        
        
        # fuse visual and language features
        h_1 = F.relu(self.fuse_l(Xp1) + self.fuse_v(h_v1))
        h_2 = F.relu(self.fuse_l(Xp2) + self.fuse_v(h_v2))
        
        # classification -> VGP or non-VGP
        h = F.relu(self.mlp_0_l(h_1) + self.mlp_0_r(h_2)) # what will happen about order variance?
        h = F.mlp_1(h)
        loss = F.softmax_cross_entropy(h, L)
        return loss


In [7]:
model = FullModel()

In [48]:
sys.path.append('script/training/')
from train import get_dataset

In [None]:
'''
Todo:
chainer dataset実装
(参考）script/training/train.py RegionEntityDatasetBase, EntityDatasetPhraseFeat あたりに似たようなものがある
get_roiとget_labelが未実装

phraseに対応するroiの取り方
データのあるファイル
- data/pl-clc/phrase_pair_wt_plclcbbox_<split>.csv
- data/region_feat/roi/full_<split>.h5

例）validation setのi番目のフレーズペアに対応するroi座標取得
import pandas as pd
df = pd.read_csv('data/pl-clc/phrase_pair_wt_plclcbbox_val.csv')
h5file = tables.open_file('data/region_feat/roi/full_val.h5')

row = df.iloc[0]
rindex_1, rindex_2 = row[['roi1', 'roi2']]

node = h5file.get_node('/', str(row['image']))
roi1 = node[rindex_1]
roi2 = node[rindex_2] # xmin, ymin, xmax, ymax

* faster rcnnは(ymin, xmin, ymax, xmax)の形式で受け取るので変換が必要
'''

class DataPbd(chainer.dataset.DatasetMixin):
    
    def __init__(self, phrase_feature_file, unique_phrase_file):
        phrase2id_dict = defaultdict(lambda: -1)
        with open(unique_phrase_file) as f:
            for i, line in enumerate(f):
                phrase2id_dict[line.rstrip()] = i
        
        self._p2i_dict = phrase2id_dict
        self._feat = np.load(phrase_feature_file).astype(np.float32)
    
    def get_roi(self, i):
        # not implemented
        # input
        # i: dataset index
        # return
        # roi1: roi for phrase 1
        # roi2: roi for phrase 2
        
        return roi1, roi2
    
    def get_label(self, i):
        # not implemented
        # input
        # i: dataset index
        # return
        # l: binary label. 1 if phrases are VGP, otherwise 0
        return l
    
    def get_example(self, i):
        img = read_image(self.root + self.image[i], color=True)
        
        # get roi
        roi1, roi2 = self.get_roi(i)
        
        # preprocess image and roi
        x = model.prepare(img)
        scale = x.shape[-1] * 1. / img.shape[-1]
        roi1 = roi1 * scale
        roi2 = roi2 * scale
        
        # get phrase features
        p1 = self._feat[self._p2i_dict[self._phrase1[i]]]
        p2 = self._feat[self._p2i_dict[self._phrase2[i]]]
        
        # get label (VGP or non-VGP)
        l = self.get_label(i)
        
        return img, p1, p2, roi1, roi2, l