Skip to content
Permalink
Browse files

fix(encoder): add netvlad and netfv register class

  • Loading branch information...
Larryjianfeng committed Jul 31, 2019
1 parent 92500f0 commit 679915336a2d3d99041844717723e8a06dae5899
Showing with 46 additions and 10 deletions.
  1. +46 −10 gnes/encoder/video/incep_mixture.py
@@ -18,11 +18,11 @@
import numpy as np
from PIL import Image

from ..base import BaseVideoEncoder
from ...helper import batching, batch_iterator, get_first_available_gpu
from gnes.encoder.base import BaseImageEncoder
from gnes.helper import batching, batch_iterator, get_first_available_gpu


class IncepMixtureEncoder(BaseVideoEncoder):
class IncepMixtureEncoder(BaseImageEncoder):

def __init__(self, model_dir_inception: str,
model_dir_mixture: str,
@@ -32,9 +32,11 @@ def __init__(self, model_dir_inception: str,
feature_size: int = 300,
vocab_size: int = 28,
cluster_size: int = 256,
method: str = 'netvlad',
method: str = 'fvnet',
input_size: int = 1536,
multitask_method: str = 'Attention'
vocab_size_2: int = 174,
max_frames: int = 30,
multitask_method: str = 'Attention',
*args, **kwargs):
super().__init__(*args, **kwargs)
self.model_dir_inception = model_dir_inception
@@ -48,12 +50,16 @@ def __init__(self, model_dir_inception: str,
self.method = method
self.input_size = input_size
self.multitask_method = multitask_method
self.inception_size_x = 299
self.inception_size_y = 299
self.max_frames = max_frames
self.vocab_size_2 = vocab_size_2

def post_init(self):
import tensorflow as tf
from ..image.inception_cores.inception_v4 import inception_v4
from ..image.inception_cores.inception_utils import inception_arg_scope
from .mixture_core.incep_mixture import *
from gnes.encoder.image.inception_cores.inception_v4 import inception_v4
from gnes.encoder.image.inception_cores.inception_utils import inception_arg_scope
from gnes.encoder.video.mixture_core.model import NetFV
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(get_first_available_gpu())

@@ -71,7 +77,7 @@ def post_init(self):
dropout_keep_prob=1.0)

config = tf.ConfigProto(log_device_placement=False)
if self._use_cuda:
if self.use_cuda:
config.gpu_options.allow_growth = True
self.sess = tf.Session(config=config)
self.saver = tf.train.Saver()
@@ -80,17 +86,47 @@ def post_init(self):
g2 = tf.Graph()
with g2.as_default():
config = tf.ConfigProto(log_device_placement=False)
if self._use_cuda:
if self.use_cuda:
config.gpu_options.allow_growth = True
self.sess2 = tf.Session(config=config)
self.mix_model = NetFV(feature_size=self.feature_size,
cluster_size=self.cluster_size,
vocab_size=self.vocab_size,
input_size=self.input_size,
use_2nd_label=True,
vocab_size_2=self.vocab_size_2,
multitask_method=self.multitask_method,
method=self.method,
is_training=False)
saver = tf.train.Saver(max_to_keep=1)
self.sess2.run(tf.global_variables_initializer())
saver.restore(self.sess2, self.model_dir_mixture)

@batching
def encode(self, videos: List['np.ndarray'], *args, **kwargs) -> np.ndarray:
ret = []
v_len = [len(v) for v in videos]
pos_start = [0] + [sum(v_len[:i]) for i in range(1, len(v_len)-1)]
pos_end = [sum(v_len[:i]) for i in range(1, len(v_len))]
max_len = min(max(v_len), self.max_frames)

img = [im for v in videos for im in v]
img = [(np.array(Image.fromarray(im).resize((self.inception_size_x,
self.inception_size_y)), dtype=np.float32) * 2 / 255. - 1.) for im
in img]
for _im in batch_iterator(img, self.batch_size):
_, end_points_ = self.sess.run((self.logits, self.end_points),
feed_dict={self.inputs: _im})
ret.append(end_points_[self.select_layer])
v = [_ for vi in ret for _ in vi]

v_input = [v[s:e] for s, e in zip(pos_start, pos_end)]
v_input = [(vi + [[0.0]*self.input_size]*(max_len-len(vi)))[:max_len] for vi in v_input]
v_input = [np.array(vi, dtype=np.float32) for vi in v_input]

ret = []
for _vi in batch_iterator(v_input, self.batch_size):
repre = self.sess2.run(self.mix_model.repre,
feed_dict={self.mix_model.feeds: v_input})
ret.append(repre)
return np.concatenate(ret, axis=1).astype(np.float32)

0 comments on commit 6799153

Please sign in to comment.
You can’t perform that action at this time.