# Deepfake Image Detector Using DL

**Download the dataset from Kaggle using kagglehub**

In [None]:
!pip install kagglehub

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("manjilkarki/deepfake-and-real-images")

print("Path to dataset files:", path)

You might need to run the above kagglehub cell again to download the dataset.

**Import the required libraries**

In [None]:
# Core Libraries
import numpy as np
import tensorflow as tf

# CNN, CapsNet, and Functional API Tools
from tensorflow.keras.layers import Conv2D, Dense, Flatten, Input, GlobalAveragePooling2D
from tensorflow.keras.layers import Layer, Lambda, Reshape, BatchNormalization, Softmax
from tensorflow.keras import layers, models

# Xception Model (Transfer Learning)
from tensorflow.keras.applications import Xception
from tensorflow.keras.applications.xception import preprocess_input
from tensorflow.keras.models import Model

# Visualization
import matplotlib.pyplot as plt

# Utilities for Handling Image Files and Paths
import os
from tensorflow.keras.preprocessing.image import load_img, img_to_array


**Load the images from the Train, Validation, and Test directories**

In [None]:
# Define the paths for Train, Validation, and Test directories
dataset_path = "/root/.cache/kagglehub/datasets/manjilkarki/deepfake-and-real-images/versions/1/Dataset"
train_path = os.path.join(dataset_path, 'Train')
val_path = os.path.join(dataset_path, 'Validation')
test_path = os.path.join(dataset_path, 'Test')

# Function to load datasets using TensorFlow's API
def load_tf_dataset(data_path, batch_size=32, img_size=(299, 299)):
    dataset = tf.keras.preprocessing.image_dataset_from_directory(
        data_path,
        label_mode='binary',   # For Real/Fake classification
        batch_size=batch_size,
        image_size=img_size,
        shuffle=True  # Shuffle the dataset to improve training performance
    )
    return dataset

# Load Train, Validation, and Test datasets
train_dataset = load_tf_dataset(train_path)
val_dataset = load_tf_dataset(val_path)
test_dataset = load_tf_dataset(test_path)

In [None]:
# Check class names
print(f"Train Class Names: {train_dataset.class_names}")
print(f"Validation Class Names: {val_dataset.class_names}")
print(f"Test Class Names: {test_dataset.class_names}")

## Exploratory Data Analysis (EDA)

**1. Class Distribution Plot**

To ensure that the dataset is balanced (i.e., equal or close number of 'Real' and 'Fake' images are present)

In [None]:
def plot_class_distribution(dataset, dataset_name):
    # Collect all labels in a list
    labels = []
    for _, batch_labels in dataset:
        labels.extend(batch_labels.numpy())

    # Calculate the count of each class (Real=0, Fake=1)
    real_count = labels.count(0)
    fake_count = labels.count(1)

    # Plot the class distribution
    plt.figure(figsize=(6, 4))
    plt.bar(['Real', 'Fake'], [real_count, fake_count], color=['skyblue', 'salmon'])
    plt.xlabel('Class')
    plt.ylabel('No of Images')
    plt.title(f'Class Distribution in {dataset_name} Dataset')
    plt.show()

# Plot class distributions for Train, Validation, and Test datasets
plot_class_distribution(train_dataset, "Train")
plot_class_distribution(val_dataset, "Validation")
plot_class_distribution(test_dataset, "Test")


From the bar graphs, we can see that the dataset is balanced (number of 'Real' and 'Fake' images are equal)

**2. Display Some Sample Images**

To confirm that the images are loaded correctly and match their labels.

In [None]:
def show_samples(dataset, class_names, num_samples=6):
    plt.figure(figsize=(12, 8))
    for images, labels in dataset.take(1):
        for i in range(num_samples):
            ax = plt.subplot(3, 3, i + 1)
            plt.imshow(images[i].numpy().astype("uint8"))
            plt.title(f"{class_names[int(labels[i])]}")
            plt.axis("off")

# Display 9 samples from the Train dataset
show_samples(train_dataset, train_dataset.class_names, num_samples=9)



**3. Check Image Shapes and Label Types**

In [None]:
for images, labels in train_dataset.take(1):  # Display the shape of one batch
    print(f"Image batch shape: {images.shape}")
    print(f"Label batch shape: {labels.shape}")
    print(f"Label Data Type: {labels.dtype}")

**4. Plot the Pixel Value Distribution**

In [None]:
def plot_pixel_distribution(dataset):
    pixel_values = []
    for images, _ in dataset.take(1):
        pixel_values.extend(images.numpy().flatten())

    plt.figure(figsize=(8, 6))
    plt.hist(pixel_values, bins=50, color='coral')
    plt.xlabel('Pixel Value')
    plt.ylabel('Frequency')
    plt.title('Pixel Value Distribution')
    plt.show()

# Plot the pixel distribution for the Train dataset
plot_pixel_distribution(train_dataset)

## Model 1: CNN

In [None]:
# Define CNN model architecture
def create_cnn_model(input_shape=(299, 299, 3)):
    model = models.Sequential()
    model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(64, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Conv2D(128, (3, 3), activation='relu'))
    model.add(layers.MaxPooling2D((2, 2)))
    model.add(layers.Flatten())
    model.add(layers.Dense(64, activation='relu'))
    model.add(layers.Dense(1, activation='sigmoid'))
    return model

# Create an instance of the CNN model
cnn_model = create_cnn_model()

In [None]:
# Compile the model
cnn_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy'])

In [None]:
# Display the model summary
cnn_model.summary()

In [None]:
# Since the dataset is very huge, we are only taking a subset of the Train, Validation and Test datasets
subset_size = 15
train_subset = train_dataset.take(subset_size)
val_subset = val_dataset.take(subset_size)
test_subset = test_dataset.take(subset_size)

# Apply caching and prefetching for performance
train_subset = train_subset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_subset = val_subset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
test_subset = test_subset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# model training
history = cnn_model.fit(train_subset, epochs=5, validation_data=val_subset)

In [None]:
# Evaluate the model on the test set
loss, accuracy = cnn_model.evaluate(test_subset)
# Results
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

## Model 2: CapsNets

In [None]:
from tensorflow.keras import backend as K

# Custom squash function for non-linear activation
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

# Capsule Layer with correct batch_dot dimensions
class CapsuleLayer(Layer):
    def __init__(self, num_capsules, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.routings = routings

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.num_capsules * self.dim_capsule),
            initializer='glorot_uniform',
            trainable=True
        )

    def call(self, inputs):
        # Reshape inputs for correct batch_dot alignment
        u_hat = K.dot(inputs, self.kernel)
        u_hat = K.reshape(u_hat, (-1, inputs.shape[1], self.num_capsules, self.dim_capsule))

        b = K.zeros_like(u_hat[:, :, :, 0])  # Initialize routing logits

        # Dynamic routing process
        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)  # Softmax over the num_capsules dimension
            s = K.sum(c[..., None] * u_hat, axis=1)  # Weighted sum across capsules
            v = squash(s)  # Squash to unit length
            if i < self.routings - 1:
                b += K.sum(u_hat * v[:, None, :, :], axis=-1)  # Update logits

        return v

# Define the CapsNet architecture
def create_capsnet(input_shape, n_classes):
    inputs = Input(shape=input_shape)

    # First convolutional layer
    conv1 = Conv2D(64, (9, 9), strides=2, padding='valid', activation='relu')(inputs)

    # Primary Capsule Layer
    primary_caps = Conv2D(128, (9, 9), strides=2, padding='valid', activation='relu')(conv1)
    primary_caps = Reshape((-1, 8))(primary_caps)

    # Capsule Layer with routing
    caps_layer = CapsuleLayer(num_capsules=n_classes, dim_capsule=16, routings=3)(primary_caps)

    # Output layer: Compute the length of capsule vectors for class probabilities
    outputs = Lambda(lambda z: K.sqrt(K.sum(K.square(z), axis=-1)))(caps_layer)

    # Add a Dense layer to ensure correct output shape
    outputs = Dense(n_classes, activation='softmax')(outputs)

    model = models.Model(inputs, outputs)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# Create the CapsNet model instance
capsnet_model = create_capsnet(input_shape=(299, 299, 3), n_classes=2)
# Display the model summary
capsnet_model.summary()

In [None]:
# Display the model summary
capsnet_model.summary()

In [None]:
# Prepare data subsets
subset_size = 15
train_subset = train_dataset.take(subset_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_subset = val_dataset.take(subset_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
test_subset = test_dataset.cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# Train the CapsNet model
history = capsnet_model.fit(train_subset, epochs=5, validation_data=val_subset)

In [None]:
# Evaluate the model on the test set
test_subset = test_subset.take(subset_size)
loss, accuracy = capsnet_model.evaluate(test_subset)
print(f"Test Loss: {loss:.4f}")
print(f"Test Accuracy: {accuracy:.4f}")

## Model 3: Xception

In [None]:
# Load the Xception model with pre-trained weights, excluding the top layers
def create_xception_model(input_shape=(299, 299, 3), n_classes=2):
    # Load the base model with ImageNet weights and exclude the top layers
    base_model = Xception(
        weights='imagenet',
        include_top=False,  # Exclude fully connected layers
        input_shape=input_shape
    )

    # Freeze the base model layers to retain pre-trained weights
    base_model.trainable = False

    # Add custom top layers for binary classification
    inputs = Input(shape=input_shape)
    x = preprocess_input(inputs)
    x = base_model(x, training=False)
    x = GlobalAveragePooling2D()(x)
    x = Dense(128, activation='relu')(x)
    outputs = Dense(n_classes, activation='softmax')(x)

    # Create the final model
    model = Model(inputs, outputs)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

    # Return both the full model and the base model for later fine-tuning
    return model, base_model

# Create the Xception model instance
xception_model, base_model = create_xception_model()

In [None]:
# Display the model summary
xception_model.summary()

In [None]:
# Prepare data subsets for training, validation, and testing
subset_size = 15
train_subset = train_dataset.take(subset_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
val_subset = val_dataset.take(subset_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)

# Train the Xception model
history = xception_model.fit(train_subset, epochs=5, validation_data=val_subset)

Epoch 1/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m443s[0m 30s/step - accuracy: 0.6114 - loss: 0.6715 - val_accuracy: 0.6792 - val_loss: 0.6053
Epoch 2/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m476s[0m 33s/step - accuracy: 0.8316 - loss: 0.4188 - val_accuracy: 0.6938 - val_loss: 0.5856
Epoch 3/5
[1m15/15[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m481s[0m 33s/step - accuracy: 0.8788 - loss: 0.3259 - val_accuracy: 0.6917 - val_loss: 0.6109
Epoch 4/5
[1m10/15[0m [32m━━━━━━━━━━━━━[0m[37m━━━━━━━[0m [1m1:10[0m 14s/step - accuracy: 0.9210 - loss: 0.2543

In [None]:
# Evaluate the model on the test set
test_subset = test_subset.take(subset_size).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
loss, accuracy = xception_model.evaluate(test_subset)
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Test Loss: {loss:.4f}")

In [None]:
from tensorflow.keras.applications import Xception
from tensorflow.keras.layers import Conv2D, Input, Reshape, Dense, Flatten
from tensorflow.keras.models import Model
import tensorflow.keras.backend as K
import tensorflow as tf

# Custom squash function
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

# Capsule Layer
class CapsuleLayer(tf.keras.layers.Layer):
    def __init__(self, num_capsules, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.routings = routings

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.num_capsules * self.dim_capsule),
            initializer='glorot_uniform',
            trainable=True
        )

    def call(self, inputs):
        u_hat = K.dot(inputs, self.kernel)
        u_hat = K.reshape(u_hat, (-1, inputs.shape[1], self.num_capsules, self.dim_capsule))

        b = K.zeros_like(u_hat[:, :, :, 0])

        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)
            s = K.sum(c[..., None] * u_hat, axis=1)
            v = squash(s)
            if i < self.routings - 1:
                b += K.sum(u_hat * v[:, None, :, :], axis=-1)
        return v

# Hybrid Model: Xception + CapsNet
def create_hybrid_model(input_shape=(299, 299, 3), num_classes=2):
    # Xception as feature extractor (without top layers)
    base_model = Xception(weights="imagenet", include_top=False, input_shape=input_shape)
    base_model.trainable = False  # Freeze pretrained layers

    x = base_model.output
    x = Conv2D(128, (1, 1), activation='relu')(x)  # Feature refinement
    x = Flatten()(x)  # Flatten before sending to Capsule Network
    x = Reshape((-1, 8))(x)

    # Capsule Network
    caps_layer = CapsuleLayer(num_capsules=num_classes, dim_capsule=16, routings=3)(x)
    output = Lambda(lambda z: K.sqrt(K.sum(K.square(z), axis=-1)))(caps_layer)
    output = Dense(num_classes, activation='softmax')(output)  # Final classification

    model = Model(inputs=base_model.input, outputs=output)
    model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    return model

# Create model
hybrid_model = create_hybrid_model()
hybrid_model.summary()


In [None]:
import tensorflow as tf
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Reshape, Lambda
from tensorflow.keras import Model
import tensorflow.keras.backend as K

# Custom squash function
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

# Capsule Layer
class CapsuleLayer(tf.keras.layers.Layer):
    def __init__(self, num_capsules, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.routings = routings

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.num_capsules * self.dim_capsule),
            initializer="glorot_uniform",
            trainable=True
        )

    def call(self, inputs):
        u_hat = K.dot(inputs, self.kernel)
        u_hat = K.reshape(u_hat, (-1, inputs.shape[1], self.num_capsules, self.dim_capsule))

        b = K.zeros_like(u_hat[:, :, :, 0])
        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)
            s = K.sum(c[..., None] * u_hat, axis=1)
            v = squash(s)
            if i < self.routings - 1:
                b += K.sum(u_hat * v[:, None, :, :], axis=-1)
        return v

# EfficientNet-based Hybrid Model
def create_efficientnet_capsnet(input_shape=(299, 299, 3), num_classes=2):
    inputs = Input(shape=input_shape)

    # EfficientNet feature extractor
    base_model = EfficientNetB0(weights="imagenet", include_top=False, input_tensor=inputs)
    base_model.trainable = False  # Freeze EfficientNet layers

    x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    x = Reshape((-1, 8))(x)  # Reshape to fit Capsule input

    # Capsule Network Layer
    caps_layer = CapsuleLayer(num_capsules=num_classes, dim_capsule=16, routings=3)(x)
    outputs = Lambda(lambda z: K.sqrt(K.sum(K.square(z), axis=-1)))(caps_layer)
    outputs = Dense(num_classes, activation="softmax")(outputs)

    model = Model(inputs, outputs)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

# Create and summarize the model
hybrid_model = create_efficientnet_capsnet()
hybrid_model.summary()


In [None]:
from tensorflow.keras.utils import plot_model
plot_model(hybrid_model, to_file="hybrid_model.png", show_shapes=True, show_layer_names=True)


In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE

def configure_dataset(dataset):
    return dataset.prefetch(buffer_size=AUTOTUNE)

train_dataset = configure_dataset(train_dataset)
val_dataset = configure_dataset(val_dataset)
test_dataset = configure_dataset(test_dataset)


In [None]:
train_dataset = train_dataset.take(1000)  # Use only 1000 batches
val_dataset = val_dataset.take(500)  # Use only 500 batches


In [None]:
train_dataset = load_tf_dataset(train_path, batch_size=64)
val_dataset = load_tf_dataset(val_path, batch_size=64)
test_dataset = load_tf_dataset(test_path, batch_size=64)


In [None]:
def process_img(img, label):
    img = tf.image.resize(img, (299, 299)) / 255.0  # Normalize images
    return img, label

def load_tf_dataset(data_path, batch_size=32):
    dataset = tf.keras.preprocessing.image_dataset_from_directory(
        data_path,
        label_mode='binary',
        batch_size=batch_size,
        image_size=(299, 299),
        shuffle=True
    ).map(process_img, num_parallel_calls=tf.data.AUTOTUNE).prefetch(buffer_size=tf.data.AUTOTUNE)

    return dataset


In [None]:
base_model.trainable = True  # Unfreeze EfficientNet
for layer in base_model.layers[:-50]:  # Freeze first 50 layers
    layer.trainable = False


In [None]:
from tensorflow.keras.applications import EfficientNetB0
from tensorflow.keras.layers import Input, Conv2D, Flatten, Dense, Reshape, Lambda
from tensorflow.keras import Model
import tensorflow.keras.backend as K

# Custom squash function
def squash(vectors, axis=-1):
    s_squared_norm = K.sum(K.square(vectors), axis, keepdims=True)
    scale = s_squared_norm / (1 + s_squared_norm) / K.sqrt(s_squared_norm + K.epsilon())
    return scale * vectors

# Capsule Layer
class CapsuleLayer(tf.keras.layers.Layer):
    def __init__(self, num_capsules, dim_capsule, routings=3, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.dim_capsule = dim_capsule
        self.routings = routings

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.num_capsules * self.dim_capsule),
            initializer="glorot_uniform",
            trainable=True
        )

    def call(self, inputs):
        u_hat = K.dot(inputs, self.kernel)
        u_hat = K.reshape(u_hat, (-1, inputs.shape[1], self.num_capsules, self.dim_capsule))
        b = K.zeros_like(u_hat[:, :, :, 0])
        for i in range(self.routings):
            c = tf.nn.softmax(b, axis=2)
            s = K.sum(c[..., None] * u_hat, axis=1)
            v = squash(s)
            if i < self.routings - 1:
                b += K.sum(u_hat * v[:, None, :, :], axis=-1)
        return v

# Updated EfficientNet-based Hybrid Model with input_shape=(224, 224, 3)
def create_efficientnet_capsnet(input_shape=(224, 224, 3), num_classes=2):
    inputs = Input(shape=input_shape)
    # EfficientNetB0 as feature extractor (default input shape for pretrained weights is 224x224)
    base_model = EfficientNetB0(weights="imagenet", include_top=False, input_tensor=inputs)
    base_model.trainable = False  # Freeze EfficientNet layers

    x = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
    x = Reshape((-1, 8))(x)  # Adjust reshaping if necessary based on new dimensions

    # Capsule Network Layer
    caps_layer = CapsuleLayer(num_capsules=num_classes, dim_capsule=16, routings=3)(x)
    outputs = Lambda(lambda z: K.sqrt(K.sum(K.square(z), axis=-1)))(caps_layer)
    outputs = Dense(num_classes, activation="softmax")(outputs)

    model = Model(inputs, outputs)
    model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
    return model

# Re-create the hybrid model with the new input shape
hybrid_model = create_efficientnet_capsnet()
hybrid_model.summary()


In [None]:
test_dataset = load_tf_dataset(test_path, batch_size=32)



In [None]:
# Unfreeze the last 20 layers  the EfficientNet base model
for i, layer in enumerate(hybrid_model.layers):
    print(i, layer.name, layer.__class__.__name__)

# Recompile with a lower learning rate for fine-tuning
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
hybrid_model.compile(optimizer=optimizer,
                     loss='sparse_categorical_crossentropy',
                     metrics=['accuracy'])


In [None]:
# Unfreeze EfficientNet base layers by checking the layer names
for layer in hybrid_model.layers:
    if layer.name.startswith("stem_") or layer.name.startswith("block"):
        layer.trainable = True

# Optionally, print layer names and trainable status to verify
for layer in hybrid_model.layers:
    print(layer.name, layer.trainable)

# Recompile the model with a lower learning rate for fine-tuning
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-5)
hybrid_model.compile(optimizer=optimizer,
                     loss='sparse_categorical_crossentropy',
                     metrics=['accuracy'])


In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications.efficientnet import EfficientNetB0, preprocess_input

# -------------------------------
# 1. Define Dataset Paths
# -------------------------------
dataset_path = "/root/.cache/kagglehub/datasets/manjilkarki/deepfake-and-real-images/versions/1/Dataset"
train_path = os.path.join(dataset_path, 'Train')
val_path   = os.path.join(dataset_path, 'Validation')
test_path  = os.path.join(dataset_path, 'Test')

print("Train Path:", train_path)
print("Validation Path:", val_path)
print("Test Path:", test_path)

# -------------------------------
# 2. Load and Preprocess Datasets
# -------------------------------
def load_raw_dataset(data_path, batch_size=32, img_size=(224,224)):
    return tf.keras.preprocessing.image_dataset_from_directory(
        data_path,
        label_mode='binary',   # binary classification: Real vs Fake
        batch_size=batch_size,
        image_size=img_size,
        shuffle=True
    )

# Load raw datasets first to capture class names
raw_train_dataset = load_raw_dataset(train_path, batch_size=32, img_size=(224,224))
train_class_names = raw_train_dataset.class_names

raw_val_dataset = load_raw_dataset(val_path, batch_size=32, img_size=(224,224))
val_class_names = raw_val_dataset.class_names

raw_test_dataset = load_raw_dataset(test_path, batch_size=32, img_size=(224,224))
test_class_names = raw_test_dataset.class_names

# Now apply EfficientNet preprocessing and prefetch
def preprocess_dataset(raw_dataset):
    return raw_dataset.map(lambda x, y: (preprocess_input(x), y),
                           num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

train_dataset = preprocess_dataset(raw_train_dataset)
val_dataset   = preprocess_dataset(raw_val_dataset)
test_dataset  = preprocess_dataset(raw_test_dataset)

print("Train Classes:", train_class_names)
print("Validation Classes:", val_class_names)
print("Test Classes:", test_class_names)

# -------------------------------
# 3. Define the CNN+BiLSTM Model
# -------------------------------
input_shape = (224, 224, 3)
inputs = tf.keras.Input(shape=input_shape)

# Use EfficientNetB0 as a feature extractor
base_model = EfficientNetB0(weights="imagenet", include_top=False, input_tensor=inputs)
base_model.trainable = False  # Freeze the backbone

# Extract features (expected shape: (None, H, W, channels), e.g., (None, 7, 7, 1280))
x = base_model.output
shape = tf.keras.backend.int_shape(x)  # e.g., (None, 7, 7, 1280)

# Reshape to a sequence for the LSTM: (batch, time_steps, features)
x = layers.Reshape((shape[1] * shape[2], shape[3]))(x)  # (None, 7*7=49, 1280)

# Process the sequence with a Bidirectional LSTM
x = layers.Bidirectional(layers.LSTM(128, return_sequences=False))(x)
x = layers.Dropout(0.5)(x)

# Final classification layer for 2 classes: Real and Fake
outputs = layers.Dense(2, activation="softmax")(x)

# Build and compile the model
bilstm_model = models.Model(inputs=inputs, outputs=outputs)
bilstm_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"]
)
bilstm_model.summary()

# -------------------------------
# 4. Train the Model
# -------------------------------
callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True),
    tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3)
]

history_bilstm = bilstm_model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=700,              # Adjust epochs as needed
    steps_per_epoch=50,     # Adjust based on your dataset size
    validation_steps=10,     # Adjust accordingly
    # callbacks=callbacks
)

# -------------------------------
# 5. Evaluate the Model
# -------------------------------
from sklearn.metrics import classification_report, confusion_matrix

y_pred = np.argmax(bilstm_model.predict(test_dataset), axis=1)
y_true = np.concatenate([y for x, y in test_dataset], axis=0)

print("Classification Report for CNN+BiLSTM Model:")
print(classification_report(y_true, y_pred, target_names=['Real', 'Fake']))
print("Confusion Matrix:")
print(confusion_matrix(y_true, y_pred))

# -------------------------------
# 6. Plot Training History
# -------------------------------
def plot_training_history(history, title="CNN+BiLSTM Training History"):
    plt.figure(figsize=(12, 5))

    # Accuracy plot
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Train Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title(title + " - Accuracy")

    # Loss plot
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.title(title + " - Loss")

    plt.show()

plot_training_history(history_bilstm)
