In [None]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models, Input
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau

# Enable mixed precision (optional, but can significantly speed up training on GPUs)
from tensorflow.keras.mixed_precision import Policy, set_global_policy

policy = Policy('mixed_float16') #float16 лучше всего работает на видеокартах nvidia tensor core
set_global_policy(policy)

# Check if GPU is available
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
if len(tf.config.list_physical_devices('GPU')) > 0:
    print("GPU is available, using GPU for training.")
else:
    print("GPU is not available, using CPU for training.  Training will be slower.")

# 1. Data Loading and Preprocessing (moved to function)
def create_data_generators(data_dir, image_size, batch_size, validation_split=0.2):
    """
    Creates training and validation data generators using ImageDataGenerator.

    Args:
        data_dir (str): Path to the directory containing the data.  Must have subdirectories
                         for each class (e.g., data_dir/class1, data_dir/class2).
        image_size (tuple): Target size of the images (width, height).
        batch_size (int): Batch size for training.
        validation_split (float): Fraction of data to use for validation (0.0 to 1.0).

    Returns:
        tuple: (train_generator, validation_generator)
    """
    datagen = ImageDataGenerator(
        rescale=1.0 / 255,  # Normalize pixel values to [0, 1]
        validation_split=validation_split  # Split into training and validation
    )

    train_generator = datagen.flow_from_directory(
        data_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode='binary',  # Binary classification: real/fake
        subset='training',
        seed=42 # Add seed for reproducibility
    )

    validation_generator = datagen.flow_from_directory(
        data_dir,
        target_size=image_size,
        batch_size=batch_size,
        class_mode='binary',  # Binary classification: real/fake
        subset='validation',
        seed=42 # Add seed for reproducibility
    )
    return train_generator, validation_generator


# 2. Model Creation (Assuming create_efficientnetb4() is defined elsewhere)

# 3. Model Training (moved to function)
def train_model(model, train_generator, validation_generator, epochs, batch_size, early_stopping_patience=5, reduce_lr_factor=0.2, reduce_lr_patience=3):
    """
    Trains a Keras model using the provided data generators and callbacks.

    Args:
        model (keras.Model): The Keras model to train.
        train_generator (keras.preprocessing.image.DirectoryIterator): Training data generator.
        validation_generator (keras.preprocessing.image.DirectoryIterator): Validation data generator.
        epochs (int): Number of epochs to train for.
        batch_size (int): Batch size used for data generation.
        early_stopping_patience (int): Patience for EarlyStopping callback.
        reduce_lr_factor (float): Factor for ReduceLROnPlateau callback.
        reduce_lr_patience (int): Patience for ReduceLROnPlateau callback.

    Returns:
        keras.callbacks.History: Training history object.
    """

    early_stopping = EarlyStopping(monitor='val_loss', patience=early_stopping_patience, restore_best_weights=True)
    reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=reduce_lr_factor, patience=reduce_lr_patience, min_lr=0.00001)

    model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    model.summary()  # Display model summary

    history = model.fit(
        train_generator,
        steps_per_epoch=train_generator.samples // batch_size,
        epochs=epochs,
        validation_data=validation_generator,
        validation_steps=validation_generator.samples // batch_size,
        callbacks=[early_stopping, reduce_lr]
    )
    return history


# 4. Model Evaluation (moved to function)
def evaluate_model(model, validation_generator, target_names=['Real', 'Fake']):
    """
    Evaluates the trained model on the validation set and prints classification report
    and confusion matrix.  Also displays a confusion matrix heatmap.

    Args:
        model (keras.Model): The trained Keras model.
        validation_generator (keras.preprocessing.image.DirectoryIterator): Validation data generator.
        target_names (list): List of class names for the classification report.
    """
    true_labels = validation_generator.classes
    predictions = model.predict(validation_generator)
    predicted_labels = (predictions > 0.5).astype(int)

    print(classification_report(true_labels, predicted_labels, target_names=target_names))

    cm = confusion_matrix(true_labels, predicted_labels)
    print("Confusion Matrix:")
    print(cm)

    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=target_names, yticklabels=target_names)
    plt.xlabel('Predicted Label')
    plt.ylabel('True Label')
    plt.title('Confusion Matrix')
    plt.show()


# 5. Plot Training History (moved to function)
def plot_training_history(history):
    """
    Plots the training accuracy and loss curves.

    Args:
        history (keras.callbacks.History): The training history object returned by model.fit().
    """
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'], label='Training Accuracy')
    plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.title('Accuracy vs. Epoch')

    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'], label='Training Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.title('Loss vs. Epoch')

    plt.tight_layout()
    plt.show()

# 6. Main execution Block

