Skip to content

Commit

Permalink
Separable CNN implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
vloncar committed Oct 16, 2020
1 parent d6d0fe7 commit 1847f67
Showing 1 changed file with 336 additions and 1 deletion.
337 changes: 336 additions & 1 deletion qkeras/qconvolutional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
from tensorflow.keras.layers import Conv1D
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Conv2DTranspose
from tensorflow.keras.layers import SeparableConv1D
from tensorflow.keras.layers import SeparableConv2D
from tensorflow.keras.layers import DepthwiseConv2D
from tensorflow.keras.layers import Dropout
from tensorflow.keras.layers import InputSpec
Expand Down Expand Up @@ -509,6 +511,339 @@ def get_prunable_weights(self):
return [self.kernel]


class QSeparableConv1D(SeparableConv1D, PrunableLayer):
"""Depthwise separable 1D convolution."""

# most of these parameters follow the implementation of SeparableConv1D
# in Keras, with the exception of depthwise_quantizer, pointwise_quantizer
# and bias_quantizer.
#
# depthwise_quantizer: quantizer function/class for depthwise spatial kernel
# pointwise_quantizer: quantizer function/class for pointwise kernel
# bias_quantizer: quantizer function/class for bias
#
# we refer the reader to the documentation of SeparableConv1D in Keras for
# the other parameters.
#

def __init__(self,
filters,
kernel_size,
strides=1,
padding='valid',
data_format=None,
dilation_rate=1,
depth_multiplier=1,
activation=None,
use_bias=True,
depthwise_initializer='glorot_uniform',
pointwise_initializer='glorot_uniform',
bias_initializer='zeros',
depthwise_regularizer=None,
pointwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
depthwise_constraint=None,
pointwise_constraint=None,
bias_constraint=None,
depthwise_quantizer=None,
pointwise_quantizer=None,
bias_quantizer=None,
**kwargs):

self.depthwise_quantizer = depthwise_quantizer
self.pointwise_quantizer = pointwise_quantizer
self.bias_quantizer = bias_quantizer

self.depthwise_quantizer_internal = get_quantizer(self.depthwise_quantizer)
self.pointwise_quantizer_internal = get_quantizer(self.pointwise_quantizer)
self.bias_quantizer_internal = get_quantizer(self.bias_quantizer)

# optimize parameter set to "auto" scaling mode if possible
if hasattr(self.depthwise_quantizer_internal, "_set_trainable_parameter"):
self.depthwise_quantizer_internal._set_trainable_parameter()

if hasattr(self.pointwise_quantizer_internal, "_set_trainable_parameter"):
self.pointwise_quantizer_internal._set_trainable_parameter()

self.quantizers = [
self.depthwise_quantizer_internal, self.pointwise_quantizer_internal,
self.bias_quantizer_internal
]

depthwise_constraint, depthwise_initializer = (
get_auto_range_constraint_initializer(self.depthwise_quantizer_internal,
depthwise_constraint,
depthwise_initializer))

pointwise_constraint, pointwise_initializer = (
get_auto_range_constraint_initializer(self.pointwise_quantizer_internal,
pointwise_constraint,
pointwise_initializer))

if use_bias:
bias_constraint, bias_initializer = (
get_auto_range_constraint_initializer(self.bias_quantizer_internal,
bias_constraint,
bias_initializer))

if activation is not None:
activation = get_quantizer(activation)

super(QSeparableConv1D, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
depth_multiplier=depth_multiplier,
activation=activation,
use_bias=use_bias,
depthwise_initializer=initializers.get(depthwise_initializer),
pointwise_initializer=initializers.get(pointwise_initializer),
bias_initializer=initializers.get(bias_initializer),
depthwise_regularizer=regularizers.get(depthwise_regularizer),
pointwise_regularizer=regularizers.get(pointwise_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
depthwise_constraint=constraints.get(depthwise_constraint),
pointwise_constraint=constraints.get(pointwise_constraint),
bias_constraint=constraints.get(bias_constraint),
**kwargs)

def call(self, inputs):
if self.padding == 'causal':
inputs = array_ops.pad(inputs, self._compute_causal_padding())

spatial_start_dim = 1 if self.data_format == 'channels_last' else 2

# Explicitly broadcast inputs and kernels to 4D.
inputs = array_ops.expand_dims(inputs, spatial_start_dim)
depthwise_kernel = array_ops.expand_dims(self.depthwise_kernel, 0)
pointwise_kernel = array_ops.expand_dims(self.pointwise_kernel, 0)
dilation_rate = (1,) + self.dilation_rate

if self.padding == 'causal':
op_padding = 'valid'
else:
op_padding = self.padding

if self.depthwise_quantizer:
quantized_depthwise_kernel = self.depthwise_quantizer_internal(
depthwise_kernel)
else:
quantized_depthwise_kernel = depthwise_kernel

if self.pointwise_quantizer:
quantized_pointwise_kernel = self.pointwise_quantizer_internal(
pointwise_kernel)
else:
quantized_pointwise_kernel = pointwise_kernel

outputs = tf.keras.backend.separable_conv2d(
inputs,
quantized_depthwise_kernel,
quantized_pointwise_kernel,
strides=self.strides * 2,
padding=op_padding,
dilation_rate=dilation_rate,
data_format=self.data_format)

if self.use_bias:
if self.bias_quantizer:
quantized_bias = self.bias_quantizer_internal(self.bias)
else:
quantized_bias = self.bias

outputs = tf.keras.backend.bias_add(
outputs,
quantized_bias,
data_format=self.data_format)

outputs = array_ops.squeeze(outputs, [spatial_start_dim])

if self.activation is not None:
return self.activation(outputs)
return outputs

def get_config(self):
config = {
"depthwise_quantizer":
constraints.serialize(self.depthwise_quantizer_internal),
"pointwise_quantizer":
constraints.serialize(self.pointwise_quantizer_internal),
"bias_quantizer":
constraints.serialize(self.bias_quantizer_internal)
}
base_config = super(QSeparableConv1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return [self.depthwise_kernel, self.pointwise_kernel]


class QSeparableConv2D(SeparableConv2D, PrunableLayer):
"""Depthwise separable 2D convolution."""

# most of these parameters follow the implementation of SeparableConv2D
# in Keras, with the exception of depthwise_quantizer, pointwise_quantizer
# and bias_quantizer.
#
# depthwise_quantizer: quantizer function/class for depthwise spatial kernel
# pointwise_quantizer: quantizer function/class for pointwise kernel
# bias_quantizer: quantizer function/class for bias
#
# we refer the reader to the documentation of SeparableConv2D in Keras for
# the other parameters.
#

def __init__(self,
filters,
kernel_size,
strides=(1, 1),
padding='valid',
data_format=None,
dilation_rate=(1, 1),
depth_multiplier=1,
activation=None,
use_bias=True,
depthwise_initializer='glorot_uniform',
pointwise_initializer='glorot_uniform',
bias_initializer='zeros',
depthwise_regularizer=None,
pointwise_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
depthwise_constraint=None,
pointwise_constraint=None,
bias_constraint=None,
depthwise_quantizer=None,
pointwise_quantizer=None,
bias_quantizer=None,
**kwargs):

self.depthwise_quantizer = depthwise_quantizer
self.pointwise_quantizer = pointwise_quantizer
self.bias_quantizer = bias_quantizer

self.depthwise_quantizer_internal = get_quantizer(self.depthwise_quantizer)
self.pointwise_quantizer_internal = get_quantizer(self.pointwise_quantizer)
self.bias_quantizer_internal = get_quantizer(self.bias_quantizer)

# optimize parameter set to "auto" scaling mode if possible
if hasattr(self.depthwise_quantizer_internal, "_set_trainable_parameter"):
self.depthwise_quantizer_internal._set_trainable_parameter()

if hasattr(self.pointwise_quantizer_internal, "_set_trainable_parameter"):
self.pointwise_quantizer_internal._set_trainable_parameter()

self.quantizers = [
self.depthwise_quantizer_internal, self.pointwise_quantizer_internal,
self.bias_quantizer_internal
]

depthwise_constraint, depthwise_initializer = (
get_auto_range_constraint_initializer(self.depthwise_quantizer_internal,
depthwise_constraint,
depthwise_initializer))

pointwise_constraint, pointwise_initializer = (
get_auto_range_constraint_initializer(self.pointwise_quantizer_internal,
pointwise_constraint,
pointwise_initializer))

if use_bias:
bias_constraint, bias_initializer = (
get_auto_range_constraint_initializer(self.bias_quantizer_internal,
bias_constraint,
bias_initializer))

if activation is not None:
activation = get_quantizer(activation)

super(QSeparableConv2D, self).__init__(
filters=filters,
kernel_size=kernel_size,
strides=strides,
padding=padding,
data_format=data_format,
dilation_rate=dilation_rate,
depth_multiplier=depth_multiplier,
activation=activation,
use_bias=use_bias,
depthwise_initializer=initializers.get(depthwise_initializer),
pointwise_initializer=initializers.get(pointwise_initializer),
bias_initializer=initializers.get(bias_initializer),
depthwise_regularizer=regularizers.get(depthwise_regularizer),
pointwise_regularizer=regularizers.get(pointwise_regularizer),
bias_regularizer=regularizers.get(bias_regularizer),
activity_regularizer=regularizers.get(activity_regularizer),
depthwise_constraint=constraints.get(depthwise_constraint),
pointwise_constraint=constraints.get(pointwise_constraint),
bias_constraint=constraints.get(bias_constraint),
**kwargs)

def call(self, inputs):
# Apply the actual ops.
if self.depthwise_quantizer:
quantized_depthwise_kernel = self.depthwise_quantizer_internal(
self.depthwise_kernel)
else:
quantized_depthwise_kernel = self.depthwise_kernel

if self.pointwise_quantizer:
quantized_pointwise_kernel = self.pointwise_quantizer_internal(
self.pointwise_kernel)
else:
quantized_pointwise_kernel = self.pointwise_kernel

outputs = tf.keras.backend.separable_conv2d(
inputs,
quantized_depthwise_kernel,
quantized_pointwise_kernel,
strides=self.strides,
padding=self.padding,
dilation_rate=self.dilation_rate,
data_format=self.data_format)

if self.use_bias:
if self.bias_quantizer:
quantized_bias = self.bias_quantizer_internal(self.bias)
else:
quantized_bias = self.bias

outputs = tf.keras.backend.bias_add(
outputs,
quantized_bias,
data_format=self.data_format)

if self.activation is not None:
return self.activation(outputs)
return outputs

def get_config(self):
config = {
"depthwise_quantizer":
constraints.serialize(self.depthwise_quantizer_internal),
"pointwise_quantizer":
constraints.serialize(self.pointwise_quantizer_internal),
"bias_quantizer":
constraints.serialize(self.bias_quantizer_internal)
}
base_config = super(QSeparableConv2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))

def get_quantizers(self):
return self.quantizers

def get_prunable_weights(self):
return [self.depthwise_kernel, self.pointwise_kernel]


class QDepthwiseConv2D(DepthwiseConv2D, PrunableLayer):
"""Creates quantized depthwise conv2d. Copied from mobilenet."""

Expand Down Expand Up @@ -706,7 +1041,7 @@ def get_prunable_weights(self):
return []


def QSeparableConv2D(
def QMobileNetSeparableConv2D(
filters, # pylint: disable=invalid-name
kernel_size,
strides=(1, 1),
Expand Down

0 comments on commit 1847f67

Please sign in to comment.