In [None]:
# Exploratory Data Analysis

# In this notebook, we'll explore and visualize the dataset to understand its characteristics and distribution.

## Import Libraries

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator


In [None]:
# Visualize Sample Images
# Let's load and visualize a few sample images from each class.

def visualize_samples(directory, num_samples=5):
    datagen = ImageDataGenerator(rescale=1./255)
    generator = datagen.flow_from_directory(directory,
                                             target_size=(256, 256),
                                             batch_size=num_samples,
                                             class_mode=None,
                                             shuffle=True)
    images = generator.next()
    plt.figure(figsize=(10, 10))
    for i in range(num_samples):
        plt.subplot(1, num_samples, i + 1)
        plt.imshow(images[i])
        plt.axis('off')
    plt.show()

print("Real images:")
visualize_samples('data/raw/train/real', num_samples=5)
print("Fake images:")
visualize_samples('data/raw/train/fake', num_samples=5)


In [None]:
# Dataset Summary
# Get an overview of the dataset distribution.

train_dir = 'data/raw/train/'
validation_dir = 'data/raw/validation/'
test_dir = 'data/raw/test/'

train_datagen = ImageDataGenerator()
train_generator = train_datagen.flow_from_directory(train_dir)
validation_generator = train_datagen.flow_from_directory(validation_dir)
test_generator = train_datagen.flow_from_directory(test_dir)

print(f"Training samples: {train_generator.samples}")
print(f"Validation samples: {validation_generator.samples}")
print(f"Test samples: {test_generator.samples}")


In [None]:
# Class Distribution
# Visualize the class distribution.

def plot_class_distribution(generator, title='Class Distribution'):
    class_labels = list(generator.class_indices.keys())
    class_counts = [generator.samples for _ in class_labels]
    
    plt.figure(figsize=(8, 6))
    sns.barplot(x=class_labels, y=class_counts)
    plt.title(title)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.show()

plot_class_distribution(train_generator, title='Training Class Distribution')
