##### Imports and constant globals

In [None]:
import tensorflow as tf
from tensorflow.keras import mixed_precision

mixed_precision.set_global_policy('mixed_float16')

USING_LARGE = True

TRAIN_IMGS_DIRECTORY_SMALL = "data-small/train/"
VALIDATION_IMGS_DIRECTORY_SMALL = "data-small/valid/"

TRAIN_IMGS_DIRECTORY_LARGE = "data-large/train/"
VALIDATION_IMGS_DIRECTORY_LARGE = "data-large/valid/"
TEST_IMGS_DIRECTORY_LARGE = "data-large/valid/"

RESCALING_FACTOR = 1./0xFF
IMAGE_SIZE = (254, 254)
BATCH_SIZE = 8

#### Data augmentation

In [None]:
from keras.preprocessing.image import img_to_array, load_img
import os

data_gen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

def save_augs(src_dir, dest_dir, count_aug, max_to_aug):

        done = 1
        directory = os.fsencode(src_dir)
        
        for file in os.listdir(directory):
                if ~os.fsdecode(file).startswith("aug"):
                        img = load_img(f"{src_dir}/{os.fsdecode(file)}")  # this is a PIL image
                        x = img_to_array(img)  # this is a Numpy array with shape (3, 150, 150)
                        x = x.reshape((1,) + x.shape)  # this is a Numpy array with shape (1, 3, 150, 150)

                        # the .flow() command below generates batches of randomly transformed images
                        # and saves the results to the `preview/` directory
                        if done > max_to_aug:
                                break
                        i = 1
                        for batch in data_gen.flow(x, batch_size=1,
                                                save_to_dir=dest_dir, save_prefix='aug', save_format='jpg'):
                                i += 1
                                if i > count_aug:
                                        break  # otherwise the generator would loop indefinitely
                                
                        done += 1
                
#save_augs(f"{TRAIN_IMGS_DIRECTORY_LARGE}fake", f"{TRAIN_IMGS_DIRECTORY_LARGE}fake", 4, 2500)

#### Read train and test data

In [None]:


if USING_LARGE:
    # seems like shuffle has a negative effect on RAM usage and may cause OOM    
    image_gen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=RESCALING_FACTOR,
    )
    
    train_dataset = image_gen.flow_from_directory(
        directory = TRAIN_IMGS_DIRECTORY_LARGE,
        target_size = IMAGE_SIZE,
        class_mode = "binary",
        seed = 0,
        shuffle = True,
        batch_size = BATCH_SIZE
    )

    validation_dataset = image_gen.flow_from_directory(
        directory = VALIDATION_IMGS_DIRECTORY_LARGE,
        target_size = IMAGE_SIZE,
        class_mode = "binary",
        shuffle = False,
        seed = 0,
        batch_size = BATCH_SIZE
    ) 


    test_dataset = image_gen.flow_from_directory(
        directory = TEST_IMGS_DIRECTORY_LARGE,
        target_size = IMAGE_SIZE,
        class_mode = "binary",
        shuffle = True,
        seed = 0,
        batch_size = BATCH_SIZE
    ) 
else:
    image_gen = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=RESCALING_FACTOR,
        validation_split=0.15
    )
    train_dataset =  image_gen.flow_from_directory(
        directory = TRAIN_IMGS_DIRECTORY_SMALL,
        target_size = IMAGE_SIZE,
        class_mode = "binary",
        subset="training",
        shuffle = True,
        seed = 0,
        batch_size = BATCH_SIZE
    )

    validation_dataset =  image_gen.flow_from_directory(
        directory = TRAIN_IMGS_DIRECTORY_SMALL,
        target_size = IMAGE_SIZE,
        class_mode = "binary",
        subset="validation",
        shuffle = True,
        seed = 0,
        batch_size = BATCH_SIZE
    )

    test_dataset = image_gen.flow_from_directory(
        directory = VALIDATION_IMGS_DIRECTORY_SMALL,
        target_size = IMAGE_SIZE,
        class_mode = "binary",
        seed = 0,
        batch_size = BATCH_SIZE
    )

#### Data visualization

In [None]:
import matplotlib.pyplot as plt 
import tensorflow as tf

def visualize_image_dataset(dataset):
    plt.figure()    
    for data_batch in dataset:
    
        img_batch, label_batches = data_batch[0], data_batch[1]
        
        for idx, img in enumerate(img_batch):
            plt.subplot(3, 3, idx + 1)
            plt.imshow(img)
            plt.title("Real" if label_batches[idx] == 0.0 else "AI")
            plt.axis("off")
        
        break


visualize_image_dataset(train_dataset)

#### Create model

In [None]:
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, BatchNormalization

INPUT_SHAPE = (*IMAGE_SIZE, 3)

model = tf.keras.models.Sequential([    
    tf.keras.Input(shape=(*IMAGE_SIZE, 3)),
    tf.keras.applications.DenseNet121(weights="imagenet/densenet121_weights_tf_dim_ordering_tf_kernels_notop.h5", include_top=False),
    GlobalAveragePooling2D(),
    Dense(512, activation='relu'),
    BatchNormalization(),
    Dropout(0.2),
    Dense(64, activation = "relu"),
    Dense(1, activation = "sigmoid")
])


model.compile(
    optimizer = tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.BinaryCrossentropy(),
    metrics=[
        tf.keras.metrics.BinaryAccuracy(name="acc"),
        tf.keras.metrics.FalseNegatives(),
        tf.keras.metrics.FalsePositives(),
    ],
)


model.summary()

#### Train logic and options

In [None]:
from datetime import datetime

reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss',factor=0.2,patience=3)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor='val_loss',patience=5,restore_best_weights=True)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
    filepath="checkpoints",
    save_weights_only=False,
    monitor='val_loss',
    mode='min',
    save_best_only=True)

history = model.fit(
    train_dataset,
    epochs = 16,
    validation_data=validation_dataset,
    callbacks=[early_stopping_cb, reduce_lr, checkpoint_cb]
)
now = datetime.today().strftime('%Y-%m-%d')
tf.keras.saving.save_model(model, f"model_{now}.keras")

#### Evaluate model

In [None]:
model.evaluate(test_dataset)