In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Dataset Exploration\n",
    "\n",
    "This notebook explores the datasets used for deepfake detection training and evaluation. It provides statistics, visualizations, and insights into the data distributions and characteristics.\n",
    "\n",
    "## Datasets:\n",
    "- FaceForensics++ - A diverse dataset with various deepfake methods\n",
    "- Celeb-DF - High-quality celebrity deepfake videos\n",
    "- Combined datasets for more robust training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "import torch\n",
    "from torch.utils.data import DataLoader\n",
    "import torchvision.transforms as transforms\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Import project modules\n",
    "from data.datasets.faceforensics import FaceForensicsDataset\n",
    "from data.datasets.celebdf import CelebDFDataset\n",
    "from data.datasets.custom_dataset import DeepfakeDataset\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Configure Dataset Paths\n",
    "\n",
    "Set the paths to the datasets on your system."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure dataset paths - update these to your local paths\n",
    "FACEFORENSICS_ROOT = \"/path/to/datasets/FaceForensics\"\n",
    "CELEBDF_ROOT = \"/path/to/datasets/CelebDF\"\n",
    "\n",
    "# Check if directories exist\n",
    "ff_exists = os.path.exists(FACEFORENSICS_ROOT)\n",
    "celebdf_exists = os.path.exists(CELEBDF_ROOT)\n",
    "\n",
    "print(f\"FaceForensics++ path exists: {ff_exists}\")\n",
    "print(f\"Celeb-DF path exists: {celebdf_exists}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Datasets\n",
    "\n",
    "Load each dataset and examine its structure."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Only execute if datasets exist\n",
    "# Create transform for visualization\n",
    "vis_transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor()\n",
    "])\n",
    "\n",
    "datasets = {}\n",
    "\n",
    "# Load FaceForensics++ dataset if available\n",
    "if ff_exists:\n",
    "    print(\"Loading FaceForensics++ dataset...\")\n",
    "    methods = [\"Deepfakes\", \"Face2Face\", \"FaceSwap\", \"NeuralTextures\"]\n",
    "    \n",
    "    ff_dataset = FaceForensicsDataset(\n",
    "        root=FACEFORENSICS_ROOT,\n",
    "        split=\"train\",  # Use train split for exploration\n",
    "        img_size=224,\n",
    "        transform=vis_transform,\n",
    "        methods=methods\n",
    "    )\n",
    "    \n",
    "    datasets[\"faceforensics\"] = ff_dataset\n",
    "    print(f\"FaceForensics++ dataset loaded with {len(ff_dataset)} samples\")\n",
    "else:\n",
    "    print(\"FaceForensics++ dataset path not found. Skipping...\")\n",
    "\n",
    "# Load Celeb-DF dataset if available\n",
    "if celebdf_exists:\n",
    "    print(\"Loading Celeb-DF dataset...\")\n",
    "    \n",
    "    celebdf_dataset = CelebDFDataset(\n",
    "        root=CELEBDF_ROOT,\n",
    "        split=\"train\",  # Use train split for exploration\n",
    "        img_size=224,\n",
    "        transform=vis_transform\n",
    "    )\n",
    "    \n",
    "    datasets[\"celebdf\"] = celebdf_dataset\n",
    "    print(f\"Celeb-DF dataset loaded with {len(celebdf_dataset)} samples\")\n",
    "else:\n",
    "    print(\"Celeb-DF dataset path not found. Skipping...\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Analyze Class Distribution\n",
    "\n",
    "Check the balance between real and fake samples in each dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def analyze_class_distribution(dataset, name):\n",
    "    \"\"\"Analyze and visualize class distribution in a dataset\"\"\"\n",
    "    # Count real (0) and fake (1) samples\n",
    "    real_count = 0\n",
    "    fake_count = 0\n",
    "    \n",
    "    for _, label in tqdm(dataset, desc=f\"Analyzing {name}\"):\n",
    "        if label == 0:\n",
    "            real_count += 1\n",
    "        else:\n",
    "            fake_count += 1\n",
    "    \n",
    "    total = real_count + fake_count\n",
    "    real_percent = (real_count / total) * 100\n",
    "    fake_percent = (fake_count / total) * 100\n",
    "    \n",
    "    print(f\"\\n{name} Dataset Distribution:\")\n",
    "    print(f\"Total samples: {total}\")\n",
    "    print(f\"Real samples: {real_count} ({real_percent:.2f}%)\")\n",
    "    print(f\"Fake samples: {fake_count} ({fake_percent:.2f}%)\")\n",
    "    \n",
    "    # Visualize distribution\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    sns.barplot(x=['Real', 'Fake'], y=[real_count, fake_count])\n",
    "    plt.title(f\"{name} Dataset Class Distribution\")\n",
    "    plt.ylabel('Count')\n",
    "    plt.show()\n",
    "    \n",
    "    return {'real': real_count, 'fake': fake_count, 'total': total}\n",
    "\n",
    "# Analyze each dataset\n",
    "distributions = {}\n",
    "for name, dataset in datasets.items():\n",
    "    distributions[name] = analyze_class_distribution(dataset, name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Visualize Sample Images\n",
    "\n",
    "Display sample images from each dataset to better understand the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def visualize_samples(dataset, name, num_samples=5):\n",
    "    \"\"\"Visualize random samples from the dataset\"\"\"\n",
    "    # Create dataloaders\n",
    "    dataloader = DataLoader(dataset, batch_size=num_samples*2, shuffle=True)\n",
    "    \n",
    "    # Get a batch\n",
    "    images, labels = next(iter(dataloader))\n",
    "    \n",
    "    # Separate real and fake samples\n",
    "    real_images = [img for img, label in zip(images, labels) if label == 0]\n",
    "    fake_images = [img for img, label in zip(images, labels) if label == 1]\n",
    "    \n",
    "    # Keep only num_samples of each\n",
    "    real_images = real_images[:num_samples]\n",
    "    fake_images = fake_images[:num_samples]\n",
    "    \n",
    "    # Plot\n",
    "    fig, axes = plt.subplots(2, num_samples, figsize=(15, 6))\n",
    "    fig.suptitle(f\"{name} Dataset Samples\", fontsize=16)\n",
    "    \n",
    "    # Plot real samples (top row)\n",
    "    for i, img in enumerate(real_images):\n",
    "        img = img.permute(1, 2, 0).numpy()  # CHW -> HWC\n",
    "        axes[0, i].imshow(img)\n",
    "        axes[0, i].set_title(\"Real\")\n",
    "        axes[0, i].axis('off')\n",
    "    \n",
    "    # Plot fake samples (bottom row)\n",
    "    for i, img in enumerate(fake_images):\n",
    "        img = img.permute(1, 2, 0).numpy()  # CHW -> HWC\n",
    "        axes[1, i].imshow(img)\n",
    "        axes[1, i].set_title(\"Fake\")\n",
    "        axes[1, i].axis('off')\n",
    "    \n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "    plt.show()\n",
    "\n",
    "# Visualize samples from each dataset\n",
    "for name, dataset in datasets.items():\n",
    "    visualize_samples(dataset, name, num_samples=5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Analyze Image Quality and Properties\n",
    "\n",
    "Analyze image properties like brightness, contrast, and quality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def calculate_image_metrics(image_tensor):\n",
    "    \"\"\"Calculate image metrics: brightness, contrast, blur\"\"\"\n",
    "    # Convert tensor to numpy\n",
    "    img = image_tensor.permute(1, 2, 0).numpy()\n",
    "    img = (img * 255).astype(np.uint8)  # Scale to 0-255\n",
    "    \n",
    "    # Convert to grayscale for some metrics\n",
    "    gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)\n",
    "    \n",
    "    # Calculate brightness (mean pixel value)\n",
    "    brightness = np.mean(gray)\n",
    "    \n",
    "    # Calculate contrast (standard deviation)\n",
    "    contrast = np.std(gray)\n",
    "    \n",
    "    # Calculate blur level (variance of Laplacian)\n",
    "    blur = cv2.Laplacian(gray, cv2.CV_64F).var()\n",
    "    \n",
    "    return {\n",
    "        'brightness': brightness,\n",
    "        'contrast': contrast,\n",
    "        'blur': blur\n",
    "    }\n",
    "\n",
    "def analyze_image_properties(dataset, name, num_samples=100):\n",
    "    \"\"\"Analyze image properties for a subset of the dataset\"\"\"\n",
    "    # Create DataLoader\n",
    "    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "    \n",
    "    # Initialize lists for metrics\n",
    "    metrics_real = []\n",
    "    metrics_fake = []\n",
    "    \n",
    "    # Analyze samples\n",
    "    real_count = 0\n",
    "    fake_count = 0\n",
    "    target_count = num_samples // 2  # Target for each class\n",
    "    \n",
    "    for img, label in tqdm(dataloader, desc=f\"Analyzing {name} properties\"):\n",
    "        # Calculate metrics\n",
    "        metrics = calculate_image_metrics(img[0])\n",
    "        \n",
    "        # Add to appropriate list\n",
    "        if label.item() == 0 and real_count < target_count:\n",
    "            metrics_real.append(metrics)\n",
    "            real_count += 1\n",
    "        elif label.item() == 1 and fake_count < target_count:\n",
    "            metrics_fake.append(metrics)\n",
    "            fake_count += 1\n",
    "        \n",
    "        # Check if we have enough samples\n",
    "        if real_count >= target_count and fake_count >= target_count:\n",
    "            break\n",
    "    \n",
    "    # Convert to DataFrames\n",
    "    df_real = pd.DataFrame(metrics_real)\n",
    "    df_real['class'] = 'Real'\n",
    "    \n",
    "    df_fake = pd.DataFrame(metrics_fake)\n",
    "    df_fake['class'] = 'Fake'\n",
    "    \n",
    "    # Combine\n",
    "    df = pd.concat([df_real, df_fake])\n",
    "    \n",
    "    # Plot distributions\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
    "    fig.suptitle(f\"{name} Dataset Image Properties\", fontsize=16)\n",
    "    \n",
    "    # Brightness\n",
    "    sns.boxplot(x='class', y='brightness', data=df, ax=axes[0])\n",
    "    axes[0].set_title('Brightness Distribution')\n",
    "    \n",
    "    # Contrast\n",
    "    sns.boxplot(x='class', y='contrast', data=df, ax=axes[1])\n",
    "    axes[1].set_title('Contrast Distribution')\n",
    "    \n",
    "    # Blur\n",
    "    sns.boxplot(x='class', y='blur', data=df, ax=axes[2])\n",
    "    axes[2].set_title('Blur Level Distribution')\n",
    "    \n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "    plt.show()\n",
    "    \n",
    "    return df\n",
    "\n",
    "# Analyze image properties for each dataset\n",
    "image_properties = {}\n",
    "for name, dataset in datasets.items():\n",
    "    image_properties[name] = analyze_image_properties(dataset, name, num_samples=100)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Face Analysis\n",
    "\n",
    "Analyze facial features in the datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def analyze_faces(dataset, name, num_samples=20):\n",
    "    \"\"\"Analyze facial features using a face detection library\"\"\"\n",
    "    try:\n",
    "        import dlib\n",
    "        print(\"Using dlib for face analysis...\")\n",
    "    except ImportError:\n",
    "        print(\"dlib not installed. Please install dlib for face analysis:\")\n",
    "        print(\"pip install dlib\")\n",
    "        return\n",
    "    \n",
    "    # Load face detector and landmark predictor\n",
    "    try:\n",
    "        detector = dlib.get_frontal_face_detector()\n",
    "        predictor_path = \"shape_predictor_68_face_landmarks.dat\"  # You need to download this file\n",
    "        if os.path.exists(predictor_path):\n",
    "            predictor = dlib.shape_predictor(predictor_path)\n",
    "        else:\n",
    "            print(f\"Landmark predictor model not found at {predictor_path}\")\n",
    "            print(\"Please download it from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2\")\n",
    "            return\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading dlib models: {e}\")\n",
    "        return\n",
    "    \n",
    "    # Create dataloader\n",
    "    dataloader = DataLoader(dataset, batch_size=1, shuffle=True)\n",
    "    \n",
    "    # Function to draw landmarks\n",
    "    def draw_landmarks(img, landmarks):\n",
    "        img_copy = img.copy()\n",
    "        for i in range(68):\n",
    "            x, y = landmarks.part(i).x, landmarks.part(i).y\n",
    "            cv2.circle(img_copy, (x, y), 2, (0, 255, 0), -1)\n",
    "        return img_copy\n",
    "    \n",
    "    # Analyze samples\n",
    "    real_samples = []\n",
    "    fake_samples = []\n",
    "    real_count = 0\n",
    "    fake_count = 0\n",
    "    target_count = num_samples // 2\n",
    "    \n",
    "    for img, label in tqdm(dataloader, desc=f\"Analyzing {name} faces\"):\n",
    "        img_np = img[0].permute(1, 2, 0).numpy() \n",
    "        img_np = (img_np * 255).astype(np.uint8)\n",
    "        \n",
    "        # Detect faces\n",
    "        faces = detector(img_np)\n",
    "        if len(faces) == 0:\n",
    "            continue\n",
    "        \n",
    "        face = faces[0]  # Use first face\n",
    "        landmarks = predictor(img_np, face)\n",
    "        \n",
    "        # Draw landmarks\n",
    "        img_landmarks = draw_landmarks(img_np, landmarks)\n",
    "        \n",
    "        # Add to appropriate list\n",
    "        if label.item() == 0 and real_count < target_count:\n",
    "            real_samples.append((img_np, img_landmarks))\n",
    "            real_count += 1\n",
    "        elif label.item() == 1 and fake_count < target_count:\n",
    "            fake_samples.append((img_np, img_landmarks))\n",
    "            fake_count += 1\n",
    "        \n",
    "        # Check if we have enough samples\n",
    "        if real_count >= target_count and fake_count >= target_count:\n",
    "            break\n",
    "    \n",
    "    # Visualize results\n",
    "    if real_samples and fake_samples:\n",
    "        rows = min(len(real_samples), len(fake_samples))\n",
    "        fig, axes = plt.subplots(rows, 4, figsize=(15, rows*3))\n",
    "        fig.suptitle(f\"{name} Dataset Face Analysis\", fontsize=16)\n",
    "        \n",
    "        for i in range(rows):\n",
    "            # Real samples\n",
    "            axes[i, 0].imshow(real_samples[i][0])\n",
    "            axes[i, 0].set_title(\"Real\")\n",
    "            axes[i, 0].axis('off')\n",
    "            \n",
    "            axes[i, 1].imshow(real_samples[i][1])\n",
    "            axes[i, 1].set_title(\"Real (Landmarks)\")\n",
    "            axes[i, 1].axis('off')\n",
    "            \n",
    "            # Fake samples\n",
    "            axes[i, 2].imshow(fake_samples[i][0])\n",
    "            axes[i, 2].set_title(\"Fake\")\n",
    "            axes[i, 2].axis('off')\n",
    "            \n",
    "            axes[i, 3].imshow(fake_samples[i][1])\n",
    "            axes[i, 3].set_title(\"Fake (Landmarks)\")\n",
    "            axes[i, 3].axis('off')\n",
    "        \n",
    "        plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "        plt.show()\n",
    "\n",
    "# Try analyzing faces if dlib is installed\n",
    "try:\n",
    "    import dlib\n",
    "    for name, dataset in datasets.items():\n",
    "        analyze_faces(dataset, name, num_samples=10)\n",
    "except ImportError:\n",
    "    print(\"dlib not installed. Skipping face analysis.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Cross-Dataset Analysis\n",
    "\n",
    "Compare properties across different datasets to understand their similarities and differences."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Skip if we don't have multiple datasets to compare\n",
    "if len(datasets) > 1 and len(image_properties) > 1:\n",
    "    # Prepare data for comparison\n",
    "    comparison_data = []\n",
    "    \n",
    "    for name, df in image_properties.items():\n",
    "        df_copy = df.copy()\n",
    "        df_copy['dataset'] = name\n",
    "        comparison_data.append(df_copy)\n",
    "    \n",
    "    # Combine data\n",
    "    combined_df = pd.concat(comparison_data)\n",
    "    \n",
    "    # Create comparison plots\n",
    "    fig, axes = plt.subplots(1, 3, figsize=(18, 6))\n",
    "    fig.suptitle(\"Cross-Dataset Comparison\", fontsize=16)\n",
    "    \n",
    "    # Brightness comparison\n",
    "    sns.boxplot(x='dataset', y='brightness', hue='class', data=combined_df, ax=axes[0])\n",
    "    axes[0].set_title('Brightness Comparison')\n",
    "    \n",
    "    # Contrast comparison\n",
    "    sns.boxplot(x='dataset', y='contrast', hue='class', data=combined_df, ax=axes[1])\n",
    "    axes[1].set_title('Contrast Comparison')\n",
    "    \n",
    "    # Blur comparison\n",
    "    sns.boxplot(x='dataset', y='blur', hue='class', data=combined_df, ax=axes[2])\n",
    "    axes[2].set_title('Blur Level Comparison')\n",
    "    \n",
    "    plt.tight_layout(rect=[0, 0, 1, 0.95])\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"Need at least 2 datasets for cross-dataset comparison.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Summary and Insights\n",
    "\n",
    "Summarize key findings and insights about the datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Create summary table\n",
    "if distributions:\n",
    "    summary_data = []\n",
    "    \n",
    "    for name, dist in distributions.items():\n",
    "        row = {\n",
    "            'Dataset': name,\n",
    "            'Total Samples': dist['total'],\n",
    "            'Real Samples': dist['real'],\n",
    "            'Fake Samples': dist['fake'],\n",
    "            'Real %': round((dist['real'] / dist['total']) * 100, 2),\n",
    "            'Fake %': round((dist['fake'] / dist['total']) * 100, 2)\n",
    "        }\n",
    "        summary_data.append(row)\n",
    "    \n",
    "    summary_df = pd.DataFrame(summary_data)\n",
    "    print(\"Dataset Summary:\")\n",
    "    display(summary_df)\n",
    "    \n",
    "    # Add insights\n",
    "    print(\"\\nKey Insights:\")\n",
    "    for name in distributions.keys():\n",
    "        print(f\"\\n{name.upper()} Dataset:\")\n",
    "        print(f\"- Contains {distributions[name]['total']} total samples\")\n",
    "        print(f\"- Class balance: {distributions[name]['real']} real vs {distributions[name]['fake']} fake samples\")\n",
    "        \n",
    "        if name in image_properties:\n",
    "            df = image_properties[name]\n",
    "            real_df = df[df['class'] == 'Real']\n",
    "            fake_df = df[df['class'] == 'Fake']\n",
    "            \n",
    "            # Brightness comparison\n",
    "            real_brightness = real_df['brightness'].mean()\n",
    "            fake_brightness = fake_df['brightness'].mean()\n",
    "            brightness_diff = abs(real_brightness - fake_brightness)\n",
    "            \n",
    "            if brightness_diff > 10:\n",
    "                print(f\"- Notable brightness difference between real and fake samples: {brightness_diff:.1f}\")\n",
    "                \n",
    "            # Contrast comparison\n",
    "            real_contrast = real_df['contrast'].mean()\n",
    "            fake_contrast = fake_df['contrast'].mean()\n",
    "            contrast_diff = abs(real_contrast - fake_contrast)\n",
    "            \n",
    "            if contrast_diff > 5:\n",
    "                print(f\"- Notable contrast difference between real and fake samples: {contrast_diff:.1f}\")\n",
    "                \n",
    "            # Blur comparison\n",
    "            real_blur = real_df['blur'].mean()\n",
    "            fake_blur = fake_df['blur'].mean()\n",
    "            blur_diff = abs(real_blur - fake_blur)\n",
    "            \n",
    "            if blur_diff > 100:\n",
    "                print(f\"- Notable blur level difference between real and fake samples: {blur_diff:.1f}\")\n",
    "else:\n",
    "    print(\"No dataset distributions available for summary.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Potential Improvements\n",
    "\n",
    "Based on the analysis, here are some potential improvements for data preparation:"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Augmentation Suggestions\n",
    "\n",
    "1. **Address class imbalance**: If there's a significant imbalance between real and fake samples, consider techniques like:\n",
    "   - Oversampling the minority class\n",
    "   - Undersampling the majority class\n",
    "   - Using weighted loss functions\n",
    "\n",
    "2. **Normalize image properties**: If there are significant differences in brightness, contrast, or blur levels between real and fake samples, consider:\n",
    "   - Applying consistent normalization\n",
    "   - Using augmentations that specifically target these differences\n",
    "\n