/
metrics.py
89 lines (74 loc) · 3.17 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""We add metrics specific to extremely quantized networks using a
`larq.context.metrics_scope` rather than through the `metrics` parameter of
`model.compile()`, where most common metrics reside. This is because, to calculate
metrics like the `flip_ratio`, we need a layer's kernel or activation and not just the
`y_true` and `y_pred` that Keras passes to metrics defined in the usual way.
"""
import numpy as np
import tensorflow as tf
from larq import utils
@utils.register_alias("flip_ratio")
@utils.register_keras_custom_object
class FlipRatio(tf.keras.metrics.Metric):
"""Computes the mean ratio of changed values in a given tensor.
!!! example
```python
m = metrics.FlipRatio()
m.update_state((1, 1)) # result: 0
m.update_state((2, 2)) # result: 1
m.update_state((1, 2)) # result: 0.75
print('Final result: ', m.result().numpy()) # Final result: 0.75
```
# Arguments
name: Name of the metric.
values_dtype: Data type of the tensor for which to track changes.
dtype: Data type of the moving mean.
"""
def __init__(self, values_dtype="int8", name="flip_ratio", dtype=None):
super().__init__(name=name, dtype=dtype)
self.built = False
self.values_dtype = tf.as_dtype(values_dtype)
def build(self, input_shape):
self._previous_values = self.add_weight(
"previous_values",
shape=input_shape,
dtype=self.values_dtype,
initializer=tf.keras.initializers.zeros,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self.total = self.add_weight(
"total",
initializer=tf.keras.initializers.zeros,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self.count = self.add_weight(
"count",
initializer=tf.keras.initializers.zeros,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
self._size = tf.cast(np.prod(input_shape), self.dtype)
self.built = True
def update_state(self, values, sample_weight=None):
values = tf.cast(values, self.values_dtype)
if not self.built:
with tf.name_scope(self.name), tf.init_scope():
self.build(values.shape)
unchanged_values = tf.math.count_nonzero(
tf.equal(self._previous_values, values)
)
flip_ratio = 1 - (
tf.cast(unchanged_values, self.dtype) / tf.cast(self._size, self.dtype)
)
update_total_op = self.total.assign_add(flip_ratio * tf.sign(self.count))
with tf.control_dependencies([update_total_op]):
update_count_op = self.count.assign_add(1)
with tf.control_dependencies([update_count_op]):
return self._previous_values.assign(values)
def result(self):
return tf.compat.v1.div_no_nan(self.total, self.count - 1)
def reset_states(self):
tf.keras.backend.batch_set_value(
[(v, 0) for v in self.variables if v is not self._previous_values]
)
def get_config(self):
return {**super().get_config(), "values_dtype": self.values_dtype.name}