In [None]:
from preprocessing.patch_generator import smash_n_reconstruct
import preprocessing.filters as f
import tensorflow as tf
from keras import layers,Model
from keras.callbacks import ModelCheckpoint, EarlyStopping
import os

In [None]:
@tf.function
def hard_tanh(x):
    return tf.maximum(tf.minimum(x, 1), -1)

class featureExtractionLayer(layers.Layer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.conv = layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu')
        self.bn = layers.BatchNormalization()
        self.activation = layers.Lambda(hard_tanh)
        
    def call(self, input):
        x = self.conv(input)
        x = self.bn(x)
        x = self.activation(x)
        return x
        

In [None]:
input1 = layers.Input(shape=(256,256,1),name="rich_texture")
input2 = layers.Input(shape=(256,256,1),name="poor_texture")

l1 = featureExtractionLayer(name="feature_extraction_layer_rich_texture")(input1)
l2 = featureExtractionLayer(name="feature_extraction_layer_poor_texture")(input2)

contrast = layers.subtract((l1,l2))

x = layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu')(contrast)
x = layers.BatchNormalization()(x)
for i in range(3):
    x = layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu')(x)
    x = layers.BatchNormalization()(x)
x = layers.BatchNormalization()(x)

for i in range(4):
    x = layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu')(x)
    x = layers.BatchNormalization()(x)
x = layers.AveragePooling2D()(x)

for i in range(2):
    x = layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu')(x)
    x = layers.BatchNormalization()(x)
x = layers.AveragePooling2D()(x)

for i in range(2):
    x = layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu')(x)
    x = layers.BatchNormalization()(x)
x = layers.GlobalAveragePooling2D()(x)

x = layers.Flatten()(x)
x = layers.Dense(1,activation='sigmoid')(x)

model = Model(inputs=(input1,input2), outputs=x, name="rich_texture_poor_texture_contrast")
model.compile(
                optimizer='adam',
                loss='BinaryCrossentropy',
                metrics='binary_accuracy'
            )
model.summary()

In [None]:
path_ai = './test_imgs/dataset/fakeV2/fake-v2/'
ai_imgs = [os.path.join(path_ai,img) for img in os.listdir(path_ai)]
ai_label = [1 for i in range(len(ai_imgs))]
path_real = './test_imgs/dataset/real/'
real_imgs = [os.path.join(path_real,img) for img in os.listdir(path_real)]
real_label = [0 for i in range(len(real_imgs))]
print(len(real_imgs),len(ai_imgs))
X_train = ai_imgs[:-21] + real_imgs[:-21]
y_train = ai_label[:-21] + real_label[:-21]
X_validate = ai_imgs[-21:] + real_imgs[-21:]
y_validate = ai_label[-21:] + real_label[-21:]
len(X_train),len(y_train),len(X_validate),len(y_validate)

In [None]:
def preprocess(path,label:int):
    rt,pt = smash_n_reconstruct(path.numpy().decode('utf-8'))
    frt = tf.cast(tf.expand_dims(f.apply_all_filters(rt),axis=-1),dtype=tf.float64)
    fpt = tf.cast(tf.expand_dims(f.apply_all_filters(pt), axis=-1),dtype=tf.float64)
    return frt,fpt,label

In [None]:
def dict_map(X1,X2,y):
    return {
        'rich_texture':X1,
        'poor_texture':X2
    },y

## Making data pipeline

In [None]:
batch_size = 32

dataset = (tf.data.Dataset.from_tensor_slices((X_train,y_train))
           .shuffle(len(X_train))
           .map(
                lambda filepath,label: 
                tf.py_function(preprocess, [filepath, label],[tf.float64, tf.float64, tf.int32])
            ).map(dict_map)
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE)
        )

validation_set = (tf.data.Dataset.from_tensor_slices((X_validate,y_validate))
           .map(
                lambda filepath,label: 
                tf.py_function(preprocess, [filepath, label],[tf.float64, tf.float64, tf.int32])
            ).map(dict_map)
            .batch(10)
            .prefetch(tf.data.AUTOTUNE)
        )

In [None]:
checkpoint_path = "./checkpoints/model_checkpoint.h5"
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_path, 
                                      monitor='val_loss', 
                                      save_best_only=True,
                                      save_weights_only=True,
                                      verbose=1)

early_stopping_callback = EarlyStopping(monitor='val_loss', 
                                        patience=5,
                                        verbose=1, 
                                        restore_best_weights=True)


## Training the model

In [None]:
model.fit(dataset, epochs=5, batch_size=1, validation_data=validation_set,callbacks=[checkpoint_callback, early_stopping_callback])

In [None]:
model.save('./classifier.h5')