Skip to content

Commit

Permalink
Add SeparableConv1D (#8851)
Browse files Browse the repository at this point in the history
* Add SeparableConv1D

* Minor updates
  • Loading branch information
taehoonlee authored and fchollet committed Dec 26, 2017
1 parent 12a060f commit c57bba2
Show file tree
Hide file tree
Showing 5 changed files with 425 additions and 31 deletions.
5 changes: 5 additions & 0 deletions keras/backend/cntk_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,6 +1502,11 @@ def conv2d(x, kernel, strides=(1, 1), padding='valid',
return _postprocess_conv2d_output(x, data_format)


def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
padding='valid', data_format=None, dilation_rate=1):
raise NotImplementedError


def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
padding='valid', data_format=None, dilation_rate=(1, 1)):
raise NotImplementedError
Expand Down
72 changes: 72 additions & 0 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3162,6 +3162,27 @@ def in_top_k(predictions, targets, k):
# CONVOLUTIONS


def _preprocess_conv1d_input(x, data_format):
"""Transpose and cast the input before the conv1d.
# Arguments
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
# Returns
A tensor.
"""
if dtype(x) == 'float64':
x = tf.cast(x, 'float32')
tf_data_format = 'NHWC' # to pass TF Conv2dNative operations
if data_format == 'channels_first':
if not _has_nchw_support():
x = tf.transpose(x, (0, 2, 1)) # NCW -> NWC
else:
tf_data_format = 'NCHW'
return x, tf_data_format


def _preprocess_conv2d_input(x, data_format):
"""Transpose and cast the input before the conv2d.
Expand Down Expand Up @@ -3362,6 +3383,57 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
return x


def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
padding='valid', data_format=None, dilation_rate=1):
"""1D convolution with separable filters.
# Arguments
x: input tensor
depthwise_kernel: convolution kernel for the depthwise convolution.
pointwise_kernel: kernel for the 1x1 convolution.
strides: stride integer.
padding: string, `"same"` or `"valid"`.
data_format: string, `"channels_last"` or `"channels_first"`.
dilation_rate: integer dilation rate.
# Returns
Output tensor.
# Raises
ValueError: if `data_format` is neither `channels_last` or `channels_first`.
"""
if data_format is None:
data_format = image_data_format()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('Unknown data_format ' + str(data_format))

x, tf_data_format = _preprocess_conv1d_input(x, data_format)
padding = _preprocess_padding(padding)
if tf_data_format == 'NHWC':
spatial_start_dim = 1
strides = (1, 1) + strides + (1,)
else:
spatial_start_dim = 2
strides = (1, 1, 1) + strides
x = tf.expand_dims(x, spatial_start_dim)
depthwise_kernel = tf.expand_dims(depthwise_kernel, 0)
pointwise_kernel = tf.expand_dims(pointwise_kernel, 0)
dilation_rate = (1,) + dilation_rate

x = tf.nn.separable_conv2d(x, depthwise_kernel, pointwise_kernel,
strides=strides,
padding=padding,
rate=dilation_rate,
data_format=tf_data_format)

x = tf.squeeze(x, [spatial_start_dim])

if data_format == 'channels_first' and tf_data_format == 'NHWC':
x = tf.transpose(x, (0, 2, 1)) # NWC -> NCW

return x


def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
padding='valid', data_format=None, dilation_rate=(1, 1)):
"""2D convolution with separable filters.
Expand Down
5 changes: 5 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1972,6 +1972,11 @@ def conv2d_transpose(x, kernel, output_shape, strides=(1, 1),
return conv_out


def separable_conv1d(x, depthwise_kernel, pointwise_kernel, strides=1,
padding='valid', data_format=None, dilation_rate=1):
raise NotImplementedError


def separable_conv2d(x, depthwise_kernel, pointwise_kernel, strides=(1, 1),
padding='valid', data_format=None, dilation_rate=(1, 1)):
raise NotImplementedError
Expand Down

0 comments on commit c57bba2

Please sign in to comment.