Skip to content

Commit

Permalink
Merge pull request #946 from vdumoulin/abstract_conv
Browse files Browse the repository at this point in the history
ConvolutionTranspose
  • Loading branch information
dwf committed Jan 23, 2016
2 parents 98797a1 + 648646d commit f5ee622
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 17 deletions.
120 changes: 106 additions & 14 deletions blocks/bricks/conv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from theano.tensor.nnet.conv import conv2d, get_conv_output_shape
from theano.tensor.signal.downsample import max_pool_2d, DownsampleFactorMax
from theano import tensor
from theano.tensor.nnet import conv2d
from theano.tensor.nnet.abstract_conv import (AbstractConv2d_gradInputs,
get_conv_output_shape)
from theano.tensor.signal.pool import pool_2d, Pool

from blocks.bricks import Initializable, Feedforward, Sequence
from blocks.bricks.base import application, Brick, lazy
Expand Down Expand Up @@ -49,7 +52,7 @@ class Convolutional(Initializable):
# to leverage features not yet available in Theano's standard conv2d.
# The function you override with here should accept at least the
# input and the kernels as positionals, and the keyword arguments
# image_shape, subsample, border_mode, and filter_shape. If some of
# input_shape, subsample, border_mode, and filter_shape. If some of
# these are unsupported they should still be accepted and ignored,
# e.g. with a wrapper function that swallows **kwargs.
conv2d_impl = staticmethod(conv2d)
Expand Down Expand Up @@ -140,14 +143,14 @@ def apply(self, input_):
W, = self.parameters

if self.image_size == (None, None):
image_shape = None
input_shape = None
else:
image_shape = (self.batch_size, self.num_channels)
image_shape += self.image_size
input_shape = (self.batch_size, self.num_channels)
input_shape += self.image_size

output = self.conv2d_impl(
input_, W,
image_shape=image_shape,
input_shape=input_shape,
subsample=self.step,
border_mode=self.border_mode,
filter_shape=((self.num_filters, self.num_channels) +
Expand All @@ -163,10 +166,10 @@ def get_dim(self, name):
if name == 'input_':
return (self.num_channels,) + self.image_size
if name == 'output':
image_shape = (None, self.num_channels) + self.image_size
input_shape = (None, self.num_channels) + self.image_size
kernel_shape = ((self.num_filters, self.num_channels) +
self.filter_size)
out_shape = self.get_output_shape(image_shape, kernel_shape,
out_shape = self.get_output_shape(input_shape, kernel_shape,
self.border_mode, self.step)
assert len(out_shape) == 4
return out_shape[1:]
Expand All @@ -177,6 +180,61 @@ def num_output_channels(self):
return self.num_filters


class ConvolutionalTranspose(Convolutional):
"""Performs the transpose of a 2D convolution.
Parameters
----------
original_image_size : tuple
The height and width of the image that forms the output of
the transpose operation, which is the input of the original
(non-transposed) convolution.
num_filters : int
Number of filters at the *output* of the transposed convolution,
i.e. the number of channels in the corresponding convolution.
num_channels : int
Number of channels at the *input* of the transposed convolution,
i.e. the number of output filters in the corresponding
convolution.
step : tuple, optional
The step (or stride) of the corresponding *convolution*.
Defaults to (1, 1).
image_size : tuple, optional
Image size of the input to the *transposed* convolution, i.e.
the output of the corresponding convolution. Required for tied
biases. Defaults to ``None``.
See Also
--------
:class:`Convolutional` : For the documentation of other parameters.
"""
@lazy(allocation=['original_image_size', 'filter_size', 'num_filters',
'num_channels'])
def __init__(self, original_image_size, filter_size, num_filters,
num_channels, **kwargs):
super(ConvolutionalTranspose, self).__init__(
filter_size, num_filters, num_channels, **kwargs)
self.original_image_size = original_image_size

def conv2d_impl(self, input_, W, input_shape, subsample, border_mode,
filter_shape):
# The AbstractConv2d_gradInputs op takes a kernel that was used for the
# **convolution**. We therefore have to invert num_channels and
# num_filters for W.
W = W.transpose(1, 0, 2, 3)
imshp = (None,) + self.get_dim('output')
kshp = (filter_shape[1], filter_shape[0]) + filter_shape[2:]
return AbstractConv2d_gradInputs(
imshp=imshp, kshp=kshp, border_mode=border_mode,
subsample=subsample)(W, input_, self.get_dim('output')[1:])

def get_dim(self, name):
if name == 'output':
return (self.num_filters,) + self.original_image_size
return super(ConvolutionalTranspose, self).get_dim(name)


class Pooling(Initializable, Feedforward):
"""Base Brick for pooling operations.
Expand Down Expand Up @@ -230,16 +288,16 @@ def apply(self, input_):
with the last two dimensions downsampled.
"""
output = max_pool_2d(input_, self.pooling_size, st=self.step,
mode=self.mode, padding=self.padding,
ignore_border=self.ignore_border)
output = pool_2d(input_, self.pooling_size, st=self.step,
mode=self.mode, padding=self.padding,
ignore_border=self.ignore_border)
return output

def get_dim(self, name):
if name == 'input_':
return self.input_dim
if name == 'output':
return tuple(DownsampleFactorMax.out_shape(
return tuple(Pool.out_shape(
self.input_dim, self.pooling_size, st=self.step,
ignore_border=self.ignore_border, padding=self.padding))

Expand Down Expand Up @@ -372,7 +430,7 @@ class ConvolutionalActivation(_AllocationMixin, Sequence, Initializable):
def __init__(self, activation, filter_size, num_filters, num_channels,
batch_size=None, image_size=None, step=(1, 1),
border_mode='valid', tied_biases=False, **kwargs):
self.convolution = Convolutional()
self._build_convolution()

self.filter_size = filter_size
self.num_filters = num_filters
Expand All @@ -387,6 +445,9 @@ def __init__(self, activation, filter_size, num_filters, num_channels,
application_methods=[self.convolution.apply, activation],
**kwargs)

def _build_convolution(self):
self.convolution = Convolutional()

def get_dim(self, name):
# TODO The name of the activation output doesn't need to be `output`
return self.convolution.get_dim(name)
Expand All @@ -396,6 +457,37 @@ def _push_allocation_config(self):
self.convolution.step = self.step


class ConvolutionalTransposeActivation(ConvolutionalActivation):
"""A transposed convolution followed by an activation function.
Parameters
----------
activation : :class:`.BoundApplication`
The application method to apply after convolution (i.e.
the nonlinear activation function)
See Also
--------
:class:`ConvolutionalTranspose` : For the documentation of other
parameters.
"""
@lazy(allocation=['original_image_size', 'filter_size', 'num_filters',
'num_channels'])
def __init__(self, activation, original_image_size, filter_size,
num_filters, num_channels, **kwargs):
super(ConvolutionalTransposeActivation, self).__init__(
activation, filter_size, num_filters, num_channels, **kwargs)
self.original_image_size = original_image_size

def _build_convolution(self):
self.convolution = ConvolutionalTranspose()

def _push_allocation_config(self):
super(ConvolutionalTransposeActivation, self)._push_allocation_config()
self.convolution.original_image_size = self.original_image_size


class ConvolutionalSequence(Sequence, Initializable, Feedforward):
"""A sequence of convolutional (or pooling) operations.
Expand Down
63 changes: 60 additions & 3 deletions tests/bricks/test_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@
from theano import tensor
from theano import function

from blocks.bricks import Rectifier
from blocks.bricks.conv import (Convolutional, MaxPooling, AveragePooling,
ConvolutionalActivation, ConvolutionalSequence)
from blocks.bricks import Rectifier, Tanh
from blocks.bricks.conv import (Convolutional, ConvolutionalTranspose,
MaxPooling, AveragePooling,
ConvolutionalActivation,
ConvolutionalTransposeActivation,
ConvolutionalSequence)
from blocks.initialization import Constant
from blocks.graph import ComputationGraph

Expand Down Expand Up @@ -37,6 +40,33 @@ def test_convolutional():
assert conv.get_dim('output') == (num_filters, 15, 11)


def test_convolutional_transpose():
x = tensor.tensor4('x')
num_channels = 4
num_filters = 3
image_size = (8, 6)
original_image_size = (17, 13)
batch_size = 5
filter_size = (3, 3)
step = (2, 2)
conv = ConvolutionalTranspose(
original_image_size, filter_size, num_filters, num_channels, step=step,
image_size=image_size, weights_init=Constant(1.),
biases_init=Constant(5.))
conv.initialize()
y = conv.apply(x)
func = function([x], y)

x_val = numpy.ones((batch_size, num_channels) + image_size,
dtype=theano.config.floatX)
expected_value = num_channels * numpy.ones(
(batch_size, num_filters) + original_image_size)
expected_value[:, :, 2:-2:2, :] += num_channels
expected_value[:, :, :, 2:-2:2] += num_channels
expected_value[:, :, 2:-2:2, 2:-2:2] += num_channels
assert_allclose(func(x_val), expected_value + 5)


def test_border_mode_not_pushed():
layers = [Convolutional(border_mode='full'),
ConvolutionalActivation(Rectifier().apply),
Expand Down Expand Up @@ -260,6 +290,33 @@ def test_convolutional_activation_use_bias():
assert len(ComputationGraph([act.apply(tensor.tensor4())]).parameters) == 1


def test_convolutional_transpose_activation():
x = tensor.tensor4('x')
num_channels = 4
num_filters = 3
image_size = (8, 6)
original_image_size = (17, 13)
batch_size = 5
filter_size = (3, 3)
step = (2, 2)
conv = ConvolutionalTransposeActivation(
Tanh().apply, original_image_size, filter_size, num_filters,
num_channels, step=step, image_size=image_size,
weights_init=Constant(1.), biases_init=Constant(5.))
conv.initialize()
y = conv.apply(x)
func = function([x], y)

x_val = numpy.ones((batch_size, num_channels) + image_size,
dtype=theano.config.floatX)
expected_value = num_channels * numpy.ones(
(batch_size, num_filters) + original_image_size)
expected_value[:, :, 2:-2:2, :] += num_channels
expected_value[:, :, :, 2:-2:2] += num_channels
expected_value[:, :, 2:-2:2, 2:-2:2] += num_channels
assert_allclose(func(x_val), numpy.tanh(expected_value + 5))


def test_convolutional_sequence_use_bias():
cnn = ConvolutionalSequence(
[ConvolutionalActivation(activation=Rectifier().apply,
Expand Down

0 comments on commit f5ee622

Please sign in to comment.