Skip to content

Commit

Permalink
Made an abstract class for upsampling. (#11065)
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse authored and fchollet committed Sep 5, 2018
1 parent 23c20f7 commit c2a6caa
Showing 1 changed file with 61 additions and 68 deletions.
129 changes: 61 additions & 68 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,7 +1868,52 @@ def get_config(self):
return config


class UpSampling1D(Layer):
class _UpSampling(Layer):
"""Abstract nD UpSampling layer (private, used as implementation base).
# Arguments
size: Tuple of ints.
data_format: A string,
one of `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch, ..., channels)` while `"channels_first"` corresponds to
inputs with shape `(batch, channels, ...)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
"""
def __init__(self, size, data_format=None, **kwargs):
# self.rank is 1 for UpSampling1D, 2 for UpSampling2D.
self.rank = len(size)
self.size = size
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=self.rank + 2)
super(_UpSampling, self).__init__(**kwargs)

def call(self, inputs):
raise NotImplementedError

def compute_output_shape(self, input_shape):
size_all_dims = (1,) + self.size + (1,)
spatial_axes = list(range(1, 1 + self.rank))
size_all_dims = transpose_shape(size_all_dims,
self.data_format,
spatial_axes)
output_shape = list(input_shape)
for dim in range(len(output_shape)):
if output_shape[dim] is not None:
output_shape[dim] *= size_all_dims[dim]
return tuple(output_shape)

def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(_UpSampling, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class UpSampling1D(_UpSampling):
"""Upsampling layer for 1D inputs.
Repeats each temporal step `size` times along the time axis.
Expand All @@ -1885,25 +1930,20 @@ class UpSampling1D(Layer):

@interfaces.legacy_upsampling1d_support
def __init__(self, size=2, **kwargs):
super(UpSampling1D, self).__init__(**kwargs)
self.size = int(size)
self.input_spec = InputSpec(ndim=3)

def compute_output_shape(self, input_shape):
size = self.size * input_shape[1] if input_shape[1] is not None else None
return (input_shape[0], size, input_shape[2])
super(UpSampling1D, self).__init__((int(size),), 'channels_last', **kwargs)

def call(self, inputs):
output = K.repeat_elements(inputs, self.size, axis=1)
output = K.repeat_elements(inputs, self.size[0], axis=1)
return output

def get_config(self):
config = {'size': self.size}
base_config = super(UpSampling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super(UpSampling1D, self).get_config()
config['size'] = self.size[0]
config.pop('data_format')
return config


class UpSampling2D(Layer):
class UpSampling2D(_UpSampling):
"""Upsampling layer for 2D inputs.
Repeats the rows and columns of the data
Expand Down Expand Up @@ -1943,43 +1983,24 @@ class UpSampling2D(Layer):

@interfaces.legacy_upsampling2d_support
def __init__(self, size=(2, 2), data_format=None, interpolation='nearest', **kwargs):
super(UpSampling2D, self).__init__(**kwargs)
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 2, 'size')
self.input_spec = InputSpec(ndim=4)
normalized_size = conv_utils.normalize_tuple(size, 2, 'size')
super(UpSampling2D, self).__init__(normalized_size, data_format, **kwargs)
if interpolation not in ['nearest', 'bilinear']:
raise ValueError('interpolation should be one '
'of "nearest" or "bilinear".')
self.interpolation = interpolation

def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
height = self.size[0] * input_shape[2] if input_shape[2] is not None else None
width = self.size[1] * input_shape[3] if input_shape[3] is not None else None
return (input_shape[0],
input_shape[1],
height,
width)
elif self.data_format == 'channels_last':
height = self.size[0] * input_shape[1] if input_shape[1] is not None else None
width = self.size[1] * input_shape[2] if input_shape[2] is not None else None
return (input_shape[0],
height,
width,
input_shape[3])

def call(self, inputs):
return K.resize_images(inputs, self.size[0], self.size[1],
self.data_format, self.interpolation)

def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(UpSampling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
config = super(UpSampling2D, self).get_config()
config['interpolation'] = self.interpolation
return config


class UpSampling3D(Layer):
class UpSampling3D(_UpSampling):
"""Upsampling layer for 3D inputs.
Repeats the 1st, 2nd and 3rd dimensions
Expand Down Expand Up @@ -2016,42 +2037,14 @@ class UpSampling3D(Layer):

@interfaces.legacy_upsampling3d_support
def __init__(self, size=(2, 2, 2), data_format=None, **kwargs):
self.data_format = K.normalize_data_format(data_format)
self.size = conv_utils.normalize_tuple(size, 3, 'size')
self.input_spec = InputSpec(ndim=5)
super(UpSampling3D, self).__init__(**kwargs)

def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
dim1 = self.size[0] * input_shape[2] if input_shape[2] is not None else None
dim2 = self.size[1] * input_shape[3] if input_shape[3] is not None else None
dim3 = self.size[2] * input_shape[4] if input_shape[4] is not None else None
return (input_shape[0],
input_shape[1],
dim1,
dim2,
dim3)
elif self.data_format == 'channels_last':
dim1 = self.size[0] * input_shape[1] if input_shape[1] is not None else None
dim2 = self.size[1] * input_shape[2] if input_shape[2] is not None else None
dim3 = self.size[2] * input_shape[3] if input_shape[3] is not None else None
return (input_shape[0],
dim1,
dim2,
dim3,
input_shape[4])
normalized_size = conv_utils.normalize_tuple(size, 3, 'size')
super(UpSampling3D, self).__init__(normalized_size, data_format, **kwargs)

def call(self, inputs):
return K.resize_volumes(inputs,
self.size[0], self.size[1], self.size[2],
self.data_format)

def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(UpSampling3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class _ZeroPadding(Layer):
"""Abstract nD ZeroPadding layer (private, used as implementation base).
Expand Down

0 comments on commit c2a6caa

Please sign in to comment.