In [None]:
# Setup library
## install -r requirements.txt
from __future__ import absolute_import, division, print_function, unicode_literals
import os
from operator import itemgetter

import matplotlib.pylab as plt
# %matplotlib widget
%matplotlib inline

import numpy as np
import tensorflow as tf
tf.random.set_seed(99)
from tensorflow.keras import layers, models

In [None]:
# Global variables
# Setup scripts (or notebook)
IMG_DATA = './dataset_tma/sampled_152'
IMG_SHAPE = (39, 39)

In [None]:
%%time
# prepare dataset
dataset_root = os.path.abspath(os.path.expanduser(IMG_DATA))
print(f'Dataset root: {dataset_root}')

image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255,
                                                                  validation_split=0.17)
train_data = image_generator.flow_from_directory(dataset_root, target_size=IMG_SHAPE,
                                                 subset='training')
validation_data = image_generator.flow_from_directory(dataset_root, target_size=IMG_SHAPE,
                                                 subset='validation')

for image_batch, label_batch in validation_data:
    print(f'Image batch shape: {image_batch.shape}')
    print(f'Label batch shape: {label_batch.shape}')
    break

class_names = sorted(validation_data.class_indices.items(), key=itemgetter(1))
class_names = np.array([key.title() for key, value in class_names])
print(f'Classes: {class_names}')

In [None]:
## get result labels
predicted_id = np.argmax(label_batch, axis=-1)
predicted_label_batch = class_names[predicted_id]

fig1 = plt.figure(figsize=(10, 10))
for i in range(25):
    ax = fig1.add_subplot(5, 5, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.grid(False)
    ax.imshow(image_batch[i])
    ax.set_xlabel(predicted_label_batch[i], color='brown')

In [None]:
# Create CNN
model = models.Sequential()
model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=IMG_SHAPE + (3, )))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(64, (3, 3), activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64, activation='relu'))
model.add(layers.Dense(len(validation_data.class_indices)))
model.summary()

In [None]:
# Train build
## Compile model for train
base_learning_rate = 0.01 # Adam default: 0.001 SGD default: 0.01
model.compile(
#     optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
    optimizer=tf.keras.optimizers.SGD(learning_rate=base_learning_rate),
    loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])

## Log class
### https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback
class CollectBatchStats(tf.keras.callbacks.Callback):
    def __init__(self):
        self.batch_losses = []
        self.batch_val_losses = []
        self.batch_acc = []
        self.batch_val_acc = []
    
    def on_epoch_end(self, epoch, logs=None):
        self.batch_losses.append(logs['loss'])
        self.batch_acc.append(logs['accuracy'])
        self.batch_val_losses.append(logs['val_loss'])
        self.batch_val_acc.append(logs['val_accuracy'])
        self.model.reset_metrics()

In [None]:
steps_per_epoch = np.ceil(train_data.samples/train_data.batch_size) # train all dataset per epoch
initial_epoch = 100
batch_stats_callback = CollectBatchStats()
earlystop_callback = EarlyStopping(monitor='val_loss', min_delta=0.0001, patience=3)

history = model.fit_generator(train_data,
                              epochs=initial_epoch,
                              steps_per_epoch=steps_per_epoch,
                              validation_data=validation_data,
                              callbacks=[batch_stats_callback, earlystop_callback])

In [None]:
# Draw learning curves chart
acc = batch_stats_callback.batch_acc
val_acc = batch_stats_callback.batch_val_acc
loss = batch_stats_callback.batch_losses
val_loss = batch_stats_callback.batch_val_losses

fig2 = plt.figure(figsize=(8, 8))
ax1 = fig2.add_subplot(2, 1, 1)
ax1.plot(acc, label='Training Accuracy')
ax1.plot(val_acc, label='Validation Accuracy')
ax1.legend(loc='lower right')
ax1.set_ylabel('Accuracy')
ax1.set_ylim([0, 1])
ax1.set_title('Training and Validation Accuracy')

ax2 = fig2.add_subplot(2, 1, 2)
ax2.plot(loss, label='Training Loss')
ax2.plot(val_loss, label='Validation Loss')
ax2.legend(loc='upper right')
ax2.set_ylabel('Cross Entropy')
ax2.set_ylim([0,max(ax2.get_ylim())])
ax2.set_title('Training and Validation Loss')
ax2.set_xlabel('epoch')

In [None]:
# import pickle
# with open('asis.pickle', 'wb') as f:
#     pickle.dump((acc, val_acc, loss, val_loss), f)
print(max(val_acc))