/
quantizers.py
1756 lines (1452 loc) · 56.8 KB
/
quantizers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2019 Google LLC
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import warnings
import numpy as np
import six
from six.moves import range
import tensorflow.compat.v2 as tf
from tensorflow.keras import initializers
import tensorflow.keras.backend as K
from tensorflow.keras.utils import deserialize_keras_object
from tensorflow.python.keras.utils import tf_utils
from .safe_eval import safe_eval
#
# Library of auxiliary functions
#
def get_weight_scale(quantizer, x=None):
"""Gets the scales of weights for (stochastic_)binary and ternary quantizers.
Arguments:
quantizer: A binary or teneray quantizer class.
x: A weight tensor. We keep it here for now for backward compatibility.
Returns:
Weight scale per channel for binary and ternary
quantizers with auto or auto_po2 alpha/threshold.
"""
if hasattr(quantizer, "scale") and quantizer.scale is not None:
return K.eval(quantizer.scale)
return 1.0
def _get_scale(alpha, x, q):
"""Gets scaling factor for scaling the tensor per channel.
Arguments:
alpha: A float or string. When it is string, it should be either "auto" or
"auto_po2", and
scale = sum(x * q, axis=all but last) / sum(q * q, axis=all but last)
x: A tensor object. Its elements are in float.
q: A tensor object. Its elements are in quantized format of x.
Returns:
A scaling factor tensor or scala for scaling tensor per channel.
"""
if isinstance(alpha, six.string_types) and "auto" in alpha:
assert alpha in ["auto", "auto_po2"]
x_shape = x.shape.as_list()
len_axis = len(x_shape)
if len_axis > 1:
if K.image_data_format() == "channels_last":
axis = list(range(len_axis - 1))
else:
axis = list(range(1, len_axis))
qx = K.mean(tf.math.multiply(x, q), axis=axis, keepdims=True)
qq = K.mean(tf.math.multiply(q, q), axis=axis, keepdims=True)
else:
qx = K.mean(x * q, axis=0, keepdims=True)
qq = K.mean(q * q, axis=0, keepdims=True)
scale = qx / (qq + K.epsilon())
if alpha == "auto_po2":
scale = K.pow(2.0,
tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0)))
elif alpha is None:
scale = 1.0
elif isinstance(alpha, np.ndarray):
scale = alpha
else:
scale = float(alpha)
return scale
def smooth_sigmoid(x):
"""Implements a linear approximation of a sigmoid function."""
# if we use 2.65 as the clipping point, MSE w.r.t. original sigmoid is
# smaller than hard_simoid but the arithmetic for it is (x >> 3) +
# (x >> 4) + 0.5, which is also not bad.
return tf.keras.backend.clip(0.1875 * x + 0.5, 0.0, 1.0)
def hard_sigmoid(x):
"""Computes hard_sigmoid function that saturates between 0 and 1."""
return tf.keras.backend.clip(0.5 * x + 0.5, 0.0, 1.0)
def binary_sigmoid(x):
"""Computes binary_sigmoid."""
return _round_through(hard_sigmoid(x))
# we use a version of approximated sigmoid everywhere in this code.
# we can set it to hard_sigmoid(x) or smooth_sigmoid(x).
_sigmoid = hard_sigmoid
def set_internal_sigmoid(mode):
"""Sets _sigmoid to either real, hard or smooth."""
global _sigmoid
if mode not in ["real", "hard", "smooth"]:
raise ValueError("mode has to be 'hard' or 'smooth'.")
if mode == "hard":
_sigmoid = hard_sigmoid
elif mode == "smooth":
_sigmoid = smooth_sigmoid
elif mode == "real":
_sigmoid = tf.keras.backend.sigmoid
def binary_tanh(x):
"""Computes binary_tanh function that outputs -1 and 1."""
return 2.0 * binary_sigmoid(x) - 1.0
def hard_tanh(x):
"""Computes hard_tanh function that saturates between -1 and 1."""
return 2.0 * hard_sigmoid(x) - 1.0
def smooth_tanh(x):
"""Computes smooth_tanh function that saturates between -1 and 1."""
return 2.0 * smooth_sigmoid(x) - 1.0
def stochastic_round(x, precision=0.5):
"""Performs stochastic rounding to the first decimal point."""
scale = 1.0 / precision
scale_x = x * scale
fraction = scale_x - tf.floor(scale_x)
result = tf.where(fraction < tf.random.uniform(tf.shape(x)),
tf.math.floor(scale_x), tf.math.ceil(scale_x))
return result / scale
def stochastic_round_po2(x):
"""Performs stochastic rounding for the power of two."""
# TODO(hzhuang): test stochastic_round_po2 and constraint.
# because quantizer is applied after constraint.
y = tf.abs(x)
eps = tf.keras.backend.epsilon()
log2 = tf.keras.backend.log(2.0)
x_log2 = tf.round(tf.keras.backend.log(y + eps) / log2)
po2 = tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32")
left_val = tf.where(po2 > y, x_log2 - 1, x_log2)
right_val = tf.where(po2 > y, x_log2, x_log2 + 1)
# sampling in [2**left_val, 2**right_val].
minval = 2 ** left_val
maxval = 2 ** right_val
val = tf.random.uniform(tf.shape(y), minval=minval, maxval=maxval)
# use y as a threshold to keep the probabliy [2**left_val, y, 2**right_val]
# so that the mean value of the sample should be y
x_po2 = tf.where(y < val, left_val, right_val)
"""
x_log2 = stochastic_round(tf.keras.backend.log(y + eps) / log2)
sign = tf.sign(x)
po2 = (
tf.sign(x) *
tf.cast(pow(2.0, tf.cast(x_log2, dtype="float32")), dtype="float32")
)
"""
return x_po2
def _round_through(x, use_stochastic_rounding=False, precision=0.5):
"""Rounds x but using straight through estimator.
We use the trick from [Sergey Ioffe](http://stackoverflow.com/a/36480182).
Straight through estimator is a biased estimator for the rounding
operation defined by Hinton"s Coursera Lecture 9c where dL/dx is made
equal to dL/dy for y = f(x) during gradient computation, where f(x) is
a non-derivable function. In that case, we assume df/dx = 1 in:
dL dL df dL
-- = -- -- = --
dx df dx dy
(https://www.youtube.com/watch?v=LN0xtUuJsEI&list=PLoRl3Ht4JOcdU872GhiYWf6jwrk_SNhz9&index=41)
Arguments:
x: tensor to perform round operation with straight through gradient.
use_stochastic_rounding: if true, we perform stochastic rounding.
precision: by default we will use 0.5 as precision, but that can overriden
by the user.
Returns:
Rounded tensor.
"""
if use_stochastic_rounding:
output = tf_utils.smart_cond(
K.learning_phase(),
lambda: x + tf.stop_gradient(-x + stochastic_round(x, precision)),
lambda: x + tf.stop_gradient(-x + tf.round(x)))
else:
output = x + tf.stop_gradient(-x + tf.round(x))
return output
def _sign_through(x):
"""Computes the sign operation using the straight through estimator."""
# tf.sign generates -1, 0 or +1, so it should not be used when we attempt
# to generate -1 and +1.
k_sign = tf.sign(x)
return x + tf.stop_gradient(-x + k_sign)
def _ceil_through(x):
"""Computes the ceiling operation using straight through estimator."""
return x + tf.stop_gradient(-x + tf.ceil(x))
#
# Activation functions for quantized networks.
#
# Please note some of these functions can be used as well
# as quantizer functions for weights of dense and convolutional
# layers.
#
class BaseQuantizer(object):
"""Base quantizer
Defines behavior all quantizers should follow.
"""
def __init__(self):
pass
def _set_trainable_parameter(self):
pass
class quantized_bits(BaseQuantizer): # pylint: disable=invalid-name
"""Quantizes the number to a number of bits.
In general, we want to use a quantization function like:
a = (pow(2,bits) - 1 - 0) / (max(x) - min(x))
b = -min(x) * a
in the equation:
xq = a x + b
This requires multiplication, which is undesirable. So, we
enforce weights to be between -1 and 1 (max(x) = 1 and min(x) = -1),
and separating the sign from the rest of the number as we make this function
symmetric, thus resulting in the following approximation.
1) max(x) = +1, min(x) = -1
2) max(x) = -min(x)
a = pow(2,bits-1)
b = 0
Finally, just remember that to represent the number with sign, the
largest representation is -pow(2,bits) to pow(2, bits-1)
Symmetric and keep_negative allow us to generate numbers that are symmetric
(same number of negative and positive representations), and numbers that
are positive.
Note:
the behavior of quantized_bits is different than Catapult HLS ac_fixed
or Vivado HLS ap_fixed. For ac_fixed<word_length, integer_lenth, signed>,
when signed = true, it is equavlent to
quantized_bits(word_length, integer_length-1, keep_negative=1)
Attributes:
bits: number of bits to perform quantization.
integer: number of bits to the left of the decimal point.
symmetric: if true, we will have the same number of values for positive
and negative numbers.
alpha: a tensor or None, the scaling factor per channel.
If None, the scaling factor is 1 for all channels.
keep_negative: if true, we do not clip negative numbers.
use_stochastic_rounding: if true, we perform stochastic rounding.
Returns:
Function that computes fixed-point quantization with bits.
"""
def __init__(self, bits=8, integer=0, symmetric=0, keep_negative=1,
alpha=None, use_stochastic_rounding=False):
super(quantized_bits, self).__init__()
self.bits = bits
self.integer = integer
self.symmetric = symmetric
self.keep_negative = (keep_negative > 0)
self.alpha = alpha
self.use_stochastic_rounding = use_stochastic_rounding
# "auto*" |-> symmetric
if isinstance(self.alpha, six.string_types):
self.symmetric = True
self.scale = None
def __str__(self):
flags = [str(self.bits), str(self.integer), str(int(self.symmetric))]
if not self.keep_negative:
flags.append("keep_negative=" + str(int(self.keep_negative)))
if self.alpha:
alpha = str(self.alpha)
if isinstance(self.alpha, six.string_types):
alpha = "'" + alpha + "'"
flags.append("alpha=" + alpha)
if self.use_stochastic_rounding:
flags.append("use_stochastic_rounding=" +
str(int(self.use_stochastic_rounding)))
return "quantized_bits(" + ",".join(flags) + ")"
def __call__(self, x):
"""Computes fixedpoint quantization of x."""
# quantized_bits with "1" bit becomes a binary implementation.
unsigned_bits = self.bits - self.keep_negative
m = pow(2, unsigned_bits)
m_i = pow(2, self.integer)
if self.alpha is None:
scale = 1.0
elif isinstance(self.alpha, six.string_types):
# We only deal with the symmetric case right now.
assert self.symmetric
len_axis = len(x.shape)
if len_axis > 1:
if K.image_data_format() == "channels_last":
axis = list(range(len_axis - 1))
else:
axis = list(range(1, len_axis))
else:
axis = [0]
x = x / m_i
# we will use this implementation for the scale for QKeras 0.7
levels = 2**self.bits - 1
scale = (K.max(x, axis=axis, keepdims=True) -
K.min(x, axis=axis, keepdims=True)) / levels
if "po2" in self.alpha:
scale = K.pow(2.0,
tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0)))
for _ in range(5):
v = tf.floor(tf.abs(x) / scale + 0.5)
mask = v < (levels - 1) / 2
z = tf.sign(x) * tf.where(mask, v, tf.ones_like(v) * (levels - 1) / 2)
scale = _get_scale("auto_po2", x, z)
# z is an integer number, so we must make the scale * m and z / m
scale = scale * m
# we will not use "z" right now because of stochastic_rounding
# this is still under test.
# if "new" in self.alpha:
# z = z / m
# self.scale = scale
# return x + tf.stop_gradient(-x + scale * z)
x = m_i * x
xq = m_i * z / m
self.scale = scale
return x + tf.stop_gradient(-x + scale * xq)
else:
scale = self.alpha
# quantized_bits with "1" bit becomes a binary implementation.
if unsigned_bits > 0:
p = x * m / m_i
xq = m_i * tf.keras.backend.clip(
_round_through(p, self.use_stochastic_rounding, precision=1.0),
self.keep_negative * (-m + self.symmetric), m - 1) / m
else:
xq = tf.sign(x)
xq += (1.0 - tf.abs(xq))
if not self.keep_negative:
xq = (xq + 1.0) / 2.0
self.scale = scale
return x + tf.stop_gradient(-x + scale * xq)
def _set_trainable_parameter(self):
if self.alpha is None:
self.alpha = "auto_po2"
self.symmetric = True
def max(self):
"""Get maximum value that quantized_bits class can represent."""
unsigned_bits = self.bits - self.keep_negative
if unsigned_bits > 0:
return max(1.0, np.power(2.0, self.integer))
else:
return 1.0
def min(self):
"""Get minimum value that quantized_bits class can represent."""
if not self.keep_negative:
return 0.0
unsigned_bits = self.bits - self.keep_negative
if unsigned_bits > 0:
return -max(1.0, np.power(2.0, self.integer))
else:
return -1.0
@classmethod
def from_config(cls, config):
return cls(**config)
def get_config(self):
config = {
"bits": self.bits,
"integer": self.integer,
"symmetric": self.symmetric,
"alpha": self.alpha,
"keep_negative": self.keep_negative,
"use_stochastic_rounding": self.use_stochastic_rounding
}
return config
class bernoulli(object): # pylint: disable=invalid-name
"""Computes a Bernoulli sample with probability sigmoid(x).
This computation uses ST approximation.
To do that, we compute sigmoid(x) and a random sample z ~ U[0,1]. As
p in [0,1] and z in [0,1], p - z in [-1,1]. However, -1 will
never appear because to get -1 we would need sigmoid(-inf) - z == 1.
As a result, the range will be in practical terms [0,1].
The noise introduced by z can be seen as a regularizer to the weights W of
y = Wx as y = Wx + Wz for some noise z with mean mu(z) and var(z). As a
result, W**2 var(z) to the variance of y, which has the same effect as a
regularizer on L2 with lambda = var(z), as presented in Hinton"s Coursera
Lecture 9c.
Remember that E[dL/dy] = E[dL/dx] once we add stochastic sampling.
Attributes:
alpha: allows one to specify multiplicative factor for number generation
of "auto" or "auto_po2".
temperature: amplifier factor for sigmoid function, making stochastic
less stochastic as it moves away from 0.
use_real_sigmoid: use real sigmoid for probability.
Returns:
Computation of round with stochastic sampling with straight through
gradient.
"""
def __init__(self, alpha=None, temperature=6.0, use_real_sigmoid=True):
self.alpha = alpha
self.bits = 1
self.temperature = temperature
self.use_real_sigmoid = use_real_sigmoid
self.default_alpha = 1.0
self.scale = None
def __str__(self):
flags = []
if self.alpha is not None:
alpha = str(self.alpha)
if isinstance(self.alpha, six.string_types):
alpha = "'" + alpha + "'"
flags.append("alpha=" + alpha)
if self.temperature != 6.0:
flags.append("temperature=" + str(self.temperature))
if not self.use_real_sigmoid:
flags.append("use_real_sigmoid=" + str(int(self.use_real_sigmoid)))
return "bernoulli(" + ",".join(flags) + ")"
def __call__(self, x):
if isinstance(self.alpha, six.string_types):
assert self.alpha in ["auto", "auto_po2"]
if isinstance(self.alpha, six.string_types):
len_axis = len(x.shape)
if len_axis > 1:
if K.image_data_format() == "channels_last":
axis = list(range(len_axis - 1))
else:
axis = list(range(1, len_axis))
else:
axis = [0]
std = K.std(x, axis=axis, keepdims=True) + K.epsilon()
else:
std = 1.0
if self.use_real_sigmoid:
p = tf.keras.backend.sigmoid(self.temperature * x / std)
else:
p = _sigmoid(self.temperature * x/std)
r = tf.random.uniform(tf.shape(x))
q = tf.sign(p - r)
q += (1.0 - tf.abs(q))
q = (q + 1.0) / 2.0
q_non_stochastic = tf.sign(x)
q_non_stochastic += (1.0 - tf.abs(q_non_stochastic))
q_non_stochastic = (q_non_stochastic + 1.0) / 2.0
# if we use non stochastic binary to compute alpha,
# this function seems to behave better
scale = _get_scale(self.alpha, x, q_non_stochastic)
self.scale = scale
return x + tf.stop_gradient(-x + scale * q)
def _set_trainable_parameter(self):
if self.alpha is None:
self.alpha = "auto_po2"
def max(self):
"""Get the maximum value bernoulli class can represent."""
if self.alpha is None or isinstance(self.alpha, six.string_types):
return 1.0
else:
return max(1.0, self.alpha)
def min(self):
"""Get the minimum value bernoulli class can represent."""
return 0.0
@classmethod
def from_config(cls, config):
return cls(**config)
def get_config(self):
config = {"alpha": self.alpha}
return config
class ternary(BaseQuantizer): # pylint: disable=invalid-name
"""Computes an activation function returning -alpha, 0 or +alpha.
Right now we assume two type of behavior. For parameters, we should
have alpha, threshold and stochastic rounding on. For activations,
alpha and threshold should be floating point numbers, and stochastic
rounding should be off.
Attributes:
x: tensor to perform sign opertion with stochastic sampling.
bits: number of bits to perform quantization.
alpha: ternary is -alpha or +alpha. Alpha can be "auto" or "auto_po2".
threshold: threshold to apply "dropout" or dead band (0 value). If "auto"
is specified, we will compute it per output layer.
use_stochastic_rounding: if true, we perform stochastic rounding.
Returns:
Computation of sign within the threshold.
"""
def __init__(self, alpha=None, threshold=None, use_stochastic_rounding=False,
number_of_unrolls=5):
super(ternary, self).__init__()
self.bits = 2
self.alpha = alpha
self.threshold = threshold
self.use_stochastic_rounding = use_stochastic_rounding
self.default_alpha = 1.0
self.default_threshold = 0.33
self.number_of_unrolls = number_of_unrolls
self.scale = None
def __str__(self):
flags = []
if self.alpha is not None:
alpha = str(self.alpha)
if isinstance(self.alpha, six.string_types):
alpha = "'" + alpha + "'"
flags.append("alpha=" + alpha)
if self.threshold is not None:
flags.append("threshold=" + str(self.threshold))
if self.use_stochastic_rounding:
flags.append(
"use_stochastic_rounding=" + str(int(self.use_stochastic_rounding)))
if self.number_of_unrolls != 5:
flags.append(
"number_of_unrolls=" + str(int(self.number_of_unrolls)))
return "ternary(" + ",".join(flags) + ")"
def __call__(self, x):
if isinstance(self.alpha, six.string_types):
# parameters
assert self.alpha in ["auto", "auto_po2"]
assert self.threshold is None
else:
# activations
assert not self.use_stochastic_rounding
assert not isinstance(self.threshold, six.string_types)
if self.alpha is None or isinstance(self.alpha, six.string_types):
scale = 1.0
elif isinstance(self.alpha, np.ndarray):
scale = self.alpha
else:
scale = float(self.alpha)
# This is an approximiation from https://arxiv.org/abs/1605.04711
# We consider channels_last only for now.
if isinstance(self.alpha, six.string_types):
# It is for parameters
# first, compute which asix corresponds to the channels.
# TODO(hzhuang): support channels_first
len_axis = len(x.shape.as_list())
if len_axis == 1:
axis = None
elif K.image_data_format() == "channels_last":
axis = list(range(len_axis - 1))
else:
axis = list(range(1, len_axis))
# This approximation is exact if x ~ U[-m, m]. For x ~ N(0, m)
# we need to iterate a few times before we can coverge
m = K.max(tf.abs(x), axis=axis, keepdims=True)
scale = 2 * m / 3.0
if "po2" in self.alpha:
scale = K.pow(2.0,
tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0)))
for _ in range(self.number_of_unrolls):
thres = scale / 2.0
# once we scale the number precision == 0.33 works
# well for Uniform and Normal distribution of input
v = scale * _round_through(
x / scale,
use_stochastic_rounding=self.use_stochastic_rounding,
precision=1. / 3.)
q = K.cast(tf.abs(v) >= thres, K.floatx()) * tf.sign(x)
scale = _get_scale(self.alpha, x, q)
else:
if self.threshold is None:
thres = self.default_threshold
else:
thres = self.threshold
q = K.cast(tf.abs(x) >= thres, K.floatx()) * tf.sign(x)
# ternary ranges from -1 to +1, so we use tanh(x) to be a differentiable
# version of that.
if self.alpha is None:
x = K.tanh(x)
self.scale = scale
return x + tf.stop_gradient(-x + scale * q)
def _set_trainable_parameter(self):
if self.alpha is None:
self.alpha = "auto_po2"
def max(self):
"""Get the maximum value that ternary can respresent."""
if self.alpha is None or isinstance(self.alpha, six.string_types):
return 1.0
else:
return max(1.0, self.alpha)
def min(self):
"""Get the minimum value that ternary can respresent."""
if self.alpha is None or isinstance(self.alpha, six.string_types):
return -1.0
else:
return -max(1.0, self.alpha)
@classmethod
def from_config(cls, config):
return cls(**config)
def get_config(self):
config = {
"alpha": self.alpha,
"threshold": self.threshold,
"use_stochastic_rounding": self.use_stochastic_rounding,
"number_of_unrolls": self.number_of_unrolls
}
return config
class stochastic_ternary(ternary): # pylint: disable=invalid-name
"""Computes a stochastic activation function returning -alpha, 0 or +alpha.
Computes straight-through approximation using random sampling to make
E[dL/dy] = E[dL/dx], and computing the sign function. See explanation above.
Attributes:
x: tensor to perform sign opertion with stochastic sampling.
bits: number of bits to perform quantization.
alpha: ternary is -alpha or +alpha, or "auto" or "auto_po2".
threshold: (1-threshold) specifies the spread of the +1 and -1 values.
temperature: amplifier factor for sigmoid function, making stochastic
less stochastic as it moves away from 0.
use_real_sigmoid: use real sigmoid for probability.
number_of_unrolls: number of times we iterate between scale and threshold.
Returns:
Computation of sign with stochastic sampling with straight through gradient.
"""
def __init__(self, alpha=None, threshold=None, temperature=8.0,
use_real_sigmoid=True, number_of_unrolls=5):
super(stochastic_ternary, self).__init__(
alpha=alpha,
threshold=threshold,
number_of_unrolls=number_of_unrolls)
self.bits = 2
self.alpha = alpha
self.threshold = threshold
assert threshold != 1.0
self.default_alpha = 1.0
self.default_threshold = 0.33
self.temperature = temperature
self.use_real_sigmoid = use_real_sigmoid
self.number_of_unrolls = number_of_unrolls
self.scale = None
def __str__(self):
flags = []
if self.alpha is not None:
alpha = str(self.alpha)
if isinstance(self.alpha, six.string_types):
alpha = "'" + alpha + "'"
flags.append("alpha=" + alpha)
if self.threshold is not None:
flags.append("threshold=" + str(self.threshold))
if self.temperature != 8.0:
flags.append("temperature=" + str(self.temperature))
if not self.use_real_sigmoid:
flags.append("use_real_sigmoid=0")
if self.number_of_unrolls != 5:
flags.append("number_of_unrolls=" + str(self.number_of_unrolls))
return "stochastic_ternary(" + ",".join(flags) + ")"
def __call__(self, x):
def stochastic_output():
# right now we only accept alpha = "auto" or "auto_po2"
assert isinstance(self.alpha, six.string_types)
assert self.alpha in ["auto", "auto_po2"]
if self.alpha is None:
scale = self.default_alpha
elif isinstance(self.alpha, six.string_types):
scale = 1.0
assert self.alpha in ["auto", "auto_po2"]
else:
assert self.alpha >= 0.0
scale = float(self.alpha)
len_axis = len(x.shape)
if len_axis > 1:
if K.image_data_format() == "channels_last":
axis = list(range(len_axis - 1))
else:
axis = list(range(1, len_axis))
else:
axis = [0]
x_std = K.std(x, axis=axis, keepdims=True)
m = K.max(tf.abs(x), axis=axis, keepdims=True)
scale = 2.*m/3.
if self.alpha == "auto_po2":
scale = K.pow(2.0,
tf.math.round(K.log(scale + K.epsilon()) / np.log(2.0)))
for _ in range(self.number_of_unrolls):
T = scale / 2.0
q_ns = K.cast(tf.abs(x) >= T, K.floatx()) * K.sign(x)
scale = _get_scale(self.alpha, x, q_ns)
x_norm = x / (x_std + K.epsilon())
T = scale / (2.0 * (x_std + K.epsilon()))
if self.use_real_sigmoid:
p0 = tf.keras.backend.sigmoid(self.temperature * (x_norm - T))
p1 = tf.keras.backend.sigmoid(self.temperature * (x_norm + T))
else:
p0 = _sigmoid(self.temperature * (x_norm - T))
p1 = _sigmoid(self.temperature * (x_norm + T))
r0 = tf.random.uniform(tf.shape(p0))
r1 = tf.random.uniform(tf.shape(p1))
q0 = tf.sign(p0 - r0)
q0 += (1.0 - tf.abs(q0))
q1 = tf.sign(p1 - r1)
q1 += (1.0 - tf.abs(q1))
q = (q0 + q1) / 2.0
self.scale = scale
return x + tf.stop_gradient(-x + scale * q)
output = tf_utils.smart_cond(
K.learning_phase(),
stochastic_output,
lambda: ternary.__call__(self, x))
return output
def _set_trainable_parameter(self):
if self.alpha is None:
self.alpha = "auto_po2"
def max(self):
"""Get the maximum value that stochastic_ternary can respresent."""
if self.alpha is None or isinstance(self.alpha, six.string_types):
return 1.0
else:
return max(1.0, self.alpha)
def min(self):
"""Get the minimum value that stochastic_ternary can respresent."""
if self.alpha is None or isinstance(self.alpha, six.string_types):
return -1.0
else:
return -max(1.0, self.alpha)
@classmethod
def from_config(cls, config):
return cls(**config)
def get_config(self):
config = {
"alpha": self.alpha,
"threshold": self.threshold,
"temperature": self.temperature,
"use_real_sigmoid": self.use_real_sigmoid,
"number_of_unrolls": self.number_of_unrolls
}
return config
class binary(BaseQuantizer): # pylint: disable=invalid-name
"""Computes the sign(x) returning a value between -alpha and alpha.
Although we cannot guarantee E[dL/dy] = E[dL/dx] if we do not use the
stochastic sampling, we still use the ST approximation.
Modified from original binary to match QNN implementation.
Attributes:
x: tensor to perform sign_through.
bits: number of bits to perform quantization.
use_01: if True, return {0,1} instead of {-1,+1}.
alpha: binary is -alpha or +alpha, or "auto", "auto_po2" to compute
automatically.
use_stochastic_rounding: if true, we perform stochastic rounding.
Returns:
Computation of sign operation with straight through gradient.
"""
def __init__(self, use_01=False, alpha=None, use_stochastic_rounding=False):
super(binary, self).__init__()
self.use_01 = use_01
self.bits = 1
self.alpha = alpha
self.use_stochastic_rounding = use_stochastic_rounding
self.default_alpha = 1.0
self.scale = None
def __str__(self):
flags = []
if self.use_01:
flags.append("use_01=" + str(int(self.use_01)))
if self.alpha is not None:
alpha = str(self.alpha)
if isinstance(self.alpha, six.string_types):
alpha = "'" + alpha + "'"
flags.append("alpha=" + alpha)
if self.use_stochastic_rounding:
flags.append(
"use_stochastic_rounding=" + str(self.use_stochastic_rounding))
return "binary(" + ",".join(flags) + ")"
def __call__(self, x):
if isinstance(self.alpha, six.string_types):
assert self.alpha in ["auto", "auto_po2"]
if self.alpha is None:
scale = self.default_alpha
elif isinstance(self.alpha, six.string_types):
scale = 1.0
elif isinstance(self.alpha, np.ndarray):
scale = self.alpha
else:
scale = float(self.alpha)
if self.use_stochastic_rounding:
len_axis = len(x.shape.as_list())
if len_axis == 1:
axis = None
elif K.image_data_format() == "channels_last":
axis = list(range(len_axis - 1))
else:
axis = list(range(1, len_axis))
# if stochastic_round is through, we need to scale
# number so that the precision is small enough.
# This is especially important if range of x is very
# small, which occurs during initialization of weights.
m = K.max(tf.abs(x), axis=axis, keepdims=True)
m = tf.where(m > 1.0, tf.ones_like(m), m)
f = 2 * m
x = tf_utils.smart_cond(
K.learning_phase(),
lambda: f * _round_through(
x / f, use_stochastic_rounding=True, precision=0.125),
lambda: x)
k_sign = tf.sign(x)
if self.use_stochastic_rounding:
# in inference, we use a biased "1" for stochastic rounding right now
k_sign += (1.0 - tf.abs(k_sign)) * tf_utils.smart_cond(
K.learning_phase(),
lambda: 2.0 * tf.round(tf.random.uniform(tf.shape(x))) - 1.0,
lambda: tf.ones_like(tf.shape(x), dtype=K.floatx()))
# if something still remains, just make it positive for now.
k_sign += (1.0 - tf.abs(k_sign))
if self.use_01:
k_sign = (k_sign + 1.0) / 2.0
# approximate binary by tanh(x) as it has limited range between -1 and +1.
if self.alpha is None:
x = K.tanh(x)
scale = _get_scale(self.alpha, x, k_sign)
self.scale = scale
return x + tf.stop_gradient(-x + scale * k_sign)
def _set_trainable_parameter(self):
if self.alpha is None:
self.alpha = "auto_po2"
def max(self):
"""Get maximum value that binary class can respresent."""
if self.alpha is None or isinstance(self.alpha, six.string_types):
return 1.0
else:
return max(1.0, self.alpha)
def min(self):
"""Get minimum value that binary class can respresent."""
if self.use_01:
return 0.0
elif self.alpha is None or isinstance(self.alpha, six.string_types):
return -1.0
else:
return -max(1.0, self.alpha)
@classmethod
def from_config(cls, config):
return cls(**config)
def get_config(self):
config = {
"use_01": self.use_01,
"alpha": self.alpha,
"use_stochastic_rounding": self.use_stochastic_rounding
}
return config
class stochastic_binary(binary): # pylint: disable=invalid-name
"""Computes a stochastic activation function returning -alpha or +alpha.
Computes straight-through approximation using random sampling to make
E[dL/dy] = E[dL/dx], and computing the sign function. See explanation above.
Attributes: