In [1]:
import tensorflow as tf
import numpy as np
from enum import Enum

tf.__version__

'1.14.0'

In [2]:
class TrainArg(Enum):
    FALSE = 0
    TRUE_UPDATE_U = 1
    TRUE_NO_UPDATE_U = 2

Some layers, in particular the `BatchNormalization` layer and the `Dropout` layer, have different behaviors during training and inference. For such layers, it is standard practice to expose a `training` (boolean) argument in the `call` method.

By exposing this argument in `call`, you enable the built-in training and evaluation loops to correctly use the layer in training and inference.

In [3]:
class Wrapper(tf.keras.layers.Layer):
    def __init__(self, **kwargs):
        super(Wrapper, self).__init__(**kwargs)
    
    def build(self, input_shape):
        self.bn = tf.keras.layers.BatchNormalization(momentum=0.9, epsilon=1e-5)
        super(Wrapper, self).build(input_shape)
        
    def call(self, inputs, training=None):
        x = self.bn(inputs, training=bool(training))
        print('training mode: {}'.format(training))
        return x

In [4]:
def make_model():
    inputs = tf.keras.Input(shape=(10,))
    x = tf.keras.layers.Dense(3)(inputs)
    outputs = Wrapper(name='wrapper')(x)
    model = tf.keras.Model(inputs, outputs)
    return model

model = make_model()

W1220 13:04:10.469505 12100 deprecation.py:506] From f:\anaconda3\envs\tensorflow1.14\lib\site-packages\tensorflow\python\ops\init_ops.py:1251: calling VarianceScaling.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor


training mode: None


In [5]:
train_update_u = model(model.input, training=TrainArg.TRUE_UPDATE_U)

training mode: TrainArg.TRUE_UPDATE_U


In [6]:
train_no_update_u = model(model.input, training=TrainArg.TRUE_NO_UPDATE_U)

training mode: TrainArg.TRUE_NO_UPDATE_U


In [7]:
inference = model(model.input, training=TrainArg.FALSE)

training mode: TrainArg.FALSE


In [8]:
tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)

[<tf.Variable 'dense/kernel:0' shape=(10, 3) dtype=float32>,
 <tf.Variable 'dense/bias:0' shape=(3,) dtype=float32>,
 <tf.Variable 'wrapper/batch_normalization/gamma:0' shape=(3,) dtype=float32>,
 <tf.Variable 'wrapper/batch_normalization/beta:0' shape=(3,) dtype=float32>,
 <tf.Variable 'wrapper/batch_normalization/moving_mean:0' shape=(3,) dtype=float32>,
 <tf.Variable 'wrapper/batch_normalization/moving_variance:0' shape=(3,) dtype=float32>]