In [1]:
import tensorflow as tf
import tensorflow.contrib.eager as tfe
import tensorflow.keras as keras
import numpy as np
from riptide.binary.bit_approximations import load_clusters, load_bits
tf.enable_eager_execution()

In [2]:
def get_quantize_bits(x):
    if len(x.shape) > 1:
        mean = tf.reduce_mean(tf.abs(tf.reshape(x, [x.shape[0], -1])), axis=-1)
    else:
        mean = tf.reduce_mean(tf.abs(x))
    # Fix dimensions of mean
    for i in range(len(x.shape) - 1):
        mean = tf.expand_dims(mean, axis=-1)
    bits = tf.cast(x >= 0, tf.float32)
    bits = (2*bits) - 1
    return mean, bits

@tf.custom_gradient
def Quantize(x):
    mean, bits = get_quantize_bits(x)
    y = mean * bits
    
    def grad_fn(dy):
        dx = dy * tf.cast(tf.abs(x) <= 1, tf.float32)
        return [dx]
    
    return y, grad_fn

def get_HWGQ_bits(x, clusters):
    # Computes HWG quantization and returns the integer binary value.
    for i in range(len(x.shape)):
        # need to reshape clusters properly.
        clusters = tf.expand_dims(clusters, axis=0)
    # Add new data axis for proper subtraction.
    x = tf.expand_dims(x, axis=-1)
    
    # Compute best fitting cluster for each value in data.
    distance = tf.abs(x - clusters)
    indices = tf.argmin(distance, axis=-1)
    return indices

@tf.custom_gradient
def HWGQuantize(x, clusters):
    indices = get_HWGQ_bits(x, clusters)
    y = tf.gather(clusters, indices)
    
    def grad_fn(dy):
        max_cluster = tf.reduce_max(clusters)
        min_cluster = tf.reduce_min(clusters)
        grad_filter = tf.logical_and(min_cluster <= x, x <= max_cluster)
        dx = dy * tf.cast(grad_filter, tf.float32)
        return [dx, None]
    
    return y, grad_fn

In [13]:
from tensorflow.python.ops import nn

"""Quantization scope, defines the modification of operator"""
class Config(object):
    """Configuration scope of current mode.

    This is used to easily switch between different
    model structure variants by simply calling into these functions.

    Parameters
    ----------
    actQ : function
        Activation quantization

    WeightQ : function: name->function
        Maps name to quantize function.

    activation : string: name of type of activation to apply,
        None for no activation.
        
    trace_conv_bn: list
        If set, keeps track of each layer of the network.

    Example
    -------
    import qnn

    with qnn.Config(actQ=qnn.quantize(bits=8, scale=8, signed=True),
                    weightQ=qnn.identity,
                    use_bn=True):
        net = qnn.get_model(model_name, **kwargs)
    """
    current = None

    def __init__(self,
                 actQ=None,
                 weightQ=None,
                 activation=None,
                 clusters=None):
        self.actQ = actQ if actQ else lambda _, x : x
        self.weightQ = weightQ if weightQ else lambda _, x : x
        self.clusters = clusters

    def __enter__(self):
        self._old_manager = Config.current
        Config.current = self
        return self

    def __exit__(self, ptype, value, trace):
        Config.current = self._old_manager
        
class Conv2D(keras.layers.Conv2D):
    def __init__(self, *args, **kwargs):
        super(Conv2D, self).__init__(*args, **kwargs)
        self.scope = Config.current
        self.actQ = self.scope.actQ
        self.weightQ = self.scope.weightQ
        self.clusters = self.scope.clusters
    def call(self, inputs):
        if self.clusters is not None:
            inputs = self.actQ(inputs, clusters)
        else:
            inputs = self.actQ(inputs)
        kernel = self.weightQ(self.kernel)
        outputs = self._convolution_op(inputs, kernel)
        
        if self.use_bias:
            if self.data_format == 'channels_first':
                outputs = nn.bias_add(outputs, self.bias, data_format='NCHW')
            else:
                outputs = nn.bias_add(outputs, self.bias, data_format='NHWC')
        
        if self.activation is not None:
            return self.activation(outputs)
        return outputs        

In [29]:
clusters = load_clusters(1)
actQ = HWGQuantize
weightQ = Quantize
config = Config(actQ=actQ, weightQ=weightQ, clusters=clusters)

with config:
    op = Conv2D(8, 1, use_bias=False)

In [30]:
conv_data = tf.abs(tf.random_normal(shape=[4, 3, 16, 16]))
out = op(conv_data)

In [34]:
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    out_np = sess.run(out)

In [35]:
out_np

array([[[[ 7.50138104e-01, -1.25902057e-01,  1.65894186e+00,
          -1.29762125e+00, -8.66185069e-01, -9.32881236e-01,
           2.93057382e-01, -5.90425313e-01],
         [ 3.56550038e-01,  6.63847327e-01, -6.21592402e-01,
          -9.93205726e-01, -9.97219443e-01,  1.38616562e-03,
          -1.42939091e-02,  5.19096971e-01],
         [-2.51245499e-03,  4.70334321e-01,  1.29171550e+00,
          -7.15964675e-01, -6.44425273e-01, -9.72674131e-01,
           1.02150381e-01,  1.29098856e+00],
         [-9.78860378e-01,  6.49110079e-02,  4.36541498e-01,
          -2.40718460e+00, -1.96089756e+00, -2.07617059e-01,
           4.11892176e-01,  1.26819134e+00],
         [-2.41220117e-01,  9.49452162e-01,  1.45131433e+00,
          -2.28948450e+00, -2.00630844e-01, -1.35889220e+00,
          -6.89986348e-02, -2.97427177e-03],
         [ 1.01694250e+00, -9.83077437e-02, -2.16761991e-01,
          -1.32944751e+00, -1.82339042e-01, -2.05806673e-01,
           2.57584959e-01,  1.31746829e+00]

In [3]:
data = tf.abs(tf.random_normal(shape=[2, 10]))

In [4]:
output = HWGQuantize(data, clusters)
#grad_quantize = tfe.gradients_function(HWGQuantize)
dy = tf.gradients(output, data)

In [11]:
with tf.Session() as sess:
    data_np, grad_np = sess.run([data, dy])

In [12]:
data_np

array([[0.14214745, 0.50518364, 0.02391862, 1.0399977 , 0.73923993,
        0.389728  , 0.4237679 , 0.11800878, 1.5546753 , 0.34932214],
       [0.7204177 , 0.21333353, 0.5899574 , 1.9131157 , 1.7082008 ,
        0.3200416 , 1.3184633 , 0.6811728 , 1.5945169 , 2.2726355 ]],
      dtype=float32)

In [13]:
grad_np

[array([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 0., 1., 1., 1., 1., 1., 0.]], dtype=float32)]

In [3]:
load_clusters(1)

array([0.        , 1.42882085])