### Freezing layers: understanding the trainable atttribute

In [129]:
import tensorflow as tf

In [130]:
# Example: the Dense layer has 2 trainable weights (kernel & bias)

layer = tf.keras.layers.Dense(3)
layer.build((None, 4))  # Input shape: (batch_size, 5)

In [131]:
print("Trainable weights:")
for weight in layer.trainable_weights:
    print(weight.name, weight.shape)

Trainable weights:
kernel:0 (4, 3)
bias:0 (3,)


In [132]:
print(f"Weights: {layer.weights}")

Weights: [<tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=
array([[ 0.5232246 ,  0.90533984,  0.05055439],
       [-0.8959446 ,  0.43206775, -0.20747948],
       [-0.44649592, -0.04036075, -0.21143943],
       [-0.13137019, -0.5553851 , -0.7530747 ]], dtype=float32)>, <tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]


In [133]:
print(f"Length of weights: {len(layer.weights)}")

Length of weights: 2


In [134]:
print(f"Trainable weights: {layer.trainable_weights}")

Trainable weights: [<tf.Variable 'kernel:0' shape=(4, 3) dtype=float32, numpy=
array([[ 0.5232246 ,  0.90533984,  0.05055439],
       [-0.8959446 ,  0.43206775, -0.20747948],
       [-0.44649592, -0.04036075, -0.21143943],
       [-0.13137019, -0.5553851 , -0.7530747 ]], dtype=float32)>, <tf.Variable 'bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]


In [135]:
print(f"Length of trainable weights: {len(layer.trainable_weights)}")

Length of trainable weights: 2


In [136]:
print(f"Non-trainable weights: {layer.non_trainable_weights}")


Non-trainable weights: []


In [137]:
print(f"Length of non-trainable weights: {len(layer.non_trainable_weights)}")

Length of non-trainable weights: 0


In [138]:
# Get the number of trainable variables
num_trainable_vars = layer.count_params()
print(f"Number of trainable variables: {num_trainable_vars}")

Number of trainable variables: 15


### The BatchNormalization layer has 2 trainable weights and 2 non-trainable weights

In [139]:
layer = tf.keras.layers.BatchNormalization()
layer.build((None, 4))  # Input shape: (batch_size, 5

In [140]:
print("weight", len(layer.weights))
print("trainable weight", len(layer.trainable_weights))
print("non-trainable weight", len(layer.non_trainable_weights))

weight 4
trainable weight 2
non-trainable weight 2


### Setting trainable to False

In [141]:
layer = tf.keras.layers.Dense(3)
layer.build((None, 4))  # Input shape: (batch_size, 5)
layer.trainable = False # Freeze the layer

In [142]:
print("weight", len(layer.weights))
print("trainable weight", len(layer.trainable_weights))
print("non-trainable weight", len(layer.non_trainable_weights))

weight 2
trainable weight 0
non-trainable weight 2


In [143]:
# Note, when trainable weight becomes non-trainable, its value is no longer updated during training

In [144]:
# Make a model with 2 layers
model_input = tf.keras.Input(shape=(3,))
layer1 = tf.keras.layers.Dense(3, activation='relu')
layer2 = tf.keras.layers.Dense(2, activation='sigmoid')

model = tf.keras.Sequential([model_input, layer1, layer2])

model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_21 (Dense)            (None, 3)                 12        
                                                                 
 dense_22 (Dense)            (None, 2)                 8         
                                                                 
Total params: 20 (80.00 Byte)
Trainable params: 20 (80.00 Byte)
Non-trainable params: 0 (0.00 Byte)
_________________________________________________________________


In [145]:
model.weights

[<tf.Variable 'dense_21/kernel:0' shape=(3, 3) dtype=float32, numpy=
 array([[-0.8312981 ,  0.23459387, -0.7852793 ],
        [-0.82814074,  0.99332094,  0.7527175 ],
        [ 0.9581959 ,  0.8446417 ,  0.30326033]], dtype=float32)>,
 <tf.Variable 'dense_21/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'dense_22/kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[-0.0340203 , -0.41568005],
        [ 0.6868131 ,  0.9702604 ],
        [ 0.5082257 , -1.0260777 ]], dtype=float32)>,
 <tf.Variable 'dense_22/bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]

In [146]:
len(model.weights)

4

In [147]:
model.trainable_weights

[<tf.Variable 'dense_21/kernel:0' shape=(3, 3) dtype=float32, numpy=
 array([[-0.8312981 ,  0.23459387, -0.7852793 ],
        [-0.82814074,  0.99332094,  0.7527175 ],
        [ 0.9581959 ,  0.8446417 ,  0.30326033]], dtype=float32)>,
 <tf.Variable 'dense_21/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'dense_22/kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[-0.0340203 , -0.41568005],
        [ 0.6868131 ,  0.9702604 ],
        [ 0.5082257 , -1.0260777 ]], dtype=float32)>,
 <tf.Variable 'dense_22/bias:0' shape=(2,) dtype=float32, numpy=array([0., 0.], dtype=float32)>]

In [148]:
model.non_trainable_weights


[]

In [149]:
len(model.trainable_weights)

4

In [150]:
# Freeze the first layer

layer1.trainable = False

In [151]:
model.summary()

Model: "sequential_4"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 dense_21 (Dense)            (None, 3)                 12        
                                                                 
 dense_22 (Dense)            (None, 2)                 8         
                                                                 
Total params: 20 (80.00 Byte)
Trainable params: 8 (32.00 Byte)
Non-trainable params: 12 (48.00 Byte)
_________________________________________________________________


###

In [152]:
len(model.trainable_weights)

2

In [153]:
model.non_trainable_weights

[<tf.Variable 'dense_21/kernel:0' shape=(3, 3) dtype=float32, numpy=
 array([[-0.8312981 ,  0.23459387, -0.7852793 ],
        [-0.82814074,  0.99332094,  0.7527175 ],
        [ 0.9581959 ,  0.8446417 ,  0.30326033]], dtype=float32)>,
 <tf.Variable 'dense_21/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>]

In [154]:
len(model.non_trainable_weights)

2

In [155]:
# Trrain the model
model.compile(optimizer='adam', loss='mse', metrics=['mae'])

model.fit(tf.random.normal((10, 3)), 
          tf.random.normal((10, 2)), epochs=2)

Epoch 1/2
Epoch 2/2


<keras.src.callbacks.History at 0x738f8c146a10>

In [156]:
# Check that the weights of layer1 have not changed during training 

model.weights

[<tf.Variable 'dense_21/kernel:0' shape=(3, 3) dtype=float32, numpy=
 array([[-0.8312981 ,  0.23459387, -0.7852793 ],
        [-0.82814074,  0.99332094,  0.7527175 ],
        [ 0.9581959 ,  0.8446417 ,  0.30326033]], dtype=float32)>,
 <tf.Variable 'dense_21/bias:0' shape=(3,) dtype=float32, numpy=array([0., 0., 0.], dtype=float32)>,
 <tf.Variable 'dense_22/kernel:0' shape=(3, 2) dtype=float32, numpy=
 array([[-0.03202101, -0.4176796 ],
        [ 0.6848131 ,  0.96826047],
        [ 0.50622576, -1.0280776 ]], dtype=float32)>,
 <tf.Variable 'dense_22/bias:0' shape=(2,) dtype=float32, numpy=array([-0.0019997 , -0.00199982], dtype=float32)>]