In [19]:
import tensorflow as tf
print("tensorflow version is", tf.__version__)
print("keras version is", tf.keras.__version__)
from tensorflow.keras.layers import Dense, BatchNormalization
from tensorflow.keras.models import Sequential
import numpy as np

tensorflow version is 2.1.0
keras version is 2.2.4-tf


In [83]:
def check_weight_update(model, operation):
    w1 = {w.name:w.numpy() for w in model.weights}
    operation(model)
    w2 = {w.name:w.numpy() for w in model.weights}
    for i, name in enumerate(w1):
        if np.allclose(w1[name], w2[name]):
            print("weight %d (%s)\tnot updated" % (i, name))
        else:
            print("weight %d (%s)\tupdated" % (i, name))    

In [68]:
def make_model():
    return Sequential([Dense(3, input_dim=2), BatchNormalization(), Dense(1)])

def operation(model):
    model.trainable = True
    model.compile(loss="mse", optimizer="adam")
    x = np.random.randn(10, 2)
    y = np.random.randn(10, 1)
    model.train_on_batch(x, y)

model = make_model()
check_weight_update(model, operation)

weight 0 (dense_24/kernel:0)	updated
weight 1 (dense_24/bias:0)	updated
weight 2 (batch_normalization_12/gamma:0)	updated
weight 3 (batch_normalization_12/beta:0)	updated
weight 4 (batch_normalization_12/moving_mean:0)	updated
weight 5 (batch_normalization_12/moving_variance:0)	updated
weight 6 (dense_25/kernel:0)	updated
weight 7 (dense_25/bias:0)	updated


In [91]:
def operation(model):
    model.trainable = True
    model.compile(loss="mse", optimizer="adam")
    x = np.random.randn(10, 2)
    model.predict(x)

model = make_model()
check_weight_update(model, operation)

weight 0 (dense_44/kernel:0)	not updated
weight 1 (dense_44/bias:0)	not updated
weight 2 (batch_normalization_22/gamma:0)	not updated
weight 3 (batch_normalization_22/beta:0)	not updated
weight 4 (batch_normalization_22/moving_mean:0)	not updated
weight 5 (batch_normalization_22/moving_variance:0)	not updated
weight 6 (dense_45/kernel:0)	not updated
weight 7 (dense_45/bias:0)	not updated


In [93]:
def operation(model):
    model.trainable = True
    model.compile(loss="mse", optimizer="adam")
    x = np.random.randn(10, 2)
    y = np.random.randn(10, 1)
    model.evaluate(x, y, verbose=False)

model = make_model()
check_weight_update(model, operation)

weight 0 (dense_48/kernel:0)	not updated
weight 1 (dense_48/bias:0)	not updated
weight 2 (batch_normalization_24/gamma:0)	not updated
weight 3 (batch_normalization_24/beta:0)	not updated
weight 4 (batch_normalization_24/moving_mean:0)	not updated
weight 5 (batch_normalization_24/moving_variance:0)	not updated
weight 6 (dense_49/kernel:0)	not updated
weight 7 (dense_49/bias:0)	not updated


In [69]:
def operation(model):
    model.trainable = False
    model.compile(loss="mse", optimizer="adam")
    x = np.random.randn(10, 2)
    y = np.random.randn(10, 1)
    model.train_on_batch(x, y)

model = make_model()
check_weight_update(model, operation)

weight 0 (dense_26/kernel:0)	not updated
weight 1 (dense_26/bias:0)	not updated
weight 2 (batch_normalization_13/gamma:0)	not updated
weight 3 (batch_normalization_13/beta:0)	not updated
weight 4 (batch_normalization_13/moving_mean:0)	not updated
weight 5 (batch_normalization_13/moving_variance:0)	not updated
weight 6 (dense_27/kernel:0)	not updated
weight 7 (dense_27/bias:0)	not updated


In [71]:
def operation(model):
    model.trainable = True
    model.layers[1].trainable = False
    model.compile(loss="mse", optimizer="adam")
    x = np.random.randn(10, 2)
    y = np.random.randn(10, 1)
    model.train_on_batch(x, y)

model = make_model()
check_weight_update(model, operation)

weight 0 (dense_30/kernel:0)	updated
weight 1 (dense_30/bias:0)	updated
weight 2 (batch_normalization_15/gamma:0)	not updated
weight 3 (batch_normalization_15/beta:0)	not updated
weight 4 (batch_normalization_15/moving_mean:0)	not updated
weight 5 (batch_normalization_15/moving_variance:0)	not updated
weight 6 (dense_31/kernel:0)	updated
weight 7 (dense_31/bias:0)	updated


In [89]:
def make_composite_model():
    m1 = Sequential([Dense(5, input_dim=3), BatchNormalization(), Dense(2)])
    m2 = Sequential([Dense(7, input_dim=2), BatchNormalization(), Dense(1)])
    m3 = Sequential([m1, m2])
    return m1, m2, m3
m1, m2, m3 = make_composite_model()

m1.summary()
print("***")
m2.summary()
print("***")
m3.summary()

Model: "sequential_25"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_40 (Dense)             (None, 5)                 20        
_________________________________________________________________
batch_normalization_20 (Batc (None, 5)                 20        
_________________________________________________________________
dense_41 (Dense)             (None, 2)                 12        
Total params: 52
Trainable params: 42
Non-trainable params: 10
_________________________________________________________________
***
Model: "sequential_26"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
dense_42 (Dense)             (None, 7)                 21        
_________________________________________________________________
batch_normalization_21 (Batc (None, 7)                 28        
_____________________________

In [90]:
def operation(*args):
    m1.trainable = True
    m2.trainable = False
    m3.compile(loss="mse", optimizer="adam")
    
    x = np.random.randn(10, 3)
    y = np.random.randn(10, 1)
    m3.train_on_batch(x, y)

check_weight_update(m3, operation)

weight 0 (dense_40/kernel:0)	updated
weight 1 (dense_40/bias:0)	updated
weight 2 (batch_normalization_20/gamma:0)	updated
weight 3 (batch_normalization_20/beta:0)	updated
weight 4 (batch_normalization_20/moving_mean:0)	updated
weight 5 (batch_normalization_20/moving_variance:0)	updated
weight 6 (dense_41/kernel:0)	updated
weight 7 (dense_41/bias:0)	updated
weight 8 (dense_42/kernel:0)	not updated
weight 9 (dense_42/bias:0)	not updated
weight 10 (batch_normalization_21/gamma:0)	not updated
weight 11 (batch_normalization_21/beta:0)	not updated
weight 12 (batch_normalization_21/moving_mean:0)	not updated
weight 13 (batch_normalization_21/moving_variance:0)	not updated
weight 14 (dense_43/kernel:0)	not updated
weight 15 (dense_43/bias:0)	not updated
