# TensorFlow: Classification with transfer learning using pre-trained model from TensorFlow Hub

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt

In [None]:
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"

import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_hub as hub
import tf_keras as keras
from tensorboard import program

print("TF Version: ", tf.__version__)
print("TF Eager mode: ", tf.executing_eagerly())
print("TF GPU is", "available" if tf.config.list_physical_devices("GPU") else "not available")

__Variable definitions__

In [None]:
# Set image size
IMAGE_SIZE = (224, 224)
# Set the size of batches
BATCH_SIZE = 32

# Prepare dataset

In [None]:
(raw_train_ds, raw_val_ds), ds_info = tfds.load(
    "tf_flowers",
    split=["train[:80%]", "train[80%:]"],
    with_info=True,
    as_supervised=True)

In [None]:
# Declare pre-processing function
def preprocess_image(image, label):
    # Resize image
    image = tf.image.resize(image, size=IMAGE_SIZE)
    # Apply normalization
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

In [None]:
BUFFER_SIZE = 10000

train_ds = (raw_train_ds
    .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(BUFFER_SIZE)
    .cache()
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE))

val_ds = (raw_val_ds
    .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
    .cache()
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE))

In [None]:
class_names = ds_info.features["label"].names
print(f"Class names: {class_names}")

# Create model

Download pre-trained image classifier model from TensorFlow Hub (Kaggle)

In [None]:
mobilenet_v2 = ("https://www.kaggle.com/api/v1/models/google/mobilenet-v2/tensorFlow2/"
                "tf2-preview-feature-vector/4/download")

In [None]:
IMAGE_SHAPE = IMAGE_SIZE + (3,)

# Wrap pre-trained model by specific KerasLayer
feature_extractor_layer = hub.KerasLayer(
    mobilenet_v2,
    input_shape=IMAGE_SHAPE,
    trainable=False)

In [None]:
num_classes = len(class_names)
print(f"The number of classes to predict: {class_names}")

model = keras.Sequential([
    feature_extractor_layer,
    # Attach classification layer with particular amount of classes
    keras.layers.Dense(num_classes)
])

In [None]:
model.compile(
  optimizer=keras.optimizers.Adam(),
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  metrics=["accuracy"])

# Train model

In [None]:
LOGS_DIR = os.path.join("logs", "tf_transfer_learning_hub")

callbacks = [
    keras.callbacks.EarlyStopping(
        patience=5,
        min_delta=1e-2,
        restore_best_weights=True,
        verbose=1
    ),
    keras.callbacks.TensorBoard(
        log_dir=LOGS_DIR,
        histogram_freq=0,
        embeddings_freq=0,
        update_freq="epoch"
    )
]

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks,
)

In [None]:
tb = program.TensorBoard()
tb.configure(argv=[None, '--load_fast', 'false', '--logdir', LOGS_DIR])
url = tb.launch()
print(f"TensorBoard listening on {url}")

__Evaluate model__

In [None]:
acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]

fig, ax = plt.subplots(1,2, figsize=(12, 6))
ax[0].plot(acc, "bo", label="Training accuracy")
ax[0].plot(val_acc, "b", label="Validation accuracy")
ax[0].set_title("Training and validation accuracy")
ax[0].set_xlabel("epochs")
ax[0].set_ylabel("accuracy")
ax[0].legend()

ax[1].plot(loss, "bo", label="Training Loss")
ax[1].plot(val_loss, 'b', label="Validation Loss")
ax[1].set_title("Training and validation loss")
ax[1].set_xlabel("epochs")
ax[1].set_ylabel("loss")
ax[1].legend()

plt.show()

__Evaluate predictions__

In [None]:
image_batch, label_batch = next(val_ds.take(1).as_numpy_iterator())
predictions_batch = model.predict(image_batch)

In [None]:
ids = tf.math.argmax(predictions_batch, axis=-1)
class_names = np.array(class_names)
predictions_labels = class_names[ids]

In [None]:
plt.figure(figsize=(10,9))
plt.subplots_adjust(hspace=0.5)

for n in range(30):
  plt.subplot(6,5,n+1)
  plt.imshow(image_batch[n])
  plt.title(predictions_labels[n].title())
  plt.axis('off')
