In [None]:
# Setup library
## install -r requirements.txt
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import time

import matplotlib.pylab as plt
# %matplotlib widget
%matplotlib inline
import PIL.Image as Image
import numpy as np
import pandas as pd

import tensorflow as tf
import tensorflow_hub as hub
from tensorflow.keras import layers

In [None]:
IMG_DATA = './dataset_growth_100'
IMG_SHAPE = (224, 224)

In [None]:
# Classifier from TF hub
classifier_url = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/classification/2' #@param {type:"string"}
classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMG_SHAPE+(3, )) # Channel 3 RGB
])

## And labels
labels_path = tf.keras.utils.get_file('ImageNetLabels.txt', 
                                      'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())

In [None]:
# prepare dataset
dataset_root = os.path.abspath(os.path.expanduser(IMG_DATA))
print(f'Dataset root: {dataset_root}')

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255,
                                                                  validation_split=0.2)
train_data = image_generator.flow_from_directory(dataset_root, target_size=IMG_SHAPE,
                                                 follow_links=True,
                                                 subset='training')
validation_data = image_generator.flow_from_directory(dataset_root, target_size=IMG_SHAPE,
                                                      follow_links=True,
                                                      shuffle=False,
                                                      subset='validation')

for image_batch, label_batch in validation_data:
    print(f'Image batch shape: {image_batch.shape}')
    print(f'Label batch shape: {label_batch.shape}')
    break

In [None]:
# Predict batch input example
## using original ImangeNet classifier
result_batch = classifier.predict(image_batch)
print(f'Batch result shape: {result_batch.shape}')

predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
print(f'Batch predicted class names: {predicted_class_names}')

fig1 = plt.figure(figsize=(10, 9))
fig1.subplots_adjust(hspace=0.5)
for n in range(30):
    ax = fig1.add_subplot(6, 5, n+1)
    ax.imshow(image_batch[n])
    ax.set_title(predicted_class_names[n])
    ax.axis('off')
_ = fig1.suptitle('ImageNet predictions')

In [None]:
# Prepare transfer learning
## Download headless (without the top classification layer) model
feature_extractor_url = 'https://tfhub.dev/google/tf2-preview/mobilenet_v2/feature_vector/2' #@param {type:"string"}
feature_extractor_layer = hub.KerasLayer(feature_extractor_url,
                                         input_shape=IMG_SHAPE+(3, ))
feature_batch = feature_extractor_layer(image_batch)
print(f'Feature vector shape: {feature_batch.shape}')

## Frozen feature extraction layer
feature_extractor_layer.trainable = False # for transfer learning classifier

## Make a model for classification
model = tf.keras.Sequential([
    feature_extractor_layer,
    layers.Dense(train_data.num_classes, activation='softmax')
])

In [None]:
## Check the model and prediction result
model.summary()

predictions = model(image_batch)
print(f'Prediction shape: {predictions.shape}')

In [None]:
# Train build
## Compile model for train
base_learning_rate = 0.001 # default
model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

## Log class
### https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback
class CollectBatchStats(tf.keras.callbacks.Callback):
    def __init__(self):
        self.batch_losses = []
        self.batch_val_losses = []
        self.batch_acc = []
        self.batch_val_acc = []
    
    def on_epoch_end(self, epoch, logs=None):
        self.batch_losses.append(logs['loss'])
        self.batch_acc.append(logs['accuracy'])
        self.batch_val_losses.append(logs['val_loss'])
        self.batch_val_acc.append(logs['val_accuracy'])
        self.model.reset_metrics()

In [None]:
steps_per_epoch = np.ceil(train_data.samples/train_data.batch_size) # train all dataset per epoch
initial_epoch = 25
batch_stats_callback = CollectBatchStats()

history = model.fit(train_data,
                    epochs=initial_epoch,
                    steps_per_epoch=steps_per_epoch,
                    validation_data=validation_data,
                    callbacks=[batch_stats_callback])

In [None]:
# Draw learning curves chart
acc = batch_stats_callback.batch_acc
val_acc = batch_stats_callback.batch_val_acc
loss = batch_stats_callback.batch_losses
val_loss = batch_stats_callback.batch_val_losses

fig2 = plt.figure(figsize=(8, 8))
ax1 = fig2.add_subplot(2, 1, 1)
ax1.plot(acc, label='Training Accuracy')
ax1.plot(val_acc, label='Validation Accuracy')
ax1.legend(loc='lower right')
ax1.set_ylabel('Accuracy')
ax1.set_ylim([min(plt.ylim()),1])
ax1.set_title('Training and Validation Accuracy')

ax2 = fig2.add_subplot(2, 1, 2)
ax2.plot(loss, label='Training Loss')
ax2.plot(val_loss, label='Validation Loss')
ax2.legend(loc='upper right')
ax2.set_ylabel('Cross Entropy')
ax2.set_ylim([0,1.0])
ax2.set_title('Training and Validation Loss')
ax2.set_xlabel('epoch')

In [None]:
# Plot results
class_names = sorted(validation_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
print(f'Classes: {class_names}')

## get result labels
predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

label_id = np.argmax(label_batch, axis=-1)

## plot
fig3 = plt.figure(figsize=(10,9))
fig3.subplots_adjust(hspace=0.5)
for n in range(30):
    ax = fig3.add_subplot(6, 5, n+1)
    ax.imshow(image_batch[n])
    color = 'green' if predicted_id[n] == label_id[n] else 'red'
    ax.set_title(predicted_label_batch[n].title(), color=color)
    ax.axis('off')
_ = fig3.suptitle('Model predictions (green: correct, red: incorrect)')

In [None]:
# Export model
t = time.time()
# prefix = input('model prefix: ').strip()
prefix = '100_tf'
export_path = f'./saved_models/{prefix}-{int(t)}'
model.save(export_path, save_format='tf')

print(f'Export the model to {export_path}')

In [None]:
# Check the exported model
reloaded = tf.keras.models.load_model(export_path)

result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)

print(f'Comparison between own model and exported model {abs(reloaded_result_batch - result_batch).max()}')

In [None]:
## Unfrozen feature extraction layer
feature_extractor_layer.trainable = True

model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate/10),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=['accuracy'])
model.summary()

In [None]:
finetune_epoch = 25

history_fine = model.fit(train_data,
                        epochs=initial_epoch+finetune_epoch,
                        initial_epoch=initial_epoch, # == history.epoch[-1]+1
                        steps_per_epoch=steps_per_epoch,
                        validation_data=validation_data,
                        callbacks = [batch_stats_callback])

In [None]:
# Draw learning curves chart
fine_acc = batch_stats_callback.batch_acc
fine_val_acc = batch_stats_callback.batch_val_acc
fine_loss = batch_stats_callback.batch_losses
fine_val_loss = batch_stats_callback.batch_val_losses

fig4 = plt.figure(figsize=(8, 8))
ax1 = fig4.add_subplot(2, 1, 1)
ax1.plot(acc, label='Training Accuracy')
ax1.plot(val_acc, label='Validation Accuracy')
ax1.set_ylabel('Accuracy')
ax1.set_ylim([min(plt.ylim()),1])
ax1.plot([initial_epoch,initial_epoch],
         ax1.get_ylim(), label='Start Fine Tuning')
ax1.legend(loc='lower right')
ax1.set_title('Training and Validation Accuracy')

ax2 = fig4.add_subplot(2, 1, 2)
ax2.plot(loss, label='Training Loss')
ax2.plot(val_loss, label='Validation Loss')
ax2.set_ylabel('Cross Entropy')
ax2.set_ylim([0,1.0])
ax2.plot([initial_epoch,initial_epoch],
         ax2.get_ylim(), label='Start Fine Tuning')
ax2.legend(loc='upper right')
ax2.set_title('Training and Validation Loss')
ax2.set_xlabel('epoch')

In [None]:
# Plot results
class_names = sorted(validation_data.class_indices.items(), key=lambda pair:pair[1])
class_names = np.array([key.title() for key, value in class_names])
print(f'Classes: {class_names}')

## get result labels
predicted_batch = model.predict(image_batch)
predicted_id = np.argmax(predicted_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

label_id = np.argmax(label_batch, axis=-1)

## plot
fig5 = plt.figure(figsize=(10,9))
fig5.subplots_adjust(hspace=0.5)
for n in range(30):
    ax = fig5.add_subplot(6, 5, n+1)
    ax.imshow(image_batch[n])
    color = 'green' if predicted_id[n] == label_id[n] else 'red'
    ax.set_title(predicted_label_batch[n].title(), color=color)
    ax.axis('off')
_ = fig5.suptitle('Model predictions (green: correct, red: incorrect)')

In [None]:
# Export model too
t = time.time()
# prefix = input('model prefix: ').strip()
prefix = '100_tffinetune'
export_path = f'./saved_models/{prefix}-{int(t)}'
model.save(export_path, save_format='tf')

print(f'Export the model to {export_path}')

In [None]:
# Check the exported model too
reloaded = tf.keras.models.load_model(export_path)

result_batch = model.predict(image_batch)
reloaded_result_batch = reloaded.predict(image_batch)

print(f'Comparison between own model and exported model {abs(reloaded_result_batch - result_batch).max()}')

In [None]:
# validation data classification result
validation_data.reset()

predicted_batch = model.predict(validation_data)
predicted_id = np.argmax(predicted_batch, axis=-1)

label_id = validation_data.classes

con_mat = tf.math.confusion_matrix(label_id, predicted_id)

result_df = pd.DataFrame(con_mat.numpy(), index=class_names, columns=class_names, dtype=int)

print('-- Validation result (Row: Actual Class, Column: Predicted Class) --')
print(result_df)