Skip to content

Commit

Permalink
Refactoring: Made an abstract class for all the cropping layers. (#10888
Browse files Browse the repository at this point in the history
)

* Made an abstract class for all the cropping layers.

* Added the rank in the config of _Cropping.

* Made self.cropping be a tuple of tuple in all cases (1d, 2d and 3d).

* Removed the rank argument in _Cropping, allowing to remove the
overwrides of get_config.
  • Loading branch information
gabrieldemarmiesse authored and fchollet committed Aug 21, 2018
1 parent 6746bda commit 23c40e6
Showing 1 changed file with 83 additions and 83 deletions.
166 changes: 83 additions & 83 deletions keras/layers/convolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2310,7 +2310,68 @@ def get_config(self):
return dict(list(base_config.items()) + list(config.items()))


class Cropping1D(Layer):
class _Cropping(Layer):
"""Abstract nD copping layer (private, used as implementation base).
# Arguments
cropping: A tuple of tuples of 2 ints.
data_format: A string,
one of `"channels_last"` or `"channels_first"`.
The ordering of the dimensions in the inputs.
`"channels_last"` corresponds to inputs with shape
`(batch, ..., channels)` while `"channels_first"` corresponds to
inputs with shape `(batch, channels, ...)`.
It defaults to the `image_data_format` value found in your
Keras config file at `~/.keras/keras.json`.
If you never set it, then it will be "channels_last".
For Cropping1D, the data format is always `"channels_last"`.
"""

def __init__(self, cropping,
data_format=None,
**kwargs):
super(_Cropping, self).__init__(**kwargs)
# self.rank is 1 for Cropping1D, 2 for Cropping2D...
self.rank = len(cropping)
self.cropping = cropping
self.data_format = K.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=2 + self.rank)

def call(self, inputs):
slices_dims = []
for start, end in self.cropping:
if end == 0:
end = None
else:
end = -end
slices_dims.append(slice(start, end))

slices = [slice(None)] + slices_dims + [slice(None)]
slices = tuple(slices)
spatial_axes = list(range(1, 1 + self.rank))
slices = transpose_shape(slices, self.data_format, spatial_axes)
return inputs[slices]

def compute_output_shape(self, input_shape):
cropping_all_dims = ((0, 0),) + self.cropping + ((0, 0),)
spatial_axes = list(range(1, 1 + self.rank))
cropping_all_dims = transpose_shape(cropping_all_dims,
self.data_format,
spatial_axes)
output_shape = list(input_shape)
for dim in range(len(output_shape)):
if output_shape[dim] is not None:
output_shape[dim] -= sum(cropping_all_dims[dim])
return tuple(output_shape)

def get_config(self):
config = {'cropping': self.cropping,
'data_format': self.data_format}
base_config = super(_Cropping, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class Cropping1D(_Cropping):
"""Cropping layer for 1D input (e.g. temporal sequence).
It crops along the time dimension (axis 1).
Expand All @@ -2330,25 +2391,19 @@ class Cropping1D(Layer):
"""

def __init__(self, cropping=(1, 1), **kwargs):
super(Cropping1D, self).__init__(**kwargs)
self.cropping = conv_utils.normalize_tuple(cropping, 2, 'cropping')
self.input_spec = InputSpec(ndim=3)

def compute_output_shape(self, input_shape):
return _compute_output_shape_cropping(input_shape,
'channels_last',
(self.cropping,))

def call(self, inputs):
return _call_cropping(inputs, 'channels_last', (self.cropping,))
normalized_cropping = (conv_utils.normalize_tuple(cropping, 2, 'cropping'),)
super(Cropping1D, self).__init__(normalized_cropping,
'channels_last',
**kwargs)

def get_config(self):
config = {'cropping': self.cropping}
base_config = super(Cropping1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
base_config.pop('data_format')
base_config['cropping'] = base_config['cropping'][0]
return base_config


class Cropping2D(Layer):
class Cropping2D(_Cropping):
"""Cropping layer for 2D input (e.g. picture).
It crops along spatial dimensions, i.e. height and width.
Expand Down Expand Up @@ -2406,10 +2461,8 @@ class Cropping2D(Layer):
@interfaces.legacy_cropping2d_support
def __init__(self, cropping=((0, 0), (0, 0)),
data_format=None, **kwargs):
super(Cropping2D, self).__init__(**kwargs)
self.data_format = K.normalize_data_format(data_format)
if isinstance(cropping, int):
self.cropping = ((cropping, cropping), (cropping, cropping))
normalized_cropping = ((cropping, cropping), (cropping, cropping))
elif hasattr(cropping, '__len__'):
if len(cropping) != 2:
raise ValueError('`cropping` should have two elements. '
Expand All @@ -2420,32 +2473,20 @@ def __init__(self, cropping=((0, 0), (0, 0)),
width_cropping = conv_utils.normalize_tuple(
cropping[1], 2,
'2nd entry of cropping')
self.cropping = (height_cropping, width_cropping)
normalized_cropping = (height_cropping, width_cropping)
else:
raise ValueError('`cropping` should be either an int, '
'a tuple of 2 ints '
'(symmetric_height_crop, symmetric_width_crop), '
'or a tuple of 2 tuples of 2 ints '
'((top_crop, bottom_crop), (left_crop, right_crop)). '
'Found: ' + str(cropping))
self.input_spec = InputSpec(ndim=4)

def compute_output_shape(self, input_shape):
return _compute_output_shape_cropping(input_shape,
self.data_format,
self.cropping)

def call(self, inputs):
return _call_cropping(inputs, self.data_format, self.cropping)

def get_config(self):
config = {'cropping': self.cropping,
'data_format': self.data_format}
base_config = super(Cropping2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
super(Cropping2D, self).__init__(normalized_cropping,
data_format,
**kwargs)


class Cropping3D(Layer):
class Cropping3D(_Cropping):
"""Cropping layer for 3D data (e.g. spatial or spatio-temporal).
# Arguments
Expand Down Expand Up @@ -2488,12 +2529,11 @@ class Cropping3D(Layer):
@interfaces.legacy_cropping3d_support
def __init__(self, cropping=((1, 1), (1, 1), (1, 1)),
data_format=None, **kwargs):
super(Cropping3D, self).__init__(**kwargs)
self.data_format = K.normalize_data_format(data_format)
if isinstance(cropping, int):
self.cropping = ((cropping, cropping),
(cropping, cropping),
(cropping, cropping))
normalized_cropping = ((cropping, cropping),
(cropping, cropping),
(cropping, cropping))
elif hasattr(cropping, '__len__'):
if len(cropping) != 3:
raise ValueError('`cropping` should have 3 elements. '
Expand All @@ -2504,7 +2544,7 @@ def __init__(self, cropping=((1, 1), (1, 1), (1, 1)),
'2nd entry of cropping')
dim3_cropping = conv_utils.normalize_tuple(cropping[2], 2,
'3rd entry of cropping')
self.cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
normalized_cropping = (dim1_cropping, dim2_cropping, dim3_cropping)
else:
raise ValueError('`cropping` should be either an int, '
'a tuple of 3 ints '
Expand All @@ -2514,49 +2554,9 @@ def __init__(self, cropping=((1, 1), (1, 1), (1, 1)),
' (left_dim2_crop, right_dim2_crop),'
' (left_dim3_crop, right_dim2_crop)). '
'Found: ' + str(cropping))
self.input_spec = InputSpec(ndim=5)

def compute_output_shape(self, input_shape):
return _compute_output_shape_cropping(input_shape,
self.data_format,
self.cropping)

def call(self, inputs):
return _call_cropping(inputs, self.data_format, self.cropping)

def get_config(self):
config = {'cropping': self.cropping,
'data_format': self.data_format}
base_config = super(Cropping3D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


def _call_cropping(inputs, data_format, cropping):
slices_dims = []
for start, end in cropping:
if end == 0:
end = None
else:
end = -end
slices_dims.append(slice(start, end))

slices = [slice(None)] + slices_dims + [slice(None)]
slices = tuple(slices)
spatial_axes = list(range(1, 1 + len(cropping)))
slices = transpose_shape(slices, data_format, spatial_axes)
return inputs[slices]


def _compute_output_shape_cropping(input_shape, data_format, cropping):
cropping_all_dims = ((0, 0),) + cropping + ((0, 0),)
spatial_axes = list(range(1, 1 + len(cropping)))
cropping_all_dims = transpose_shape(cropping_all_dims, data_format, spatial_axes)

output_shape = list(input_shape)
for dim in range(len(output_shape)):
if output_shape[dim] is not None:
output_shape[dim] -= sum(cropping_all_dims[dim])
return tuple(output_shape)
super(Cropping3D, self).__init__(normalized_cropping,
data_format,
**kwargs)


# Aliases
Expand Down

0 comments on commit 23c40e6

Please sign in to comment.