# Introduction

[Recent work](https://arxiv.org/pdf/2112.00133.pdf) on binary neural networks proposes a new method for binary quantization. On the forward pass, the binary quantizer is a sign function, and its derivative is the indicator of the interval $(-B, B)$ where $B$ is a  hyperparameter. 

[Others](https://arxiv.org/pdf/1705.09283.pdf) have proposed using a scaling factor $\alpha$ on top of the sign function, and choosing $\alpha$ to minimize a squared average loss between the quantized and full-precision data. 

In this notebook I'm going to explain a potential issue I see with the first idea, which I will refer to as the PokeBNN quantizer. I will also propose a different implementation that elegantly addresses this issue and allows for scaling factors. The arguments will be primarily mathematical.

# Main Issue: [The FTC won't let me be](https://youtu.be/YVkUvmDQ3HY?t=82)

Derivatives are the fundamental ingredient of backpropogation, and the defining property of derivatives is the [Fundamental Theorem of Calculus](https://en.wikipedia.org/wiki/Fundamental_theorem_of_calculus#Formal_statements). It states that for a function $F$ with derivative $f$, we have
$$
\int_a^b f(x) dx = F(b) - F(a).
$$
When training quantized neural networks, our quantization functions can never perfectly obey this rule, since the derivative of any function with finitely many outputs is zero or undefined everywhere. However, when we construct our approximate gradients, it is best to keep them as close to the fundamentals of derivatives as possible. This is more than mathematical musing: The gradient informs the model approximately how much the function output will change in response to input changes, and that property is captured in its entirety in the above equation.

For example, for the straight through estimator, we have $f(x) = 1$ and $F(x) = \text{round}(x)$. We can see that
$$
\int_a^b f(x) dx = b - a
$$
differs from
$$
F(b) - F(a) = \text{round}(b) - \text{round}(a)
$$
by at most 1 for all $a$ and $b$.

The PokeBNN quantizer does poorly when examined in this light. For example, when $ a \leq -B$ and $b > B$, we have
$$
\int_a^b \frac{\partial POKE(x)}{\partial x} = 2B
$$
and
$$
POKE(b) - POKE(a) = 2.
$$
It appears we are off by a potentially significant factor here.




# A [more robust](https://www.youtube.com/watch?v=gAjR4_CbPpQ) implementation

Below I've proposed an alternate definition $POKE'$ that solves both this problem and the problem with binary representation.
$$
POKE'(x) = B * \left(\text{round}\left(\text{clip}\left( \frac{x}{B}, -0.5, 0.5 \right) - 0.5 \right) + 0.5\right)
$$

The forward pass of this function maps all nonnegative values to $B/2$ and all negative values to $-B/2$, with no exceptions. We can compute the derivative as follows:
$$
\frac{\partial POKE'(x)}{\partial x} = \cases{1 & -B/2 < x < B/2 \\ 0 & \text{otherwise.} }
$$
This has the nice property that for all $ a \leq -B/2$ and $b > B/2$, we have
$$
\int_a^b \frac{\partial POKE'(x)}{\partial x} = B
$$
and
$$
POKE'(b) - POKE'(a) = B,
$$
which is in line with the fundamental theorem of calculus.

We implement this function below.

In [3]:
import tensorflow as tf

def round_through(x):
  """Round function with straight through estimator"""

  return x + tf.stop_gradient(tf.round(x) - x)

def poke_prime(x, b):
  """POKE' auto-quantization function"""

  clipped_scaled_x = tf.clip_by_value(x / b, -0.5, 0.5)
  xq_scaled = round_through(clipped_scaled_x -0.5) + 0.5
  xq = b * xq_scaled

  return xq

def poke_prime_autoscale(x):
  """Auto-scaling for POKE'

  Note: The auto-scaling computation may differ slightly here, depending on
  the goal. Since the clip bounds have changed to 0.5, it may best to set $B$
  as twice the max absolute value of the data."""

  b = 2 * tf.reduce_max(tf.abs(x))
  return poke_prime(x, b)

# basic computations
x = tf.Variable([-5.0, -1.5, 0.0, 1.0, 6.0])
print(f'{x.numpy() = }')
print(f'{poke_prime_autoscale(x).numpy() = }')

# derivative computations
b = tf.constant(2.0)
with tf.GradientTape() as tape:
    y = poke_prime(x, b)
grads = tape.gradient(y, x)
print(f'For {b.numpy() = }, {grads.numpy() = }')

x.numpy() = array([-5. , -1.5,  0. ,  1. ,  6. ], dtype=float32)
poke_prime_autoscale(x).numpy() = array([-6., -6.,  6.,  6.,  6.], dtype=float32)
For b.numpy() = 2.0, grads.numpy() = array([0., 0., 1., 1., 0.], dtype=float32)


# Conclusion

It appears that this implementation is more mathematically sound, but we have yet to see if it actually improves model performance. I have an outstanding [PR](https://github.com/google/qkeras/pull/118/files#diff-77a8e3576b51d24701cad688de6b996e0ba87c6f569194ebf016e62745d936aeR816-R864) for qkeras that, once merged, will allow us to run experiments using that framework.

It would also be interesting to see this implemented in PokeBNN. Given the impressive results seen so far by that framework, it would be great to see if this change could bring about further improvements.

# Appendix: [Another Perspective](https://www.youtube.com/watch?v=d3sA5plF6kE) on $POKE'$

Here is another way to think about what the gradient of the quantizer "should" be. Perhaps the simplest (almost) smooth approximation to a function that maps all positive values to $B/2$ and all negative values to $-B/2$ is
$$
POKE^{\text{approx}}_B = \text{clip}(x, -B/2, B/2).
$$
The derivative of this function is
$$
\frac{\partial POKE^{\text{approx}}(x)}{\partial x} = \cases{1 & -B/2 < x < B/2 \\ 0 & \text{otherwise.} }
$$
This is the same as the derivative on $POKE'$.
