In [53]:
import numpy as np
import tensorflow as tf
import tensorflow_hub as hub

import pathlib

## An ImageNet classifier

In [6]:
classifier_url ="https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/4" 

In [8]:
IMAGE_SHAPE = [224, 224]

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+[3,])
])

## Simple transfer learning

In [21]:
data_root = tf.keras.utils.get_file('flower_photos.tgz','https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz', extract=True)
data_root = pathlib.Path(data_root).parents[0]/'flower_photos'

In [24]:
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
image_data = image_generator.flow_from_directory(
    directory=data_root,
    target_size=IMAGE_SHAPE,
    batch_size=32
)

Found 3670 images belonging to 5 classes.


__Run the classifier on a batch of images__

In [34]:
classifier.predict_on_batch(image_data[0][0]).shape

(32, 1001)

### Download the headless model

In [35]:
feature_extractor_url = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4"

feature_extractor_layer = hub.KerasLayer(feature_extractor_url, input_shape=IMAGE_SHAPE+[3,])

__Run feature extractor__

In [38]:
feature_extractor_layer(image_data[0][0]).shape

TensorShape([32, 1280])

__Freeze this layer to apply transfer learning__

In [41]:
feature_extractor_layer.trainable = False

model = tf.keras.Sequential([
    feature_extractor_layer,
    tf.keras.layers.Dense(image_data.num_classes)
])

### Train the transfer model

In [45]:
model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

In [63]:
class CollectBatchStats(tf.keras.callbacks.Callback):
    def __init__(self):
        super(CollectBatchStats, self).__init__()
        self.batch_losses = []
        self.batch_acc = []
        
    def on_train_batch_end(self, batch, logs=None):
        self.batch_losses.append(logs['loss'])
        self.batch_acc.append(logs['accuracy'])
        self.model.reset_metrics()

In [64]:
steps_per_epoch = np.ceil(image_data.samples / image_data.batch_size)a

batch_stats_callback = CollectBatchStats()

In [65]:
model.fit(image_data,
          epochs=2,
          steps_per_epoch=steps_per_epoch,
          callbacks=[batch_stats_callback])

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x7fb241060b50>