<a href="https://colab.research.google.com/github/niemand-01/ML-Demo/blob/master/Transfer_Learning_Keras_workflow.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from tensorflow import keras
import numpy as np


# Typical Transfer-learning workflow
## standard
1. Instantiate a base model and load pre-trained weights into it.
2. Freeze all layers in the base model by setting trainable = False.
3. Create a new model on top of the output of one (or several) layers from the base model.
4. Train your new model on your new dataset.

## lightweight
1. Instantiate a base model and load pre-trained weights into it.
2. Run your new dataset through it and record the output of one (or several) layers from the base model. This is called __feature extraction__.
3. Use that output as input data for a new, smaller model.

## Pro for lightweight:
faster&cheaper
## Contra for lightweight:
not able to modify the input data of new model dynamically (important for data augmentation)

In [None]:
# standard workflow

# instantiate a basemodel and load pretrained weights
base_model = keras.applications.Xception(
    weights = "imagenet", # load pretrained weights on ImageNet
    input_shape=(150,150,3),
    include_top=False
)

# freeze all weights in base-model
base_model.trainable = False

# create a new small model on top
inputs = keras.Input(shape=(150,150,3))

x = base_model(inputs,training=False)

x = keras.layers.GlobalAveragePooling2D()(x)

outputs = keras.layers.Dense(1)(x)

model = keras.Model(inputs,outputs)


# train new model
model.complie(
    optimizer = keras.optimizers.Adam(),
    loss = keras.losses.BinaryCrossentropy(from_logits=True),
    metrics = [keras.metrics.BinaryAccuracy()]
)

# model.fit(new_dataset,epochs=20,callbacks=...,validation_data=...)

## Transfer learning & fine-tuning with a custom training loop
If instead of fit(), you are using your own low-level training loop, the workflow stays essentially the same. You should be careful to only take into account the list model.trainable_weights when applying gradient updates:

In [None]:
# Create base model
base_model = keras.applications.Xception(
    weights='imagenet',
    input_shape=(150, 150, 3),
    include_top=False)
# Freeze base model
base_model.trainable = False

# Create new model on top.
inputs = keras.Input(shape=(150, 150, 3))
x = base_model(inputs, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

loss_fn = keras.losses.BinaryCrossentropy(from_logits=True)
optimizer = keras.optimizers.Adam()

# Iterate over the batches of a dataset.
for inputs, targets in new_dataset:
    # Open a GradientTape.
    with tf.GradientTape() as tape:
        # Forward pass.
        predictions = model(inputs)
        # Compute the loss value for this batch.
        loss_value = loss_fn(targets, predictions)

    # Get gradients of loss wrt the *trainable* weights.
    gradients = tape.gradient(loss_value, model.trainable_weights)
    # Update the weights of the model.
    optimizer.apply_gradients(zip(gradients, model.trainable_weights))