/
metrics.py
182 lines (155 loc) · 6.19 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""We add metrics specific to extremely quantized networks using a `scope` rather than
through the `metrics` parameter of `model.compile()`, where most common metrics reside.
This is because, to calculate metrics like the `flip_ratio`, we need a layer's kernel or
activation and not just the `y_true` and `y_pred` that Keras passes to metrics defined
in the usual way.
"""
import tensorflow as tf
from larq import utils
import numpy as np
from contextlib import contextmanager
try:
from tensorflow.keras.metrics import Metric
except: # pragma: no cover
# TensorFlow 1.13 doesn't export this as a public API
from tensorflow.python.keras.metrics import Metric
__all__ = ["scope", "get_training_metrics"]
_GLOBAL_TRAINING_METRICS = set()
_AVAILABLE_METRICS = {"flip_ratio"}
@contextmanager
def scope(metrics=[]):
"""A context manager to set the training metrics to be used in layers.
!!! example
```python
with larq.metrics.scope(["flip_ratio"]):
model = tf.keras.models.Sequential(
[larq.layers.QuantDense(3, kernel_quantizer="ste_sign", input_shape=(32,))]
)
model.compile(loss="mse", optimizer="sgd")
```
# Arguments
metrics: Iterable of metrics to add to layers defined inside this context.
Currently only the `flip_ratio` metric is available.
"""
for metric in metrics:
if metric not in _AVAILABLE_METRICS:
raise ValueError(
f"Unknown training metric '{metric}'. Available metrics: {_AVAILABLE_METRICS}."
)
backup = _GLOBAL_TRAINING_METRICS.copy()
_GLOBAL_TRAINING_METRICS.update(metrics)
yield _GLOBAL_TRAINING_METRICS
_GLOBAL_TRAINING_METRICS.clear()
_GLOBAL_TRAINING_METRICS.update(backup)
def get_training_metrics():
"""Retrieves a live reference to the training metrics in the current scope.
Updating and clearing training metrics using `larq.metrics.scope` is preferred,
but `get_training_metrics` can be used to directly access them.
!!! example
```python
get_training_metrics().clear()
get_training_metrics().add("flip_ratio")
```
# Returns
A set of training metrics in the current scope.
"""
return _GLOBAL_TRAINING_METRICS
class LarqMetric(Metric):
"""Metric with support for both 1.13 and 1.14+"""
def add_weight(
self,
name,
shape=(),
aggregation=tf.VariableAggregation.SUM,
synchronization=tf.VariableSynchronization.ON_READ,
initializer=None,
dtype=None,
):
if utils.tf_1_14_or_newer():
return super().add_weight(
name=name,
shape=shape,
aggregation=aggregation,
synchronization=synchronization,
initializer=initializer,
dtype=dtype,
)
else: # pragma: no cover
# Call explicitely tf.keras.layers.Layer.add_weight because TF 1.13
# doesn't support setting a custom dtype
return tf.keras.layers.Layer.add_weight(
self,
name=name,
shape=shape,
dtype=self._dtype if dtype is None else dtype,
trainable=False,
initializer=initializer,
collections=[],
synchronization=synchronization,
aggregation=aggregation,
)
@utils.register_alias("flip_ratio")
@utils.register_keras_custom_object
class FlipRatio(LarqMetric):
"""Computes the mean ration of changed values in a given tensor.
!!! example
```python
m = metrics.FlipRatio(values_shape=(2,))
m.update_state((1, 1)) # result: 0
m.update_state((2, 2)) # result: 1
m.update_state((1, 2)) # result: 0.75
print('Final result: ', m.result().numpy()) # Final result: 0.75
```
# Arguments
values_shape: Shape of the tensor for which to track changes.
values_dtype: Data type of the tensor for which to track changes.
name: Name of the metric.
dtype: Data type of the moving mean.
"""
def __init__(
self, values_shape=(), values_dtype="int8", name="flip_ratio", dtype=None
):
super().__init__(name=name, dtype=dtype)
self.values_dtype = tf.as_dtype(values_dtype)
self.values_shape = tf.TensorShape(values_shape).as_list()
self.is_weight_metric = True
with tf.init_scope():
self._previous_values = self.add_weight(
"previous_values",
shape=values_shape,
dtype=self.values_dtype,
initializer=tf.keras.initializers.zeros,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self.total = self.add_weight(
"total",
initializer=tf.keras.initializers.zeros,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self.count = self.add_weight(
"count",
initializer=tf.keras.initializers.zeros,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self._size = np.prod(self.values_shape)
def update_state(self, values, sample_weight=None):
values = tf.cast(values, self.values_dtype)
changed_values = tf.math.count_nonzero(tf.equal(self._previous_values, values))
flip_ratio = 1 - (tf.cast(changed_values, self.dtype) / self._size)
update_total_op = self.total.assign_add(flip_ratio * tf.sign(self.count))
with tf.control_dependencies([update_total_op]):
update_count_op = self.count.assign_add(1)
with tf.control_dependencies([update_count_op]):
return self._previous_values.assign(values)
def result(self):
return tf.compat.v1.div_no_nan(self.total, self.count - 1)
def reset_states(self):
tf.keras.backend.batch_set_value(
[(v, 0) for v in self.variables if v is not self._previous_values]
)
def get_config(self):
return {
**super().get_config(),
"values_shape": self.values_shape,
"values_dtype": self.values_dtype.name,
}