In [None]:
from preprocessing.patch_generator import smash_n_reconstruct
import preprocessing.filters as f
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras import layers
from keras.callbacks import ModelCheckpoint, EarlyStopping
import os
from tqdm import tqdm
from PIL import Image
import gc

In [None]:
def hard_tanh(x):
    return tf.clip_by_value(x,-1,1)

trainable_model = keras.Sequential([
        layers.Input(shape=(256,256,1)),
        layers.Conv2D(filters=32, kernel_size=(3,3), activation='relu'),
        layers.BatchNormalization(),
        layers.Lambda(hard_tanh)
    ])

trainable_model.compile(optimizer='adam',loss=keras.losses.BinaryCrossentropy,metrics=['accuracy'])
trainable_model.build()

In [None]:
classifier = keras.Sequential([
        layers.Input(shape=(254,254,32)),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.AveragePooling2D(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.AveragePooling2D(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.AveragePooling2D(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.Conv2D(filters=32,kernel_size=(3,3),activation='relu'),
        layers.BatchNormalization(),
        layers.GlobalAveragePooling2D(),
        layers.Flatten(),
        layers.Dense(1,activation='sigmoid')
    ])

classifier.compile(
                    optimizer='adam',
                    loss='BinaryCrossentropy',
                    metrics='binary_accuracy'
                )

classifier.build()

In [None]:
def preprocess(path,label:int):
    print(f'🖼️image path: - {path}')
    rt,pt = smash_n_reconstruct(path.numpy().decode('utf-8'))
    frt = tf.constant([f.apply_all_filters(rt)])
    fpt = tf.constant([f.apply_all_filters(pt)])
    return (trainable_model.predict(frt)-trainable_model.predict(fpt))[0],label

In [None]:
def dict_map(X,y):
    return {
        'X':X,
        'y':y
    }

In [None]:
path_ai = './test_imgs/dataset/fakeV2/fake-v2'
ai_imgs = [os.path.join(path_ai,img) for img in os.listdir(path_ai)]
ai_label = [1 for i in range(len(ai_imgs))]
path_real = './test_imgs/dataset/real'
real_imgs = [os.path.join(path_real,img) for img in os.listdir(path_real)]
real_label = [0 for i in range(len(real_imgs))]
print(len(real_imgs),len(ai_imgs))
X_train = ai_imgs[:-21] + real_imgs[:-21]
y_train = ai_label[:-21] + real_label[:-21]
X_validate = ai_imgs[-21:] + real_imgs[-21:]
y_validate = ai_label[-21:] + real_label[-21:]
len(X_train),len(y_train),len(X_validate),len(y_validate)

## Making data pipeline

In [None]:
batch_size = 32

In [None]:
dataset = (tf.data.Dataset.from_tensor_slices((X_train,y_train))
           .shuffle(len(X_train))
           .map(
                lambda filepath,label: 
                tf.py_function(preprocess, [filepath, label],[tf.float64, tf.int32])
            )
            .batch(batch_size)
            .prefetch(tf.data.AUTOTUNE)
        )

validation_set = (tf.data.Dataset.from_tensor_slices((X_validate,y_validate))
           .map(
                lambda filepath,label: 
                tf.py_function(preprocess, [filepath, label],[tf.float64, tf.int32])
            )
            .batch(10)
            .prefetch(tf.data.AUTOTUNE)
        )

In [None]:
checkpoint_path = "./checkpoints/model_checkpoint.h5"
checkpoint_callback = ModelCheckpoint(filepath=checkpoint_path, 
                                      monitor='val_loss', 
                                      save_best_only=True,
                                      save_weights_only=True,
                                      verbose=1)

early_stopping_callback = EarlyStopping(monitor='val_loss', 
                                        patience=5,
                                        verbose=1, 
                                        restore_best_weights=True)


## Training the model

In [None]:
model = keras.models.load_model('./classifier.h5')

In [None]:
for data in dataset:
    train_data = data[0]
    train_labels = data[1]
    gc.collect()
    model.fit(x=train_data,y=train_labels,epochs=5, validation_data=(validation_set), callbacks=[checkpoint_callback, early_stopping_callback])
    gc.collect()
    break

In [None]:
import preprocessing.patch_generator as p

In [None]:
path = './test_imgs/dataset/real\\z6ewevdaap5a1.png'
path[-4:]
# img = Image.open(path)
# img = img.resize((256,256))
# img = img.convert('RGB')
# np.array(img).shape
# img.size
# a1,a2 = p.img_to_patches(path)
# vv = [p.get_pixel_var_degree_for_patch(patch) for patch in a1]
# b1,b2 = p.extract_rich_and_poor_textures(vv,a2)
# c1 = p.get_complete_image(b1).shape
# c2 = p.get_complete_image(b2).shape
# c1,c2

In [None]:
# classifier.fit(dataset,epochs=5,validation_data=validation_set,callbacks=[checkpoint_callback, early_stopping_callback])

In [None]:
# classifier.save('./classifier.h5')