-
Notifications
You must be signed in to change notification settings - Fork 19.4k
/
base_preprocessing_layer.py
311 lines (262 loc) · 12.5 KB
/
base_preprocessing_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains the base ProcessingLayer and a subclass that uses Combiners."""
import abc
import tensorflow.compat.v2 as tf
from keras.engine import data_adapter
from keras.engine.base_layer import Layer
from keras.utils import version_utils
# isort: off
from tensorflow.python.eager import context
from tensorflow.python.util.tf_export import keras_export
from tensorflow.tools.docs import doc_controls
keras_kpl_gauge = tf.__internal__.monitoring.BoolGauge(
"/tensorflow/api/keras/layers/preprocessing",
"keras preprocessing layers usage",
"method",
)
@keras_export("keras.layers.experimental.preprocessing.PreprocessingLayer")
class PreprocessingLayer(Layer, metaclass=abc.ABCMeta):
"""Base class for Preprocessing Layers.
**Don't use this class directly: it's an abstract base class!** You may
be looking for one of the many built-in
[preprocessing layers](https://keras.io/guides/preprocessing_layers/)
instead.
Preprocessing layers are layers whose state gets computed before model
training starts. They do not get updated during training. Most
preprocessing layers implement an `adapt()` method for state computation.
The `PreprocessingLayer` class is the base class you would subclass to
implement your own preprocessing layers.
"""
_must_restore_from_config = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self._is_compiled = False
self._is_adapted = False
# Sets `is_adapted=False` when `reset_state` is called.
self._reset_state_impl = self.reset_state
self.reset_state = self._reset_state_wrapper
self._adapt_function = None
@property
def is_adapted(self):
"""Whether the layer has been fit to data already."""
return self._is_adapted
@doc_controls.do_not_generate_docs
def update_state(self, data):
"""Accumulates statistics for the preprocessing layer.
Arguments:
data: A mini-batch of inputs to the layer.
"""
raise NotImplementedError
@doc_controls.do_not_generate_docs
def reset_state(self):
"""Resets the statistics of the preprocessing layer."""
raise NotImplementedError
@doc_controls.do_not_generate_docs
def finalize_state(self):
"""Finalize the statistics for the preprocessing layer.
This method is called at the end of `adapt` or after restoring a
serialized preprocessing layer's state. This method handles any one-time
operations that should occur on the layer's state before
`Layer.__call__`.
"""
pass
@doc_controls.do_not_generate_docs
def make_adapt_function(self):
"""Creates a function to execute one step of `adapt`.
This method can be overridden to support custom adapt logic.
This method is called by `PreprocessingLayer.adapt`.
Typically, this method directly controls `tf.function` settings,
and delegates the actual state update logic to
`PreprocessingLayer.update_state`.
This function is cached the first time `PreprocessingLayer.adapt`
is called. The cache is cleared whenever `PreprocessingLayer.compile`
is called.
Returns:
Function. The function created by this method should accept a
`tf.data.Iterator`, retrieve a batch, and update the state of the
layer.
"""
if self._adapt_function is not None:
return self._adapt_function
def adapt_step(iterator):
data = next(iterator)
self._adapt_maybe_build(data)
self.update_state(data)
if self._steps_per_execution.numpy().item() == 1:
adapt_fn = adapt_step
else:
def adapt_fn(iterator):
for _ in tf.range(self._steps_per_execution):
adapt_step(iterator)
if not self._run_eagerly:
adapt_fn = tf.function(adapt_fn)
self._adapt_function = adapt_fn
return self._adapt_function
def compile(self, run_eagerly=None, steps_per_execution=None):
"""Configures the layer for `adapt`.
Arguments:
run_eagerly: Bool. Defaults to `False`. If `True`, this `Model`'s
logic will not be wrapped in a `tf.function`. Recommended to leave
this as `None` unless your `Model` cannot be run inside a
`tf.function`.
steps_per_execution: Int. Defaults to 1. The number of batches to run
during each `tf.function` call. Running multiple batches inside a
single `tf.function` call can greatly improve performance on TPUs or
small models with a large Python overhead.
"""
if steps_per_execution is None:
steps_per_execution = 1
self._configure_steps_per_execution(steps_per_execution)
if run_eagerly is None:
run_eagerly = self.dynamic
self._run_eagerly = run_eagerly
self._is_compiled = True
def adapt(self, data, batch_size=None, steps=None):
"""Fits the state of the preprocessing layer to the data being passed.
After calling `adapt` on a layer, a preprocessing layer's state will not
update during training. In order to make preprocessing layers efficient
in any distribution context, they are kept constant with respect to any
compiled `tf.Graph`s that call the layer. This does not affect the layer
use when adapting each layer only once, but if you adapt a layer
multiple times you will need to take care to re-compile any compiled
functions as follows:
* If you are adding a preprocessing layer to a `keras.Model`, you need
to call `model.compile` after each subsequent call to `adapt`.
* If you are calling a preprocessing layer inside
`tf.data.Dataset.map`, you should call `map` again on the input
`tf.data.Dataset` after each `adapt`.
* If you are using a `tf.function` directly which calls a preprocessing
layer, you need to call `tf.function` again on your callable after
each subsequent call to `adapt`.
`tf.keras.Model` example with multiple adapts:
>>> layer = tf.keras.layers.Normalization(
... axis=None)
>>> layer.adapt([0, 2])
>>> model = tf.keras.Sequential(layer)
>>> model.predict([0, 1, 2])
array([-1., 0., 1.], dtype=float32)
>>> layer.adapt([-1, 1])
>>> model.compile() # This is needed to re-compile model.predict!
>>> model.predict([0, 1, 2])
array([0., 1., 2.], dtype=float32)
`tf.data.Dataset` example with multiple adapts:
>>> layer = tf.keras.layers.Normalization(
... axis=None)
>>> layer.adapt([0, 2])
>>> input_ds = tf.data.Dataset.range(3)
>>> normalized_ds = input_ds.map(layer)
>>> list(normalized_ds.as_numpy_iterator())
[array([-1.], dtype=float32),
array([0.], dtype=float32),
array([1.], dtype=float32)]
>>> layer.adapt([-1, 1])
>>> normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
>>> list(normalized_ds.as_numpy_iterator())
[array([0.], dtype=float32),
array([1.], dtype=float32),
array([2.], dtype=float32)]
`adapt()` is meant only as a single machine utility to compute layer
state. To analyze a dataset that cannot fit on a single machine, see
[Tensorflow Transform](
https://www.tensorflow.org/tfx/transform/get_started)
for a multi-machine, map-reduce solution.
Arguments:
data: The data to train on. It can be passed either as a tf.data
Dataset, or as a numpy array.
batch_size: Integer or `None`.
Number of samples per state update. If unspecified,
`batch_size` will default to 32. Do not specify the
`batch_size` if your data is in the form of datasets,
generators, or `keras.utils.Sequence` instances (since they
generate batches).
steps: Integer or `None`.
Total number of steps (batches of samples)
When training with input tensors such as
TensorFlow data tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined. If x is a
`tf.data` dataset, and 'steps' is None, the epoch will run until
the input dataset is exhausted. When passing an infinitely
repeating dataset, you must specify the `steps` argument. This
argument is not supported with array inputs.
"""
_disallow_inside_tf_function("adapt")
if not version_utils.should_use_v2():
raise RuntimeError("`adapt` is only supported in tensorflow v2.")
if not self._is_compiled:
self.compile() # Compile with defaults.
if self.built:
self.reset_state()
data_handler = data_adapter.DataHandler(
data,
batch_size=batch_size,
steps_per_epoch=steps,
epochs=1,
steps_per_execution=self._steps_per_execution,
distribute=False,
)
self._adapt_function = self.make_adapt_function()
for _, iterator in data_handler.enumerate_epochs():
with data_handler.catch_stop_iteration():
for _ in data_handler.steps():
self._adapt_function(iterator)
if data_handler.should_sync:
context.async_wait()
self.finalize_state()
self._is_adapted = True
def _reset_state_wrapper(self):
"""Calls `reset_state` and sets `adapted` to `False`."""
self._reset_state_impl()
self._is_adapted = False
@tf.__internal__.tracking.no_automatic_dependency_tracking
def _configure_steps_per_execution(self, steps_per_execution):
self._steps_per_execution = tf.Variable(
steps_per_execution,
dtype="int64",
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
)
# TODO(omalleyt): Unify this logic with `Layer._maybe_build`.
def _adapt_maybe_build(self, data):
if not self.built:
try:
# If this is a Numpy array or tensor, we can get shape from
# .shape. If not, an attribute error will be thrown.
data_shape = data.shape
data_shape_nones = tuple([None] * len(data.shape))
except AttributeError:
# The input has an unknown number of dimensions.
data_shape = None
data_shape_nones = None
# TODO (b/159261555): move this to base layer build.
batch_input_shape = getattr(self, "_batch_input_shape", None)
if batch_input_shape is None:
# Set the number of dimensions.
self._batch_input_shape = data_shape_nones
self.build(data_shape)
self.built = True
def _disallow_inside_tf_function(method_name):
"""Disallow calling a method inside a `tf.function`."""
if tf.inside_function():
error_msg = (
"Detected a call to `PreprocessingLayer.{method_name}` inside a "
"`tf.function`. `PreprocessingLayer.{method_name} is a high-level "
"endpoint that manages its own `tf.function`. Please move the call "
"to `PreprocessingLayer.{method_name}` outside of all enclosing "
"`tf.function`s. Note that you can call a `PreprocessingLayer` "
"directly on `Tensor`s inside a `tf.function` like: `layer(x)`, "
"or update its state like: `layer.update_state(x)`."
).format(method_name=method_name)
raise RuntimeError(error_msg)