In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import tensorflow_io as tfio

from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras import layers

In [None]:
BATCH_SIZE=32
IMG_SIZE=(224, 224)
# IMG_SIZE=(1024, 1024)
SEED=123

In [None]:
train_dataset = image_dataset_from_directory("train", shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE, validation_split=0.2, subset="training", seed=SEED)

In [None]:
validation_dataset = image_dataset_from_directory("train", shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE, validation_split=0.2, subset="validation", seed=SEED)

In [None]:
test_dataset = image_dataset_from_directory("test", shuffle=True, batch_size=BATCH_SIZE, image_size=IMG_SIZE)

In [None]:
class_names = train_dataset.class_names

In [None]:
plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

In [None]:
data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

In [None]:
# Create the base model from the pre-trained model ResNet152 V2
IMG_SHAPE = IMG_SIZE + (3,)
# ResNet
# preprocess_input = tf.keras.applications.resnet_v2.preprocess_input
# base_model = tf.keras.applications.ResNet152V2(
#     include_top=False, weights='imagenet',
#     input_shape=IMG_SHAPE, pooling=None, classes=len(class_names),
#     classifier_activation='softmax'
# )

#MobileNet
preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
base_model = tf.keras.applications.MobileNetV2(
    input_shape=IMG_SHAPE, alpha=1.0, include_top=False, weights='imagenet', classes=len(class_names),
    classifier_activation='softmax')

#DenseNet201
# preprocess_input = tf.keras.applications.densenet.preprocess_input
# base_model = tf.keras.applications.DenseNet201(
#     include_top=False, weights='imagenet',
#     input_shape=IMG_SHAPE, classes=len(class_names)
# )

#InceptionV2
# preprocess_input = tf.keras.applications.inception_resnet_v2.preprocess_input
# base_model = tf.keras.applications.InceptionResNetV2(
#     include_top=False, weights='imagenet',
#     input_shape=IMG_SHAPE, classes=len(class_names), classifier_activation='softmax')

# EfficientNet B3
# preprocess_input = tf.keras.applications.efficientnet.preprocess_input
# base_model = tf.keras.applications.EfficientNetB7(
#     include_top=False, weights='imagenet',
#     input_shape=IMG_SHAPE, pooling=None, classes=len(class_names),
#     classifier_activation='softmax')

In [None]:
image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)

In [None]:
base_model.trainable = False
base_model.summary()

In [None]:
global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)

In [None]:
prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)

In [None]:
# Let's try to create our own Keras layer for converting to red
class RedAlert(layers.Layer):
    def __init__(self, img_shape, scale = 1.0, name=None, **kwargs):
        self.img_shape = img_shape
        self.scale = scale
        self.mask = tf.broadcast_to(tf.constant([1., 0., 0.]), img_shape)
        super(RedAlert, self).__init__(name=name, **kwargs)
        
    def call(self, inputs):
        red_only = inputs * self.mask
        c = tf.constant([1.])
        scale = tf.broadcast_to(c, self.img_shape)
        m2 = red_only * scale
        return tf.clip_by_value(m2, clip_value_min = 0, clip_value_max=255)

In [None]:
class Colorspace(layers.Layer):
    def __init__(self, colorspace, name=None, **kwargs):
        self.colorspace = colorspace
        super(Colorspace, self).__init__(name=name, **kwargs)
        
    def call(self, inputs):
        return self.colorspace(inputs)
        

In [None]:
inputs = tf.keras.Input(shape=IMG_SHAPE)
x = data_augmentation(inputs)

process = tf.keras.Sequential([
#     RedAlert(img_shape = IMG_SHAPE, scale = 1.3),
#     tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1),
    tf.keras.layers.experimental.preprocessing.Rescaling(1./255), # Rescale [0,1] for rgb processing
    Colorspace(colorspace = tf.image.rgb_to_hsv),
    tf.keras.layers.experimental.preprocessing.Rescaling(1./.5, offset= -1), # rescale [-1, 1] for model
])

x = process(x)
x = base_model(x, training=True)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

In [None]:
base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])

In [None]:
model.summary()

In [None]:
history = model.fit(train_dataset,
                    epochs=50,
                    validation_data=validation_dataset)