In [45]:
import tensorflow as tf

import numpy as np
import matplotlib.pyplot as plt

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

%matplotlib inline
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


## data download

In [46]:
splits = tfds.Split.TRAIN.subsplit([70, 30])

In [47]:
(training_set, validation_set), dataset_info = \
    tfds.load('tf_flowers', split=splits, as_supervised=True, with_info=True)

In [48]:
num_classes = dataset_info.features['label'].num_classes

In [49]:
num_classes

5

In [50]:
num_training_examples = len(list(training_set))
num_validation_examples = len(list(validation_set))

In [51]:
num_training_examples, num_validation_examples

(2590, 1080)

In [52]:
dataset_info.splits['train'].num_examples

3670

In [53]:
# almost 30% but not quite
2590 + 1080, 1080 / 3670

(3670, 0.29427792915531337)

In [54]:
for i, example in enumerate(training_set.take(5)):
    print('Image {} shape: {} label: {}'.format(i+1, example[0].shape, example[1]))

Image 1 shape: (240, 320, 3) label: 4
Image 2 shape: (375, 500, 3) label: 1
Image 3 shape: (333, 500, 3) label: 1
Image 4 shape: (227, 320, 3) label: 4
Image 5 shape: (246, 320, 3) label: 1


## data preprocessing

In [55]:
IMAGE_RES = 224

In [56]:
def format_image(image, label):
    image = tf.image.resize(image, (IMAGE_RES, IMAGE_RES)) / 255.0
    return image, label

In [57]:
BATCH_SIZE = 32

In [58]:
train_batches = training_set \
                    .shuffle(num_training_examples // 4) \
                    .map(format_image) \
                    .batch(BATCH_SIZE) \
                    .prefetch(1)

In [59]:
validation_batches = validation_set \
                    .map(format_image) \
                    .batch(BATCH_SIZE) \
                    .prefetch(1)

## transfer learning

### train model

In [60]:
URL = "https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/4" 

In [61]:
feature_extractor = hub.KerasLayer(URL, 
                                   input_shape=(IMAGE_RES, IMAGE_RES, 3), 
                                   trainable=False)

In [65]:
model = tf.keras.Sequential([
    feature_extractor,
    tf.keras.layers.Dense(num_classes, activation='softmax')
])

In [66]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
keras_layer (KerasLayer)     (None, 1280)              2257984   
_________________________________________________________________
dense (Dense)                (None, 5)                 6405      
Total params: 2,264,389
Trainable params: 6,405
Non-trainable params: 2,257,984
_________________________________________________________________


In [67]:
EPOCHS = 6

In [68]:
model.compile(optimizer='adam', 
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

In [None]:
history = model.fit(train_batches,
                    batch_size=BATCH_SIZE,
                    epochs=EPOCHS,
                    validation_data=validation_batches)

### plot metrics

In [None]:
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(EPOCHS)

plt.figure(figsize=(15, 10))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

### check predictions

In [None]:
class_names = np.array(dataset_info.features['label'].names)

In [None]:
class_names

In [None]:
image_batch, label_batch = next(iter(train_batches))

image_batch = image_batch.numpy()
label_batch = label_batch.numpy()

predicted_batch = model.predict(image_batch)
predicted_batch = tf.squeeze(predicted_batch).numpy()

predicted_ids = np.argmax(predicted_batch, axis=-1)
predicted_class_names = class_names[predicted_ids]

In [None]:
predicted_class_names

In [None]:
print("Labels:           ", label_batch)
print("Predicted labels: ", predicted_ids)

In [None]:
plt.figure(figsize=(10,9))
for n in range(30):
    plt.subplot(6,5,n+1)
    plt.subplots_adjust(hspace = 0.3)
    plt.imshow(image_batch[n])
    color = "blue" if predicted_ids[n] == label_batch[n] else "red"
    plt.title(predicted_class_names[n].title(), color=color)
    plt.axis('off')
_ = plt.suptitle("Model predictions (blue: correct, red: incorrect)")