Skip to content

Commit

Permalink
Merge pull request #5213 from crcrpar/feature/vgg19
Browse files Browse the repository at this point in the history
Add `L.VGG19Layers`
  • Loading branch information
kmaehashi committed Sep 5, 2018
2 parents d1385a9 + c84b903 commit 96f68cd
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 51 deletions.
1 change: 1 addition & 0 deletions chainer/links/__init__.py
Expand Up @@ -51,6 +51,7 @@
from chainer.links.model.vision.resnet import ResNet152Layers # NOQA
from chainer.links.model.vision.resnet import ResNet50Layers # NOQA
from chainer.links.model.vision.vgg import VGG16Layers # NOQA
from chainer.links.model.vision.vgg import VGG19Layers # NOQA
from chainer.links.normalization.batch_normalization import BatchNormalization # NOQA
from chainer.links.normalization.batch_renormalization import BatchRenormalization # NOQA
from chainer.links.normalization.group_normalization import GroupNormalization # NOQA
Expand Down
241 changes: 196 additions & 45 deletions chainer/links/model/vision/vgg.py
Expand Up @@ -31,17 +31,19 @@
from chainer.variable import Variable


class VGG16Layers(link.Chain):
class VGGLayers(link.Chain):

"""A pre-trained CNN model with 16 layers provided by VGG team.
"""A pre-trained CNN model provided by VGG team.
During initialization, this chain model automatically downloads
the pre-trained caffemodel, convert to another chainer model,
stores it on your local directory, and initializes all the parameters
with it. This model would be useful when you want to extract a semantic
You can use ``VGG16Layers`` or ``VGG19Layers`` for concrete
implementations. During initialization, this chain model
automatically downloads the pre-trained caffemodel, convert to
another chainer model, stores it on your local directory,
and initializes all the parameters with it.
This model would be useful when you want to extract a semantic
feature vector from a given image, or fine-tune the model
on a different dataset.
Note that this pre-trained model is released under Creative Commons
Note that these pre-trained models are released under Creative Commons
Attribution License.
If you want to manually convert the pre-trained caffemodel to a chainer
Expand All @@ -66,14 +68,17 @@ class VGG16Layers(link.Chain):
are not initialized by the pre-trained model, but the default
initializer used in the original paper, i.e.,
``chainer.initializers.Normal(scale=0.01)``.
n_layers (int): The number of layers of this model. It should be
either 16 or 19.
Attributes:
available_layers (list of str): The list of available layer names
used by ``forward`` and ``extract`` methods.
"""

def __init__(self, pretrained_model='auto'):
def __init__(self, pretrained_model='auto', n_layers=16):
super(VGGLayers, self).__init__()
if pretrained_model:
# As a sampling process is time-consuming,
# we employ a zero initializer for faster computation.
Expand All @@ -85,7 +90,12 @@ def __init__(self, pretrained_model='auto'):
'initialW': normal.Normal(0.01),
'initial_bias': constant.Zero(),
}
super(VGG16Layers, self).__init__()

if n_layers not in [16, 19]:
raise ValueError(
'The n_layers argument should be either 16 or 19,'
'but {} was given.'.format(n_layers)
)

with self.init_scope():
self.conv1_1 = Convolution2D(3, 64, 3, 1, 1, **kwargs)
Expand All @@ -104,42 +114,31 @@ def __init__(self, pretrained_model='auto'):
self.fc6 = Linear(512 * 7 * 7, 4096, **kwargs)
self.fc7 = Linear(4096, 4096, **kwargs)
self.fc8 = Linear(4096, 1000, **kwargs)
if n_layers == 19:
self.conv3_4 = Convolution2D(256, 256, 3, 1, 1, **kwargs)
self.conv4_4 = Convolution2D(512, 512, 3, 1, 1, **kwargs)
self.conv5_4 = Convolution2D(512, 512, 3, 1, 1, **kwargs)

if pretrained_model == 'auto':
_retrieve(
'VGG_ILSVRC_16_layers.npz',
'https://www.robots.ox.ac.uk/%7Evgg/software/very_deep/'
'caffe/VGG_ILSVRC_16_layers.caffemodel',
self)
if n_layers == 16:
_retrieve(
'VGG_ILSVRC_16_layers.npz',
'https://www.robots.ox.ac.uk/%7Evgg/software/very_deep/'
'caffe/VGG_ILSVRC_16_layers.caffemodel',
self)
else:
_retrieve(
'VGG_ILSVRC_19_layers.npz',
'http://www.robots.ox.ac.uk/%7Evgg/software/very_deep/'
'caffe/VGG_ILSVRC_19_layers.caffemodel',
self)
elif pretrained_model:
npz.load_npz(pretrained_model, self)

@property
def functions(self):
return collections.OrderedDict([
('conv1_1', [self.conv1_1, relu]),
('conv1_2', [self.conv1_2, relu]),
('pool1', [_max_pooling_2d]),
('conv2_1', [self.conv2_1, relu]),
('conv2_2', [self.conv2_2, relu]),
('pool2', [_max_pooling_2d]),
('conv3_1', [self.conv3_1, relu]),
('conv3_2', [self.conv3_2, relu]),
('conv3_3', [self.conv3_3, relu]),
('pool3', [_max_pooling_2d]),
('conv4_1', [self.conv4_1, relu]),
('conv4_2', [self.conv4_2, relu]),
('conv4_3', [self.conv4_3, relu]),
('pool4', [_max_pooling_2d]),
('conv5_1', [self.conv5_1, relu]),
('conv5_2', [self.conv5_2, relu]),
('conv5_3', [self.conv5_3, relu]),
('pool5', [_max_pooling_2d]),
('fc6', [self.fc6, relu, dropout]),
('fc7', [self.fc7, relu, dropout]),
('fc8', [self.fc8]),
('prob', [softmax]),
])
# This class will not be used directly.
raise NotImplementedError

@property
def available_layers(self):
Expand Down Expand Up @@ -168,18 +167,20 @@ def forward(self, x, layers=None, **kwargs):
.. warning::
``test`` argument is not supported anymore since v2.
Instead, use ``chainer.using_config('train', train)``.
Instead, use ``chainer.using_config('train', False)``
to run in test mode.
See :func:`chainer.using_config`.
Args:
x (~chainer.Variable): Input variable. It should be prepared by
``prepare`` function.
layers (list of str): The list of layer names you want to extract.
If ``None``, 'prob' will be used as layers.
Returns:
Dictionary of ~chainer.Variable: A directory in which
the key contains the layer name and the value contains
the corresponding feature map variable.
Dictionary of ~chainer.Variable: A dictionary in which
the key contains the layer and the value contains the
corresponding feature map variable.
"""

Expand All @@ -189,7 +190,8 @@ def forward(self, x, layers=None, **kwargs):
if kwargs:
argument.check_unexpected_kwargs(
kwargs, test='test argument is not supported anymore. '
'Use chainer.using_config')
'Use chainer.using_config'
)
argument.assert_kwargs_empty(kwargs)

h = x
Expand Down Expand Up @@ -224,7 +226,7 @@ def extract(self, images, layers=None, size=(224, 224), **kwargs):
.. code-block:: python
# model is an instance of VGG16Layers
# model is an instance of VGGLayers (16 or 19 layers)
with chainer.using_config('train', False):
with chainer.using_config('enable_backprop', False):
feature = model.extract([image])
Expand Down Expand Up @@ -312,6 +314,155 @@ def predict(self, images, oversample=True):
return y


class VGG16Layers(VGGLayers):

"""A pre-trained CNN model with 16 layers provided by VGG team.
During initialization, this chain model automatically downloads
the pre-trained caffemodel, convert to another chainer model,
stores it on your local directory, and initializes all the parameters
with it. This model would be useful when you want to extract a semantic
feature vector from a given image, or fine-tune the model
on a different dataset.
Note that this pre-trained model is released under Creative Commons
Attribution License.
If you want to manually convert the pre-trained caffemodel to a chainer
model that can be specified in the constructor,
please use ``convert_caffemodel_to_npz`` classmethod instead.
See: K. Simonyan and A. Zisserman, `Very Deep Convolutional Networks
for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`_
Args:
pretrained_model (str): the destination of the pre-trained
chainer model serialized as a ``.npz`` file.
If this argument is specified as ``auto``,
it automatically downloads the caffemodel from the internet.
Note that in this case the converted chainer model is stored
on ``$CHAINER_DATASET_ROOT/pfnet/chainer/models`` directory,
where ``$CHAINER_DATASET_ROOT`` is set as
``$HOME/.chainer/dataset`` unless you specify another value
as a environment variable. The converted chainer model is
automatically used from the second time.
If the argument is specified as ``None``, all the parameters
are not initialized by the pre-trained model, but the default
initializer used in the original paper, i.e.,
``chainer.initializers.Normal(scale=0.01)``.
Attributes:
available_layers (list of str): The list of available layer names
used by ``forward`` and ``extract`` methods.
"""

def __init__(self, pretrained_model='auto'):
super(VGG16Layers, self).__init__(pretrained_model, 16)

@property
def functions(self):
return collections.OrderedDict([
('conv1_1', [self.conv1_1, relu]),
('conv1_2', [self.conv1_2, relu]),
('pool1', [_max_pooling_2d]),
('conv2_1', [self.conv2_1, relu]),
('conv2_2', [self.conv2_2, relu]),
('pool2', [_max_pooling_2d]),
('conv3_1', [self.conv3_1, relu]),
('conv3_2', [self.conv3_2, relu]),
('conv3_3', [self.conv3_3, relu]),
('pool3', [_max_pooling_2d]),
('conv4_1', [self.conv4_1, relu]),
('conv4_2', [self.conv4_2, relu]),
('conv4_3', [self.conv4_3, relu]),
('pool4', [_max_pooling_2d]),
('conv5_1', [self.conv5_1, relu]),
('conv5_2', [self.conv5_2, relu]),
('conv5_3', [self.conv5_3, relu]),
('pool5', [_max_pooling_2d]),
('fc6', [self.fc6, relu, dropout]),
('fc7', [self.fc7, relu, dropout]),
('fc8', [self.fc8]),
('prob', [softmax]),
])


class VGG19Layers(VGGLayers):

"""A pre-trained CNN model with 19 layers provided by VGG team.
During initialization, this chain model automatically downloads
the pre-trained caffemodel, convert to another chainer model,
stores it on your local directory, and initializes all the parameters
with it. This model would be useful when you want to extract a semantic
feature vector from a given image, or fine-tune the model
on a different dataset.
Note that this pre-trained model is released under Creative Commons
Attribution License.
If you want to manually convert the pre-trained caffemodel to a chainer
model that can be specified in the constructor,
please use ``convert_caffemodel_to_npz`` classmethod instead.
See: K. Simonyan and A. Zisserman, `Very Deep Convolutional Networks
for Large-Scale Image Recognition <https://arxiv.org/abs/1409.1556>`_
Args:
pretrained_model (str): the destination of the pre-trained
chainer model serialized as a ``.npz`` file.
If this argument is specified as ``auto``,
it automatically downloads the caffemodel from the internet.
Note that in this case the converted chainer model is stored
on ``$CHAINER_DATASET_ROOT/pfnet/chainer/models`` directory,
where ``$CHAINER_DATASET_ROOT`` is set as
``$HOME/.chainer/dataset`` unless you specify another value
as a environment variable. The converted chainer model is
automatically used from the second time.
If the argument is specified as ``None``, all the parameters
are not initialized by the pre-trained model, but the default
initializer used in the original paper, i.e.,
``chainer.initializers.Normal(scale=0.01)``.
Attributes:
available_layers (list of str): The list of available layer names
used by ``forward`` and ``extract`` methods.
"""

def __init__(self, pretrained_model='auto'):
super(VGG19Layers, self).__init__(pretrained_model, 19)

@property
def functions(self):
return collections.OrderedDict([
('conv1_1', [self.conv1_1, relu]),
('conv1_2', [self.conv1_2, relu]),
('pool1', [_max_pooling_2d]),
('conv2_1', [self.conv2_1, relu]),
('conv2_2', [self.conv2_2, relu]),
('pool2', [_max_pooling_2d]),
('conv3_1', [self.conv3_1, relu]),
('conv3_2', [self.conv3_2, relu]),
('conv3_3', [self.conv3_3, relu]),
('conv3_4', [self.conv3_4, relu]),
('pool3', [_max_pooling_2d]),
('conv4_1', [self.conv4_1, relu]),
('conv4_2', [self.conv4_2, relu]),
('conv4_3', [self.conv4_3, relu]),
('conv4_4', [self.conv4_4, relu]),
('pool4', [_max_pooling_2d]),
('conv5_1', [self.conv5_1, relu]),
('conv5_2', [self.conv5_2, relu]),
('conv5_3', [self.conv5_3, relu]),
('conv5_4', [self.conv5_4, relu]),
('pool5', [_max_pooling_2d]),
('fc6', [self.fc6, relu, dropout]),
('fc7', [self.fc7, relu, dropout]),
('fc8', [self.fc8]),
('prob', [softmax]),
])


def prepare(image, size=(224, 224)):
"""Converts the given image to the numpy array for VGG models.
Expand Down Expand Up @@ -366,7 +517,7 @@ def _make_npz(path_npz, url, model):
sys.stderr.write(
'Now loading caffemodel (usually it may take few minutes)\n')
sys.stderr.flush()
VGG16Layers.convert_caffemodel_to_npz(path_caffemodel, path_npz)
VGGLayers.convert_caffemodel_to_npz(path_caffemodel, path_npz)
npz.load_npz(path_npz, model)
return model

Expand Down
5 changes: 3 additions & 2 deletions docs/source/reference/links.rst
Expand Up @@ -111,14 +111,15 @@ where ``fc7`` denotes a layer before the last fully-connected layer.
Unlike the usual links, these classes automatically load all the
parameters from the pre-trained models during initialization.

VGG16Layers
~~~~~~~~~~~
VGG Networks
~~~~~~~~~~~~

.. autosummary::
:toctree: generated/
:nosignatures:

chainer.links.VGG16Layers
chainer.links.VGG19Layers
chainer.links.model.vision.vgg.prepare

GoogLeNet
Expand Down
15 changes: 11 additions & 4 deletions tests/chainer_tests/links_tests/model_tests/test_vision.py
Expand Up @@ -166,24 +166,31 @@ def test_copy_gpu(self):


@testing.parameterize(*testing.product({
'n_layers': [16, 19],
'dtype': [numpy.float16, numpy.float32],
}))
@unittest.skipUnless(resnet.available, 'Pillow is required')
@attr.slow
class TestVGG16Layers(unittest.TestCase):
class TestVGGs(unittest.TestCase):

def setUp(self):
self._config_user = chainer.using_config('dtype', self.dtype)
self._config_user.__enter__()
self.link = vgg.VGG16Layers(pretrained_model=None)
if self.n_layers == 16:
self.link = vgg.VGG16Layers(pretrained_model=None)
elif self.n_layers == 19:
self.link = vgg.VGG19Layers(pretrained_model=None)

def tearDown(self):
self._config_user.__exit__(None, None, None)

def test_available_layers(self):
result = self.link.available_layers
assert isinstance(result, list)
assert len(result) == 22
if self.n_layers == 16:
assert len(result) == 22
elif self.n_layers == 19:
assert len(result) == 25

def check_call(self):
xp = self.link.xp
Expand Down Expand Up @@ -222,7 +229,7 @@ def test_prepare(self):
assert y4.dtype == self.dtype
y5 = vgg.prepare(x5, size=None)
assert y5.shape == (3, 160, 120)
assert y5.dtype == numpy.float32
assert y5.dtype == self.dtype

def check_extract(self):
x1 = numpy.random.uniform(0, 255, (320, 240, 3)).astype(numpy.uint8)
Expand Down

0 comments on commit 96f68cd

Please sign in to comment.