Skip to content

Commit

Permalink
Merge pull request #899 from dwf/conv_improvements
Browse files Browse the repository at this point in the history
Convolutional network pooling improvements
  • Loading branch information
dwf committed Nov 6, 2015
2 parents 24a489a + 769eb70 commit 1d63fa8
Show file tree
Hide file tree
Showing 2 changed files with 245 additions and 29 deletions.
165 changes: 141 additions & 24 deletions blocks/bricks/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,31 +162,44 @@ def get_dim(self, name):
self.step, self.border_mode))
return super(Convolutional, self).get_dim(name)

@property
def num_output_channels(self):
return self.num_filters

class MaxPooling(Initializable, Feedforward):
"""Max pooling layer.

Parameters
----------
pooling_size : tuple
The height and width of the pooling region i.e. this is the factor
by which your input's last two dimensions will be downscaled.
step : tuple, optional
The vertical and horizontal shift (stride) between pooling regions.
By default this is equal to `pooling_size`. Setting this to a lower
number results in overlapping pooling regions.
input_dim : tuple, optional
A tuple of integers representing the shape of the input. The last
two dimensions will be used to calculate the output dimension.
class Pooling(Initializable, Feedforward):
"""Base Brick for pooling operations.
"""
@lazy(allocation=['pooling_size'])
def __init__(self, pooling_size, step=None, input_dim=None, **kwargs):
super(MaxPooling, self).__init__(**kwargs)
This should generally not be instantiated directly; see
:class:`MaxPooling`.
self.input_dim = input_dim
"""
@lazy(allocation=['mode', 'pooling_size'])
def __init__(self, mode, pooling_size, step, input_dim, ignore_border,
padding, **kwargs):
super(Pooling, self).__init__(**kwargs)
self.pooling_size = pooling_size
self.mode = mode
self.step = step
self.input_dim = input_dim if input_dim is not None else (None,) * 3
self.ignore_border = ignore_border
self.padding = padding

@property
def image_size(self):
return self.input_dim[-2:]

@image_size.setter
def image_size(self, value):
self.input_dim = self.input_dim[:-2] + value

@property
def num_channels(self):
return self.input_dim[0]

@num_channels.setter
def num_channels(self, value):
self.input_dim = (value,) + self.input_dim[1:]

@application(inputs=['input_'], outputs=['output'])
def apply(self, input_):
Expand All @@ -207,16 +220,112 @@ def apply(self, input_):
with the last two dimensions downsampled.
"""
output = max_pool_2d(input_, self.pooling_size, st=self.step)
output = max_pool_2d(input_, self.pooling_size, st=self.step,
mode=self.mode, padding=self.padding,
ignore_border=self.ignore_border)
return output

def get_dim(self, name):
if name == 'input_':
return self.input_dim
if name == 'output':
return tuple(DownsampleFactorMax.out_shape(self.input_dim,
self.pooling_size,
st=self.step))
return tuple(DownsampleFactorMax.out_shape(
self.input_dim, self.pooling_size, st=self.step,
ignore_border=self.ignore_border, padding=self.padding))

@property
def num_output_channels(self):
return self.input_dim[0]


class MaxPooling(Pooling):
"""Max pooling layer.
Parameters
----------
pooling_size : tuple
The height and width of the pooling region i.e. this is the factor
by which your input's last two dimensions will be downscaled.
step : tuple, optional
The vertical and horizontal shift (stride) between pooling regions.
By default this is equal to `pooling_size`. Setting this to a lower
number results in overlapping pooling regions.
input_dim : tuple, optional
A tuple of integers representing the shape of the input. The last
two dimensions will be used to calculate the output dimension.
padding : tuple, optional
A tuple of integers representing the vertical and horizontal
zero-padding to be applied to each of the top and bottom
(vertical) and left and right (horizontal) edges. For example,
an argument of (4, 3) will apply 4 pixels of padding to the
top edge, 4 pixels of padding to the bottom edge, and 3 pixels
each for the left and right edge. By default, no padding is
performed.
ignore_border : bool, optional
Whether or not to do partial downsampling based on borders where
the extent of the pooling region reaches beyond the edge of the
image. If `True`, a (5, 5) image with (2, 2) pooling regions
and (2, 2) step will be downsampled to shape (2, 2), otherwise
it will be downsampled to (3, 3). `True` by default.
Notes
-----
.. warning::
As of this writing, setting `ignore_border` to `False` with a step
not equal to the pooling size will force Theano to perform pooling
computations on CPU rather than GPU, even if you have specified
a GPU as your computation device. Additionally, Theano will only
use [cuDNN]_ (if available) for pooling computations with
`ignure_border` set to `True`. You can ensure that the entire
input is captured by at least one pool by using the `padding`
argument to add zero padding prior to pooling being performed.
.. [cuDNN]: `NVIDIA cuDNN <https://developer.nvidia.com/cudnn>`_.
"""
@lazy(allocation=['pooling_size'])
def __init__(self, pooling_size, step=None, input_dim=None,
ignore_border=True, padding=(0, 0),
**kwargs):
super(MaxPooling, self).__init__('max', pooling_size,
step=step, input_dim=input_dim,
ignore_border=ignore_border,
padding=padding, **kwargs)

def __setstate__(self, state):
self.__dict__.update(state)
# Fix objects created before pull request #899.
self.mode = getattr(self, 'mode', 'max')
self.padding = getattr(self, 'padding', (0, 0))
self.ignore_border = getattr(self, 'ignore_border', False)


class AveragePooling(Pooling):
"""Average pooling layer.
Parameters
----------
include_padding : bool, optional
When calculating an average, include zeros that are the
result of zero padding added by the `padding` argument.
A value of `True` is only accepted if `ignore_border`
is also `True`. `False` by default.
Notes
-----
For documentation on the remainder of the arguments to this
class, see :class:`MaxPooling`.
"""
@lazy(allocation=['pooling_size'])
def __init__(self, pooling_size, step=None, input_dim=None,
ignore_border=True, padding=(0, 0),
include_padding=False, **kwargs):
mode = 'average_inc_pad' if include_padding else 'average_exc_pad'
super(AveragePooling, self).__init__(mode, pooling_size,
step=step, input_dim=input_dim,
ignore_border=ignore_border,
padding=padding, **kwargs)


class _AllocationMixin(object):
Expand All @@ -226,6 +335,14 @@ def _push_allocation_config(self):
'tied_biases', 'use_bias']:
setattr(self.convolution, attr, getattr(self, attr))

@property
def num_output_channels(self):
# Assumes an elementwise activation function. Would need to
# change to support e.g. maxout, but that would also require
# a way of querying the activation function for this kind of
# information.
return self.num_filters


class ConvolutionalActivation(_AllocationMixin, Sequence, Initializable):
"""A convolution followed by an activation function.
Expand Down Expand Up @@ -426,7 +543,7 @@ def _push_allocation_config(self):
if layer.image_size is not None:
output_shape = layer.get_dim('output')
image_size = output_shape[1:]
num_channels = layer.num_filters
num_channels = layer.num_output_channels


class Flattener(Brick):
Expand Down
109 changes: 104 additions & 5 deletions tests/bricks/test_conv.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pickle
import numpy
from nose.tools import assert_raises_regexp

Expand All @@ -8,7 +9,8 @@

from blocks.bricks import Rectifier
from blocks.bricks.conv import (Convolutional, ConvolutionalLayer, MaxPooling,
ConvolutionalActivation, ConvolutionalSequence)
AveragePooling, ConvolutionalActivation,
ConvolutionalSequence)
from blocks.initialization import Constant
from blocks.graph import ComputationGraph

Expand Down Expand Up @@ -117,13 +119,110 @@ def test_max_pooling():
dtype=theano.config.floatX)
assert_allclose(func(x_val),
numpy.ones((batch_size, num_channels,
x_size / pool_size + 1,
y_size / pool_size + 1)))
x_size / pool_size,
y_size / pool_size)))
pool.input_dim = (x_size, y_size)
pool.get_dim('output') == (num_channels, x_size / pool_size + 1,
y_size / pool_size + 1)


def test_max_pooling_ignore_border_true():
x = tensor.tensor4('x')
brick = MaxPooling((3, 4), ignore_border=True)
y = brick.apply(x)
out = y.eval({x: numpy.zeros((8, 3, 10, 13), dtype=theano.config.floatX)})
assert out.shape == (8, 3, 3, 3)


def test_max_pooling_ignore_border_false():
x = tensor.tensor4('x')
brick = MaxPooling((5, 7), ignore_border=False)
y = brick.apply(x)
out = y.eval({x: numpy.zeros((4, 6, 12, 15), dtype=theano.config.floatX)})
assert out.shape == (4, 6, 3, 3)


def test_max_pooling_padding():
x = tensor.tensor4('x')
brick = MaxPooling((6, 2), padding=(3, 1), ignore_border=True)
y = brick.apply(x)
out = y.eval({x: numpy.zeros((2, 3, 6, 10), dtype=theano.config.floatX)})
assert out.shape == (2, 3, 2, 6)


def test_max_pooling_old_pickle():
brick = MaxPooling((3, 4))
brick.allocate()
# Simulate old pickle, before #899.
del brick.ignore_border
del brick.mode
del brick.padding
# Pickle in this broken state and re-load.
broken_pickled = pickle.dumps(brick)
loaded = pickle.loads(broken_pickled)
# Same shape, same step.
assert brick.pooling_size == loaded.pooling_size
assert brick.step == loaded.step
# Check that the new attributes were indeed added.
assert hasattr(loaded, 'padding') and loaded.padding == (0, 0)
assert hasattr(loaded, 'mode') and loaded.mode == 'max'
assert hasattr(loaded, 'ignore_border') and not loaded.ignore_border
try:
loaded.apply(tensor.tensor4())
except Exception:
raise AssertionError("failed to apply on unpickled MaxPooling")
# Make sure we're not overriding these attributes wrongly.
new_brick = MaxPooling((4, 3), padding=(2, 1))
new_brick_unpickled = pickle.loads(pickle.dumps(new_brick))
assert new_brick_unpickled.padding == (2, 1)
assert new_brick_unpickled.ignore_border


def test_average_pooling():
x = tensor.tensor4('x')
brick = AveragePooling((2, 2))
y = brick.apply(x)
tmp = numpy.arange(16, dtype=theano.config.floatX).reshape(1, 1, 4, 4)
x_ = numpy.tile(tmp, [2, 3, 1, 1])
out = y.eval({x: x_})
assert_allclose(
out - numpy.array([[10 / 4., 18 / 4.], [42 / 4., 50 / 4.]]),
numpy.zeros_like(out))


def test_average_pooling_inc_padding():
x = tensor.tensor4('x')
brick = AveragePooling((2, 2), ignore_border=True, padding=(1, 1),
include_padding=True)
y = brick.apply(x)
output = y.eval({x: 3 * numpy.ones((1, 1, 2, 2),
dtype=theano.config.floatX)})
expected_out = numpy.array([0.75, 0.75, 0.75, 0.75]).reshape(1, 1, 2, 2)
assert_allclose(expected_out, output)


def test_average_pooling_exc_padding():
x = tensor.tensor4('x')
brick = AveragePooling((2, 2), ignore_border=True, padding=(1, 1),
include_padding=False)
y = brick.apply(x)
x_ = 3 * numpy.ones((1, 1, 2, 2), dtype=theano.config.floatX)
output = y.eval({x: x_})
assert_allclose(x_, output)


def test_pooling_works_in_convolutional_sequence():
x = tensor.tensor4('x')
brick = ConvolutionalSequence([AveragePooling((2, 2), step=(2, 2)),
MaxPooling((4, 4), step=(2, 2),
ignore_border=True)],
image_size=(16, 32), num_channels=3)
brick.allocate()
y = brick.apply(x)
out = y.eval({x: numpy.empty((2, 3, 16, 32), dtype=theano.config.floatX)})
assert out.shape == (2, 3, 3, 7)


def test_convolutional_layer():
x = tensor.tensor4('x')
num_channels = 4
Expand All @@ -147,7 +246,7 @@ def test_convolutional_layer():
x_val = numpy.ones((batch_size, num_channels, 17, 13),
dtype=theano.config.floatX)
assert_allclose(func(x_val), numpy.prod(filter_size) * num_channels *
numpy.ones((batch_size, num_filters, 5, 4)) + 5)
numpy.ones((batch_size, num_filters, 5, 3)) + 5)

assert_equal(conv.convolution.batch_size, batch_size)
assert_equal(conv.pooling.batch_size, batch_size)
Expand Down Expand Up @@ -178,7 +277,7 @@ def test_convolutional_sequence():
func = function([x], y)

x_val = numpy.ones((batch_size, 4, 17, 13), dtype=theano.config.floatX)
y_val = (numpy.ones((batch_size, 4, 4, 3)) *
y_val = (numpy.ones((batch_size, 4, 4, 2)) *
(9 * 4 + 5) * 4 * 5)
assert_allclose(func(x_val), y_val)

Expand Down

0 comments on commit 1d63fa8

Please sign in to comment.