In [2]:
import os
import numpy as np
import matplotlib.pyplot as plt
import itertools
import tensorflow as tf
from keras import Sequential
from keras.layers import Conv2D, AvgPool2D, Flatten, Dense, Dropout
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping, ReduceLROnPlateau
from keras.utils import to_categorical
from PIL import Image, ImageChops
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report
from tqdm import tqdm

np.random.seed(2)

2023-10-06 21:34:05.784469: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-10-06 21:34:05.784513: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-10-06 21:34:05.787194: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-10-06 21:34:05.997595: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
def preprocess_images(dataset_path, path_original, path_tampered):
    """Preprocess images and return paths."""
    valid_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']

    total_original = [f for f in os.listdir(dataset_path + path_original) if os.path.splitext(f)[1].lower() in valid_extensions]
    total_tampered = [f for f in os.listdir(dataset_path + path_tampered) if os.path.splitext(f)[1].lower() in valid_extensions]

    pristine_images = [dataset_path + path_original + i for i in total_original]
    fake_images = [dataset_path + path_tampered + i for i in total_tampered]

    return pristine_images, fake_images


def load_and_preprocess_image(image_path):
    """Load and preprocess an image from its path."""
    img = Image.open(image_path).convert('RGB')
    img = img.resize((224, 224))
    img_array = np.array(img)
    return img_array


def cnn_model():
    """Define the CNN model."""
    model = Sequential()
    model.add(Conv2D(filters=128, kernel_size=(5, 5), activation='relu', input_shape=(224, 224, 3)))
    model.add(AvgPool2D(pool_size=(2, 2)))
    model.add(Conv2D(filters=256, kernel_size=(5, 5), activation='relu'))
    model.add(AvgPool2D(pool_size=(2, 2)))
    model.add(Flatten())
    model.add(Dense(units=128, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(units=32, activation='relu'))
    model.add(Dropout(0.5))
    model.add(Dense(units=2, activation='softmax'))
    return model


def plot_metrics(history):
    """Plot training metrics."""
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title("Accuracy")
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.legend(['train', 'test'])
    plt.show()

    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title("Loss")
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.legend(['train', 'test'])
    plt.show()

In [None]:
def main():
    # Paths and preprocessing
    dataset_path = '/home/jaki/Dev/cnn_ela/CASIA2/'
    path_original = 'Au/'
    path_tampered = 'Tp/'
    pristine_images, fake_images = preprocess_images(dataset_path, path_original, path_tampered)
    # print(1, pristine_images)

    # Labeling the images
    pristine_labels = [0] * len(pristine_images)
    fake_labels = [1] * len(fake_images)
    # print(2, pristine_labels)
    # print(3, fake_labels)

    # Combining the images and labels
    all_images = pristine_images + fake_images
    all_labels = pristine_labels + fake_labels
    # print(4, all_images)
    # print(5, all_labels)

    # Convert image paths to actual image data
    x_train, x_dev, y_train, y_dev = train_test_split(all_images, all_labels, test_size=0.2, random_state=133,
                                                      shuffle=True)
    x_train = np.array([load_and_preprocess_image(img_path) for img_path in x_train])
    x_dev = np.array([load_and_preprocess_image(img_path) for img_path in x_dev])
    # print(6, x_train)
    # print(7, x_dev)
    # print(8, y_train)
    # print(9, y_dev)
    # print(10, x_train)
    # print(11, x_dev)

    # # Normalize the image data to [0, 1] range
    x_train = x_train.astype('float32') / 255.0
    x_dev = x_dev.astype('float32') / 255.0
    # print(12, x_train)
    # print(13, x_dev)

    # # Convert labels to one-hot encoding
    y_train = to_categorical(y_train, 2)
    y_dev = to_categorical(y_dev, 2)
    # print(14, y_train)
    # print(15, y_dev)

    # Model definition, compilation, and training
    model = cnn_model()
    optimizer = Adam(learning_rate=1e-4)
    # print(16, optimizer)
    model.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['accuracy'])
    early_stop = EarlyStopping(monitor='val_accuracy', patience=6, verbose=1, restore_best_weights=True)
    # print(17, early_stop)
    reduce_lr = ReduceLROnPlateau(monitor='val_accuracy', factor=0.22, patience=6, verbose=1, min_delta=0.0001,
                                  min_lr=0.0001)
    # print(18, reduce_lr)
    history = model.fit(x_train, y_train, epochs=30, validation_data=(x_dev, y_dev), callbacks=[early_stop, reduce_lr],
                        verbose=1, shuffle=True)
    print(19, history)

    # Plot metrics
    plot_metrics(history)

    # Evaluation
    Y_pred = model.predict(x_dev)
    Y_pred_classes = np.argmax(Y_pred, axis=1)
    Y_true = np.argmax(y_dev, axis=1)
    confusion_mtx = confusion_matrix(Y_true, Y_pred_classes)
    # Assuming you have a function or method to plot the confusion matrix
    # plot_confusion_matrix(confusion_mtx, classes=range(2))
    print(classification_report(Y_true, Y_pred_classes))


if __name__ == "__main__":
    main()
