Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
remove predict and use SequentialFeatureExtractionChain
Browse files Browse the repository at this point in the history
  • Loading branch information
yuyu2172 committed Jun 19, 2017
1 parent 6a2b43c commit 724dcb5
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 219 deletions.
78 changes: 78 additions & 0 deletions chainercv/links/model/sequential_feature_extraction_chain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import copy

import chainer


class SequentialFeatureExtractionChain(chainer.Chain):

def __init__(self, feature_names, link_generators):
if (not isinstance(feature_names, str) and
all([isinstance(feature, str) for feature in feature_names])):
return_tuple = True
else:
return_tuple = False
feature_names = [feature_names]
self._return_tuple = return_tuple
self._feature_names = feature_names

super(SequentialFeatureExtractionChain, self).__init__()

if any([name not in self.functions for
name in self._feature_names]):
raise ValueError('Elements of `feature_names` shuold be one of '
'{}.'.format(self.functions.keys()))

# Remove all functions that are not necessary.
self._unused_function_names = []
pop_funcs = False
features = list(self._feature_names)
for name in self.functions.keys():
if pop_funcs:
self._unused_function_names.append(name)

if name in features:
features.remove(name)
if len(features) == 0:
pop_funcs = True

with self.init_scope():
for name, link_gen in link_generators.items():
# Ignore layers whose names match functions that are removed.
if name not in self._unused_function_names:
setattr(self, name, link_gen())

@property
def functions(self):
raise NotImplementedError

def __call__(self, x):
"""Forward the model.
Args:
x (~chainer.Variable): Batch of image variables.
Returns:
Variable or tuple of Variable:
A batch of features or tuple of batched features.
The returned features are selected by :obj:`feature_names` that
is passed to :meth:`__init__`.
"""
functions = copy.copy(self.functions)
for name in self._unused_function_names:
functions.pop(name)

features = {}
h = x
for name, funcs in functions.items():
for func in funcs:
h = func(h)
if name in self._feature_names:
features[name] = h

if self._return_tuple:
features = tuple(
[features[name] for name in features.keys()])
else:
features = list(features.values())[0]
return features
169 changes: 7 additions & 162 deletions chainercv/links/model/vgg/vgg16.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,19 @@

import collections

import numpy as np

import chainer
from chainer import cuda
import chainer.functions as F
from chainer.initializers import constant
from chainer.initializers import normal
import chainer.links as L

from chainercv.transforms import center_crop
from chainercv.transforms import scale
from chainercv.transforms import ten_crop

from chainercv.utils import download_model


# RGB order
_imagenet_mean = np.array(
[123.68, 116.779, 103.939], dtype=np.float32)[:, np.newaxis, np.newaxis]
from chainercv.links.model.sequential_feature_extraction_chain import \
SequentialFeatureExtractionChain


class VGG16Layers(chainer.Chain):
class VGG16Layers(SequentialFeatureExtractionChain):

"""VGG16 Network for classification and feature extraction.
Expand Down Expand Up @@ -86,24 +77,11 @@ class VGG16Layers(chainer.Chain):
}

def __init__(self, pretrained_model=None, n_class=None,
features='prob', initialW=None, initial_bias=None,
mean=_imagenet_mean, do_ten_crop=False):
if (not isinstance(features, str) and
all([isinstance(feature, str) for feature in features])):
return_tuple = True
else:
return_tuple = False
features = [features]
self._return_tuple = return_tuple
self._features = features

self.mean = mean
self.do_ten_crop = do_ten_crop

feature_names='prob', initialW=None, initial_bias=None):
if n_class is None:
if (pretrained_model is None and
all([feature not in ['fc8', 'prob']
for feature in features])):
for feature in feature_names])):
# fc8 layer is not used in this case.
pass
elif pretrained_model not in self._models:
Expand All @@ -127,8 +105,6 @@ def __init__(self, pretrained_model=None, n_class=None,
initial_bias = constant.Zero()
kwargs = {'initialW': initialW, 'initial_bias': initial_bias}

super(VGG16Layers, self).__init__()

link_generators = {
'conv1_1': lambda: L.Convolution2D(3, 64, 3, 1, 1, **kwargs),
'conv1_2': lambda: L.Convolution2D(64, 64, 3, 1, 1, **kwargs),
Expand All @@ -147,10 +123,7 @@ def __init__(self, pretrained_model=None, n_class=None,
'fc7': lambda: L.Linear(4096, 4096, **kwargs),
'fc8': lambda: L.Linear(4096, n_class, **kwargs)
}
with self.init_scope():
for name, link_gen in link_generators.items():
if name in self.functions:
setattr(self, name, link_gen())
super(VGG16Layers, self).__init__(feature_names, link_generators)

if pretrained_model in self._models:
path = download_model(self._models[pretrained_model]['url'])
Expand All @@ -163,7 +136,7 @@ def functions(self):
def _getattr(name):
return getattr(self, name, None)

funcs = collections.OrderedDict([
return collections.OrderedDict([
('conv1_1', [_getattr('conv1_1'), F.relu]),
('conv1_2', [_getattr('conv1_2'), F.relu]),
('pool1', [_max_pooling_2d]),
Expand All @@ -187,134 +160,6 @@ def _getattr(name):
('fc8', [_getattr('fc8')]),
('prob', [F.softmax]),
])
if any([name not in funcs for name in self._features]):
raise ValueError('Elements of `features` shuold be one of '
'{}.'.format(funcs.keys()))

# Remove all functions that are not necessary.
pop_funcs = False
features = list(self._features)
for name in list(funcs.keys()):
if pop_funcs:
funcs.pop(name)

if name in features:
features.remove(name)
if len(features) == 0:
pop_funcs = True

return funcs

def __call__(self, x):
"""Forward VGG16.
Args:
x (~chainer.Variable): Batch of image variables.
Returns:
Variable or tuple of Variable:
A batch of features or tuple of them.
The features to output are selected by :obj:`features` option
of :meth:`__init__`.
"""
activations = {}
h = x
for name, funcs in self.functions.items():
for func in funcs:
h = func(h)
if name in self._features:
activations[name] = h

if self._return_tuple:
activations = tuple(
[activations[name] for name in activations.keys()])
else:
activations = list(activations.values())[0]
return activations

def _prepare(self, img):
"""Transform an image to the input for VGG network.
Args:
img (~numpy.ndarray): An image. This is in CHW and RGB format.
The range of its value is :math:`[0, 255]`.
Returns:
~numpy.ndarray:
A preprocessed image.
"""
img = scale(img, size=256)
img = img - self.mean

return img

def _average_ten_crop(self, y):
xp = chainer.cuda.get_array_module(y)
n = y.shape[0] // 10
y_shape = y.shape[1:]
y = y.reshape((n, 10) + y_shape)
y = xp.sum(y, axis=1) / 10
return y

def predict(self, imgs):
"""Predict features from images.
When :obj:`self.do_ten_crop == True`, this extracts features from
patches that are ten-cropped from images.
Otherwise, this extracts features from center-crop of the images.
When using patches from ten-crop, the features for each crop
is averaged to compute one feature.
Ten-crop mode is only supported for calculation of features
:math:`fc6, fc7, fc8, prob`.
Given :math:`N` input images, this outputs a batched array with
batchsize :math:`N`.
Args:
imgs (iterable of numpy.ndarray): Array-images.
All images are in CHW and RGB format
and the range of their value is :math:`[0, 255]`.
Returns:
Variable or tuple of Variable:
A batch of features or tuple of them.
The features to output are selected by :obj:`features` option
of :meth:`__init__`.
"""
if (self.do_ten_crop and
any([feature not in ['fc6', 'fc7', 'fc8', 'prob']
for feature in self._features])):
raise ValueError

imgs = [self._prepare(img) for img in imgs]
if self.do_ten_crop:
imgs = [ten_crop(img, (224, 224)) for img in imgs]
else:
imgs = [center_crop(img, (224, 224)) for img in imgs]
imgs = self.xp.asarray(imgs).reshape(-1, 3, 224, 224)

with chainer.function.no_backprop_mode():
imgs = chainer.Variable(imgs)
activations = self(imgs)

if isinstance(activations, tuple):
output = []
for activation in activations:
activation = activation.data
if self.do_ten_crop:
activation = self._average_ten_crop(activation)
output.append(cuda.to_cpu(activation))
output = tuple(output)
else:
output = cuda.to_cpu(activations.data)
if self.do_ten_crop:
output = self._average_ten_crop(output)

return output


def _max_pooling_2d(x):
Expand Down

0 comments on commit 724dcb5

Please sign in to comment.