<a href="https://colab.research.google.com/github/niemand-01/ML-Demo/blob/master/Transfer_Learning_Keras_Background.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
from tensorflow import keras
import numpy as np

# 背景知识
## Transfer learning 定义：
Transfer learning consists of taking features learned on one problem, and leveraging them on a new, similar problem. For instance, features from a model that has learned to identify racoons may be useful to kick-start a model meant to identify tanukis.

也就是将一个模型适用于另一个模型，设别浣熊到识别狸猫

## Reason for transfer learning:
Transfer learning is usually done for tasks where your dataset has too little data to train a full-scale model from scratch.

## Most common incarnation of transfer learning(workflow):
1. Take layers from a previously trained model.
2. Freeze them, so as to avoid destroying any of the information they contain during future training rounds.
3. Add some new, trainable layers on top of the frozen layers. They will learn to turn the old features into predictions on a new dataset.
4. Train the new layers on your dataset.
5. __Finetuning__ which consists of unfreezing the entire model you obtained above (or part of it), and re-training it on the new data with a very low learning rate. This can potentially achieve meaningful improvements, by incrementally adapting the pretrained features to the new data.(optional step)


# 理解 trainable attribute
## example: Dense layer has 2 trainable weights(kernel&bias)

In [3]:
layer = keras.layers.Dense(3)
layer.build((None,4))

print("weights:",len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 2
non_trainable_weights: 0


## example: BatchNormalization layer has 2 trainable , 2non-trainable weights

In [4]:
layer = keras.layers.BatchNormalization()
layer.build((None, 4))  # Create the weights

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 4
trainable_weights: 2
non_trainable_weights: 2


## example: set trainable attribute to False:

In [5]:
layer = keras.layers.Dense(3)
layer.build((None, 4))  # Create the weights
layer.trainable = False  # Freeze the layer

print("weights:", len(layer.weights))
print("trainable_weights:", len(layer.trainable_weights))
print("non_trainable_weights:", len(layer.non_trainable_weights))

weights: 2
trainable_weights: 0
non_trainable_weights: 2


## example: when a trainable weight becomes non-trainable, its value is no longer updated during training

In [9]:
layer1 = keras.layers.Dense(3, activation="relu")
layer2 = keras.layers.Dense(3, activation="sigmoid")
model = keras.Sequential([keras.Input(shape=(3,)), layer1, layer2])

# Freeze the first layer
layer1.trainable = False

# Keep a copy of the weights of layer1 for later reference
initial_layer1_weights_values = layer1.get_weights()

# Train the model
model.compile(optimizer="adam", loss="mse")
model.fit(np.random.random((2, 3)), np.random.random((2, 3)))

# Check that the weights of layer1 have not changed during training
final_layer1_weights_values = layer1.get_weights()

# assert consistency before-train and after-train values of layer1 weights
np.testing.assert_allclose(
    initial_layer1_weights_values[0], final_layer1_weights_values[0]
)
np.testing.assert_allclose(
    initial_layer1_weights_values[1], final_layer1_weights_values[1]
)



## example: recursive setting of the trainable attribute

In [None]:
inner_model = keras.Sequential([
            keras.Input(shape=(3,))
            keras.layers.Dense(3,activation="relu"),
            keras,layers.Dense(3,activation="relu"),
])

model = keras.Sequential([
              keras.Input(shape=(3,)),
              inner_model,
              keras.layers.Dense(3,activation="sigmoid"),
])

model.trainable = False # freeze the outer model

assert inner_model.trainable == False # All layers in `model` are now frozen
assert inner_model.layers[0].trainable == False  # `trainable` is propagated recursively