# Pretrained Models for Transfer Learning
- If there's not enough training data, it is a good idea to use a pretrained model's lower layers
- example w/ Xception model pretrained on ImageNet:

In [12]:
import tensorflow_datasets as tfds

(train_set, valid_set, test_set), info = tfds.load('tf_flowers', as_supervised=True, with_info=True, split=['train[:75%]', 'train[75%:90%]', 'train[90%:]'])
dataset_size = info.splits["train"].num_examples #3670
class_anems = info.features["label"].names # ["dandelion", "daisies"]
n_classes = info.features["label"].num_classes #5

In [13]:
def preprocess(image, label):
    resized_image = tf.image.resize(image, [224, 224])
    final_image = keras.applications.xception.preprocess_input(resized_image)
    return final_image, label

In [17]:
import tensorflow as tf
from tensorflow import keras

batch_size = 32
train_set = train_set.shuffle(1000)
train_set = train_set.map(preprocess).batch(batch_size).prefetch(1) #prefetching caches the data in fast memory (RAM) for faster fetching time
valid_set = valid_set.map(preprocess).batch(batch_size).prefetch(1)
test_set = test_set.map(preprocess).batch(batch_size).prefetch(1)

In [19]:
# load an Xception model pretrained on ImageNet

base_model = keras.applications.xception.Xception(weights="imagenet", include_top=False) # exclude the top of the NN(global avg pooling layer and Dense layer)
avg = keras.layers.GlobalAveragePooling2D()(base_model.output) # add own global avg pooling layer based on output of base model
output = keras.layers.Dense(n_classes, activation="softmax")(avg)
model = keras.Model(inputs=base_model.input, outputs=output) # uses base model's layers directly rather than the base_model itself

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5


In [None]:
# freeze the weights of the pretrained layers

for layer in base_model.layers:
    layer.trainable = False