Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
a441234
Added source code and docstring of new kernel quantizer.
Joschua-Conrad Apr 15, 2021
425dcfa
Fixed two typos
Joschua-Conrad Apr 16, 2021
1056c44
Added unit test for backwards compat alias of DoReFa kernel quantizer
Joschua-Conrad Apr 16, 2021
6215854
Added test for metrics of DoReFaKernel quantizer
Joschua-Conrad Apr 16, 2021
7569bae
Fixed bug in gradient computation of new kernel quantizer
Joschua-Conrad Apr 16, 2021
20c0956
Added test for gradient of new kernel quantizer
Joschua-Conrad Apr 16, 2021
0d3b04d
Explicitly using tf.math now
Joschua-Conrad Apr 16, 2021
a43ba40
Added quantization test for new kernel quantizer
Joschua-Conrad Apr 16, 2021
168ccfc
Removed checks for 0. in calls of tf.math.tanh
Joschua-Conrad Apr 16, 2021
2ce76f3
Compat alias name of new kernel quantizer added to __all__ quantizers.
Joschua-Conrad Apr 16, 2021
7dd5149
Fixed typo and one bad indent
Joschua-Conrad Apr 16, 2021
4ab5622
I hope this pleases the almightly linter. Only comments or blank line…
Joschua-Conrad Apr 16, 2021
58f07d4
black passes now its check after rewriting files with black
Joschua-Conrad Apr 16, 2021
deeb716
Remove alias: unittest
Joschua-Conrad Apr 27, 2021
57ba9c2
Remove alias: Definition
Joschua-Conrad Apr 27, 2021
b9549bc
Remove alias: Quantizerlist
Joschua-Conrad Apr 27, 2021
376597e
Improve readbility in weight preprocessing definition
Joschua-Conrad Apr 27, 2021
426c64e
Improved code readability
Joschua-Conrad Apr 27, 2021
b4a6e29
Divisions by zero are now skipped by using tf.math.divide_no_nan
Joschua-Conrad Apr 29, 2021
d545af6
Apply suggested changes on docstring syntax
Joschua-Conrad Apr 29, 2021
a26aba7
Fix docstring typo
Joschua-Conrad May 3, 2021
6791ad3
Remove linefeed
Joschua-Conrad May 3, 2021
9b85bc2
Move scale factor into division
Joschua-Conrad May 3, 2021
82df95f
Kernel quantizer now tested for multiple bitwidths.
Joschua-Conrad May 3, 2021
26a6119
Moved kernel quantizer logic and comments into activation quantizer
Joschua-Conrad May 4, 2021
fe00e6c
Moved docstrings from kernel to generic DoReFa quantizer and removed …
Joschua-Conrad May 4, 2021
924509e
Made old unittests pass again
Joschua-Conrad May 4, 2021
ad92f23
Added mode attribute to get_config
Joschua-Conrad May 4, 2021
a80acbd
Unified DoReFa acitvation and weight quantization test
Joschua-Conrad May 4, 2021
2bca107
Also merged test routines for gradients of DoReFa activation and weig…
Joschua-Conrad May 4, 2021
50f83fc
Apply suggestions from code review
Joschua-Conrad May 5, 2021
363c60c
Renamed kernel to weights
Joschua-Conrad May 5, 2021
b3fbc42
Added tests for both error messages regarding DoReFa quantizer mode
Joschua-Conrad May 5, 2021
ef76f89
Mode error messages now list available modes.
Joschua-Conrad May 5, 2021
b4bfd47
Added metrics test case for DoReFa quantizer in weights mode
Joschua-Conrad May 5, 2021
6b87b2a
Made black happy
Joschua-Conrad May 5, 2021
1b4d22a
Fixed unittest import order
Joschua-Conrad May 5, 2021
507cd97
Apply suggestions from code review
lgeiger May 5, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 92 additions & 6 deletions larq/quantizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,47 +558,133 @@ class DoReFa(_BaseQuantizer):
0 & \text{else}
\end{cases}\\]

The behavior for quantizing weights should be different in comparison to
the quantization of activations:
instead of limiting input operands (or in this case: weights) using a hard
limiter, a tangens hyperbolicus is applied to achieve a softer limiting
with a gradient, which is continuously differentiable itself.

\\[
w_{lim}(w) = \tanh(w)
\\]

Furthermore, the weights of each layer are normed, such that the weight with
the largest magnitude gets the largest or smallest (depending on its sign)
quantizable value. That way, the full quantizable numeric range is utilized.

\\[
w_{norm}(w) = \frac{w}{\max(|w|)}
\\]

The formulas can be found in the paper in section 2.3. Please note, that
the paper refers to weights being quantized on a numeric range of [-1, 1], while
activations are quantized on the numeric range [0, 1]. This implementation
uses the same ranges as specified in the paper.

The activation quantizer defines the function quantizek() from the paper with
the correct numeric range of [0, 1]. The weight quantization mode adds
pre- and post-processing for numeric range adaptions, soft limiting and
norming. The full quantization function including the adaption of numeric ranges is

\\[
q(w) = 2 \, quantize_{k}(\frac{w_{norm}\left(w_{lim}\left(w\right)\right)}{2} + \frac{1}{2}) - 1
\\]

!!! warning
While the DoReFa paper describes how to do quantization for both weights and
activations, this implementation is only valid for activations, and this
quantizer should therefore not be used as a kernel quantizer.
The weight mode works for weights on the range [-1, 1], which matches the
default setting of `constraints.weight_clip`. Do not use this quantizer
with a different constraint `clip_value` than the default one.

```plot-activation
quantizers.DoReFa
```

# Arguments
k_bit: number of bits for the quantization.
mode: `"activations"` for clipping inputs on [0, 1] range or `"weights"` for
soft-clipping and norming weights on [-1, 1] range before applying
quantization.
metrics: An array of metrics to add to the layer. If `None` the metrics set in
`larq.context.metrics_scope` are used. Currently only the `flip_ratio`
metric is available.

# Returns
Quantization function

# Raises
ValueError for bad value of `mode`.

# References
- [DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low
Bitwidth Gradients](https://arxiv.org/abs/1606.06160)
"""
precision = None

def __init__(self, k_bit: int = 2, **kwargs):
def __init__(self, k_bit: int = 2, mode: str = "activations", **kwargs):
self.precision = k_bit

if mode not in ("activations", "weights"):
raise ValueError(
f"Invalid DoReFa quantizer mode {mode}. "
"Valid values are 'activations' and 'weights'."
)
self.mode = mode

super().__init__(**kwargs)

def weight_preprocess(self, inputs):
# Limit inputs to [-1, 1] range
limited = tf.math.tanh(inputs)

# Divider for max-value norm.
dividend = tf.math.reduce_max(tf.math.abs(limited))

# Need to stop the gradient here. Otherwise, for the maximum element,
# which gives the dividend, normed is limited/limited (for this one
# maximum digit). The derivative of y = x/x, dy/dx is just zero, when
# one does the simplification y = x/x = 1. But TF does NOT do this
# simplification when computing the gradient for the
# normed = limited/dividend operation. As a result, this gradient
# becomes complicated, because during the computation, "dividend" is
# not just a constant, but depends on "limited" instead. Here,
# tf.stop_gradient is used to mark "dividend" as a constant explicitly.
dividend = tf.stop_gradient(dividend)

# Norm and then scale from value range [-1,1] to [0,1] (the range
# expected by the core quantization operation).
# If the dividend used for the norm operation is 0, all elements of
# the weight tensor are 0 and divide_no_nan returns 0 for all weights.
# So if all elements of the weight tensor are zero, nothing is normed.
return tf.math.divide_no_nan(limited, 2.0 * dividend) + 0.5

def call(self, inputs):
inputs = tf.clip_by_value(inputs, 0.0, 1.0)
# Depending on quantizer mode (activation or weight) just clip inputs
# on [0, 1] range or use weight preprocessing method.
if self.mode == "activations":
inputs = tf.clip_by_value(inputs, 0.0, 1.0)
elif self.mode == "weights":
inputs = self.weight_preprocess(inputs)
else:
raise ValueError(
f"Invalid DoReFa quantizer mode {self.mode}. "
"Valid values are 'activations' and 'weights'."
)

@tf.custom_gradient
def _k_bit_with_identity_grad(x):
n = 2 ** self.precision - 1
return tf.round(x * n) / n, lambda dy: dy

outputs = _k_bit_with_identity_grad(inputs)

# Scale weights from [0, 1] quantization range back to [-1,1] range
if self.mode == "weights":
outputs = 2.0 * outputs - 1.0

return super().call(outputs)

def get_config(self):
return {**super().get_config(), "k_bit": self.precision}
return {**super().get_config(), "k_bit": self.precision, "mode": self.mode}


# `DoReFa` used to be called `DoReFaQuantizer`; this alias is for
Expand Down
50 changes: 41 additions & 9 deletions larq/quantizers_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools

import numpy as np
import pytest
import tensorflow as tf
Expand Down Expand Up @@ -66,6 +68,12 @@ def test_invalid_usage(self):
lq.quantizers.get(42)
with pytest.raises(ValueError):
lq.quantizers.get("unknown")
with pytest.raises(ValueError):
lq.quantizers.DoReFa(k_bit=2, mode="unknown")
f = lq.quantizers.DoReFa(k_bit=2, mode="activations")
f.mode = "unknown"
with pytest.raises(ValueError):
f.call([0.0])

@pytest.mark.parametrize("quantizer", ["input_quantizer", "kernel_quantizer"])
def test_layer_as_quantizer(self, quantizer, keras_should_run_eagerly):
Expand Down Expand Up @@ -216,22 +224,34 @@ def test_ternarization_with_ternary_weight_networks(self):
assert not np.any(result > 1)
assert not np.any(result < -1)

def test_dorefa_quantize(self):
@pytest.mark.parametrize("k_bit", [1, 2, 4, 6, 8])
@pytest.mark.parametrize("mode", ["activations", "weights"])
def test_dorefa_quantize(self, k_bit, mode):
x = tf.keras.backend.placeholder(ndim=2)
f = tf.keras.backend.function([x], [lq.quantizers.DoReFa(2)(x)])
f = tf.keras.backend.function([x], [lq.quantizers.DoReFa(k_bit, mode)(x)])
real_values = testing_utils.generate_real_values_with_zeros()
result = f([real_values])[0]
k_bit = 2
n = 2 ** k_bit - 1
if mode == "weights":
# Create the preprocessed and scaled stimulus, which is then ready to
# go through the same test like for the activation quantizer
divider = np.amax(np.abs(np.tanh(real_values)))
real_values = np.tanh(real_values) / divider
real_values = (real_values / 2.0) + 0.5
# The results, which are currently on [-1, 1] range get the same
# scaling, so they behave like they were created on the activation
# range and can be tested like that
result = result / 2.0 + 0.5
assert not np.any(result > 1)
assert not np.any(result < 0)
for i in range(n + 1):
assert np.all(
np.testing.assert_allclose(
result[
(real_values > (2 * i - 1) / (2 * n))
& (real_values < (2 * i + 1) / (2 * n))
]
== i / n
],
i / n,
atol=1e-6,
)


Expand Down Expand Up @@ -325,19 +345,30 @@ def test_magnitude_aware_sign_grad(self):
grad.numpy(), np.where(abs(a) < 1, np.ones(a.shape) * scale_vector, 0)
)

def test_dorefa_ste_grad(self):
@pytest.mark.parametrize("mode", ["activations", "weights"])
def test_dorefa_ste_grad(self, mode):
@np.vectorize
def ste_grad(x):
if x <= 1 and x >= 0:
return 1.0
return 0.0

def tanh_grad(x):
# 1/(cosh**2) is the derivative of tanh. The gradients of the
# scaling operations cancel each other and the gradient of the
# quantizek function is supposed to be 1 everywhere, because it
# is used on its linear region only. tanh does all the limiting.
dividend = np.amax(np.abs(np.tanh(x)))
return 1 / (np.cosh(x) ** 2.0) / dividend

expected_gradient = ste_grad if mode == "activations" else tanh_grad

x = testing_utils.generate_real_values_with_zeros(shape=(8, 3, 3, 16))
tf_x = tf.Variable(x)
with tf.GradientTape() as tape:
activation = lq.quantizers.DoReFa(2)(tf_x)
activation = lq.quantizers.DoReFa(2, mode)(tf_x)
grad = tape.gradient(activation, tf_x)
np.testing.assert_allclose(grad.numpy(), ste_grad(x))
np.testing.assert_allclose(grad.numpy(), expected_gradient(x))


@pytest.mark.parametrize(
Expand All @@ -350,6 +381,7 @@ def ste_grad(x):
("magnitude_aware_sign", lq.quantizers.MagnitudeAwareSign),
("ste_tern", lq.quantizers.SteTern),
("dorefa_quantizer", lq.quantizers.DoReFa),
("dorefa_quantizer", functools.partial(lq.quantizers.DoReFa, mode="weights")),
],
)
def test_metrics(quantizer):
Expand Down