In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Demo\n",
    "\n",
    "This notebook provides a demonstration of the deepfake detection system. It allows you to:\n",
    "\n",
    "1. Load pretrained deepfake detection models\n",
    "2. Analyze single images for deepfake detection\n",

In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Demo\n",
    "\n",
    "This notebook provides a demonstration of the deepfake detection system. It allows you to:\n",
    "\n",
    "1. Load pretrained deepfake detection models\n",
    "2. Analyze single images for deepfake detection\n",
    "3. Process videos frame-by-frame for deepfake detection\n",
    "4. Visualize model decisions and explain results\n",
    "5. Test with your own images or sample images\n",
    "\n",
    "The demo uses the transformer-based deepfake detection models (ViT, DeiT, and Swin) and includes visualization tools to help understand the decision-making process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "import time\n",
    "import argparse\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output, HTML\n",
    "\n",
    "# Add parent directory to path for importing project modules\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Import project modules\n",
    "from models.vit.model import ViT\n",
    "from models.deit.model import DeiT\n",
    "from models.swin.model import SwinTransformer\n",
    "from models.model_zoo.model_factory import create_model\n",
    "from data.preprocessing.face_extraction import setup_face_detector, extract_faces\n",
    "from data.preprocessing.normalization import normalize_face\n",
    "from evaluation.visualization.attention_maps import visualize_attention_maps\n",
    "from evaluation.visualization.grad_cam import visualize_grad_cam\n",
    "\n",
    "# Set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Pretrained Models\n",
    "\n",
    "First, let's load the pretrained models for deepfake detection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure paths - update these to your checkpoint paths\n",
    "CHECKPOINT_DIR = \"../trained_models\"\n",
    "VIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"vit_celebdf/checkpoints/best.pth\")\n",
    "DEIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"deit_celebdf/checkpoints/best.pth\")\n",
    "SWIN_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"swin_celebdf/checkpoints/best.pth\")\n",
    "\n",
    "# Check if checkpoints exist\n",
    "vit_exists = os.path.exists(VIT_CHECKPOINT)\n",
    "deit_exists = os.path.exists(DEIT_CHECKPOINT)\n",
    "swin_exists = os.path.exists(SWIN_CHECKPOINT)\n",
    "\n",
    "print(f\"ViT checkpoint exists: {vit_exists}\")\n",
    "print(f\"DeiT checkpoint exists: {deit_exists}\")\n",
    "print(f\"Swin checkpoint exists: {swin_exists}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_checkpoint(model, checkpoint_path, device):\n",
    "    \"\"\"Load model from checkpoint\"\"\"\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        print(f\"Checkpoint not found at {checkpoint_path}\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "        \n",
    "        # Different checkpoint formats\n",
    "        if 'model' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model'])\n",
    "        elif 'model_state_dict' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        else:\n",
    "            model.load_state_dict(checkpoint)\n",
    "            \n",
    "        model = model.to(device)\n",
    "        model.eval()  # Set to evaluation mode\n",
    "        print(f\"Model loaded successfully from {checkpoint_path}\")\n",
    "        return model\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading checkpoint: {e}\")\n",
    "        return None\n",
    "\n",
    "# Initialize models\n",
    "models = {}\n",
    "\n",
    "# Load ViT model\n",
    "if vit_exists:\n",
    "    vit_model = ViT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12\n",
    "    )\n",
    "    vit_model = load_checkpoint(vit_model, VIT_CHECKPOINT, device)\n",
    "    if vit_model is not None:\n",
    "        models['vit'] = vit_model\n",
    "\n",
    "# Load DeiT model\n",
    "if deit_exists:\n",
    "    deit_model = DeiT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        distillation=True\n",
    "    )\n",
    "    deit_model = load_checkpoint(deit_model, DEIT_CHECKPOINT, device)\n",
    "    if deit_model is not None:\n",
    "        models['deit'] = deit_model\n",
    "\n",
    "# Load Swin model\n",
    "if swin_exists:\n",
    "    swin_model = SwinTransformer(\n",
    "        img_size=224,\n",
    "        patch_size=4,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=96,\n",
    "        depths=[2, 2, 6, 2],\n",
    "        num_heads=[3, 6, 12, 24],\n",
    "        window_size=7\n",
    "    )\n",
    "    swin_model = load_checkpoint(swin_model, SWIN_CHECKPOINT, device)\n",
    "    if swin_model is not None:\n",
    "        models['swin'] = swin_model\n",
    "\n",
    "print(f\"Loaded {len(models)} models: {list(models.keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Set up Face Detection\n",
    "\n",
    "We need to set up face detection to extract faces from images and videos."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Set up face detector\n",
    "try:\n",
    "    face_detector = setup_face_detector(device='cpu')\n",
    "    print(\"Face detector set up successfully.\")\n",
    "except Exception as e:\n",
    "    print(f\"Error setting up face detector: {e}\")\n",
    "    print(\"Please install the required dependencies: pip install facenet-pytorch opencv-python\")\n",
    "    face_detector = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Image Preprocessing Functions\n",
    "\n",
    "Define functions for preprocessing images for the deepfake detection models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def preprocess_image(image_path, face_detector=None, target_size=224):\n",
    "    \"\"\"Preprocess an image for deepfake detection\"\"\"\n",
    "    # Load image\n",
    "    if isinstance(image_path, str):\n",
    "        # Load from file\n",
    "        if not os.path.exists(image_path):\n",
    "            print(f\"Image not found at {image_path}\")\n",
    "            return None\n",
    "        \n",
    "        img = cv2.imread(image_path)\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    else:\n",
    "        # Assume numpy array\n",
    "        img = image_path\n",
    "        if img.shape[2] == 3 and img.dtype == np.uint8:\n",
    "            # Likely BGR format from OpenCV\n",
    "            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    \n",
    "    # Extract face if detector is provided\n",
    "    if face_detector is not None:\n",
    "        faces = extract_faces(img, face_detector)\n",
    "        if not faces:\n",
    "            print(\"No faces detected in the image\")\n",
    "            # Just use the whole image if no faces detected\n",
    "            faces = [img]\n",
    "    else:\n",
    "        # Use the whole image if no detector\n",
    "        faces = [img]\n",
    "    \n",
    "    # Process each face\n",
    "    processed_faces = []\n",
    "    for face in faces:\n",
    "        # Resize to target size\n",
    "        face_resized = cv2.resize(face, (target_size, target_size))\n",
    "        \n",
    "        # Convert to float and normalize\n",
    "        face_float = face_resized.astype(np.float32) / 255.0\n",
    "        \n",
    "        # Normalize with ImageNet mean and std\n",
    "        face_normalized = (face_float - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])\n",
    "        \n",
    "        # Convert to tensor and add batch dimension (C, H, W)\n",
    "        face_tensor = torch.from_numpy(face_normalized.transpose(2, 0, 1)).float()\n",
    "        \n",
    "        processed_faces.append({\n",
    "            'original': face,\n",
    "            'resized': face_resized,\n",
    "            'tensor': face_tensor\n",
    "        })\n",
    "    \n",
    "    return processed_faces"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Deepfake Detection Functions\n",
    "\n",
    "Define functions for performing deepfake detection on images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def detect_deepfake(face_tensor, models, device):\n",
    "    \"\"\"Detect deepfake in a face image\"\"\"\n",
    "    # Check if we have models\n",
    "    if not models:\n",
    "        print(\"No models available for detection\")\n",
    "        return None\n",
    "    \n",
    "    # Move tensor to device\n",
    "    face_tensor = face_tensor.to(device).unsqueeze(0)  # Add batch dimension\n",
    "    \n",
    "    # Get predictions from each model\n",
    "    results = {}\n",
    "    for name, model in models.items():\n",
    "        with torch.no_grad():\n",
    "            # Forward pass\n",
    "            output = model(face_tensor)\n",
    "            \n",
    "            # Convert to probability\n",
    "            prob = torch.sigmoid(output).item()\n",
    "            \n",
    "            # Store result\n",
    "            results[name] = {\n",
    "                'probability': prob,\n",
    "                'prediction': 'Fake' if prob > 0.5 else 'Real',\n",
    "                'confidence': max(prob, 1 - prob)\n",
    "            }\n",
    "    \n",
    "    # Calculate ensemble prediction (simple averaging)\n",
    "    if len(results) > 1:\n",
    "        ensemble_prob = np.mean([r['probability'] for r in results.values()])\n",
    "        results['ensemble'] = {\n",
    "            'probability': ensemble_prob,\n",
    "            'prediction': 'Fake' if ensemble_prob > 0.5 else 'Real',\n",
    "            'confidence': max(ensemble_prob, 1 - ensemble_prob)\n",
    "        }\n",
    "    \n",
    "    return results\n",
    "\n",
    "def visualize_detection_results(face_dict, results):\n",
    "    \"\"\"Visualize deepfake detection results\"\"\"\n",
    "    # Check inputs\n",
    "    if face_dict is None or results is None:\n",
    "        print(\"No face or results to visualize\")\n",
    "        return\n",
    "    \n",
    "    # Set up figure\n",
    "    fig = plt.figure(figsize=(12, 8))\n",
    "    \n",
    "    # Plot original face\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.imshow(face_dict['original'])\n",
    "    plt.title('Input Face')\n",
    "    plt.axis('off')\n",
    "    \n",
    "    # Plot detection results\n",
    "    plt.subplot(1, 2, 2)\n",
    "    \n",
    "    # Create bar chart of fakeness probabilities\n",
    "    models = list(results.keys())\n",
    "    probs = [results[m]['probability'] for m in models]\n",
    "    \n",
    "    # Choose color based on prediction (green for real, red for fake)\n",
    "    colors = ['green' if results[m]['prediction'] == 'Real' else 'red' for m in models]\n",
    "    \n",
    "    bars = plt.barh(models, probs, color=colors)\n",
    "    plt.xlim(0, 1)\n",
    "    plt.xlabel('Probability of Fake')\n",
    "    plt.axvline(x=0.5, color='black', linestyle='--')\n",
    "    \n",
    "    # Add value labels\n",
    "    for i, bar in enumerate(bars):\n",
    "        plt.text(probs[i] + 0.01, bar.get_y() + bar.get_height()/2, \n",
    "                f\"{probs[i]:.2f} ({results[models[i]]['prediction']})\", \n",
    "                va='center')\n",
    "    \n",
    "    plt.grid(True, linestyle='--', alpha=0.7)\n",
    "    plt.title('Deepfake Detection Results')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    # Print summary\n",
    "    print(\"\\nDetection Summary:\")\n",
    "    for model, result in results.items():\n",
    "        print(f\"• {model}: {result['prediction']} with {result['confidence']*100:.1f}% confidence\")\n",
    "    \n",
    "    # Overall verdict (ensemble or only model)\n",
    "    if 'ensemble' in results:\n",
    "        verdict = results['ensemble']['prediction']\n",
    "        confidence = results['ensemble']['confidence']\n",
    "    else:\n",
    "        # Use the only model available\n",
    "        model = list(results.keys())[0]\n",
    "        verdict = results[model]['prediction']\n",
    "        confidence = results[model]['confidence']\n",
    "    \n",
    "    # Print overall verdict with confidence level text\n",
    "    confidence_level = \"high\" if confidence > 0.8 else \"moderate\" if confidence > 0.6 else \"low\"\n",
    "    print(f\"\\nOverall verdict: Image is {verdict} with {confidence_level} confidence ({confidence*100:.1f}%)\")\n",
    "    \n",
    "    return fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Explanation Functions\n",
    "\n",
    "Define functions for explaining model decisions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def find_target_layer(model, model_type):\n",
    "    \"\"\"Find the target layer for Grad-CAM based on model type\"\"\"\n",
    "    if model_type == 'vit':\n",
    "        # For ViT, we use the output of the last transformer block\n",
    "        return model.blocks[-1]\n",
    "    elif model_type == 'deit':\n",
    "        # For DeiT, we use the output of the last transformer block\n",
    "        return model.blocks[-1]\n",
    "    elif model_type == 'swin':\n",
    "        # For Swin, we use the output of the last layer\n",
    "        return model.layers[-1]\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown model type: {model_type}\")\n",
    "\n",
    "class GradCAM:\n",
    "    \"\"\"Grad-CAM implementation for transformer models\"\"\"\n",
    "    \n",
    "    def __init__(self, model, target_layer):\n",
    "        self.model = model\n",
    "        self.target_layer = target_layer\n",
    "        self.gradients = None\n",
    "        self.activations = None\n",
    "        \n",
    "        # Register hooks\n",
    "        self.register_hooks()\n",
    "    \n",
    "    def register_hooks(self):\n",
    "        def forward_hook(module, input, output):\n",
    "            self.activations = output.detach()\n",
    "        \n",
    "        def backward_hook(module, grad_input, grad_output):\n",
    "            self.gradients = grad_output[0].detach()\n",
    "        \n",
    "        # Register hooks\n",
    "        self.target_layer.register_forward_hook(forward_hook)\n",
    "        self.target_layer.register_backward_hook(backward_hook)\n",
    "    \n",
    "    def __call__(self, x, class_idx=None):\n",
    "        # Forward pass\n",
    "        b, c, h, w = x.size()\n",
    "        logits = self.model(x)\n",
    "        \n",
    "        # If class_idx is None, use the model's prediction\n",
    "        if class_idx is None:\n",
    "            if logits.dim() > 1 and logits.shape[1] > 1:\n",
    "                class_idx = torch.argmax(logits, dim=1).item()\n",
    "            else:\n",
    "                class_idx = (logits > 0.5).long().item()\n",
    "        \n",
    "        # Backward pass\n",
    "        self.model.zero_grad()\n",
    "        if logits.dim() > 1 and logits.shape[1] > 1:\n",
    "            target = torch.zeros_like(logits)\n",
    "            target[0, class_idx] = 1\n",
    "        else:\n",
    "            target = torch.ones_like(logits) if class_idx == 1 else torch.zeros_like(logits)\n",
    "        \n",
    "        logits.backward(gradient=target, retain_graph=True)\n",
    "        \n",
    "        # Special handling for transformer models\n",
    "        if hasattr(self.model, 'blocks'):\n",
    "            # For ViT/DeiT, reshape activations and gradients to match image size\n",
    "            if self.activations.dim() == 3:  # [B, L, D]\n",
    "                # Skip class token\n",
    "                activations = self.activations[:, 1:, :]\n",
    "                gradients = self.gradients[:, 1:, :]\n",
    "                \n",
    "                # Calculate patch size\n",
    "                patch_size = int(np.sqrt(w // int(np.sqrt(activations.shape[1]))))\n",
    "                num_patches = int(np.sqrt(activations.shape[1]))\n",
    "                \n",
    "                # Reshape to [B, H, W, D]\n",
    "                activations = activations.reshape(b, num_patches, num_patches, -1)\n",
    "                gradients = gradients.reshape(b, num_patches, num_patches, -1)\n",
    "                \n",
    "                # Get weights (average over patch dimensions)\n",
    "                weights = gradients.mean(dim=(1, 2)).unsqueeze(-1).unsqueeze(-1)\n",
    "                \n",
    "                # Compute weighted activation map\n",
    "                cam = (weights * activations).sum(dim=3)[0].detach().cpu().numpy()\n",
    "                \n",
    "                # Resize to original image size\n",
    "                cam = cv2.resize(cam, (w, h))\n",
    "            else:\n",
    "                # Fallback for unsupported format\n",
    "                cam = np.zeros((h, w), dtype=np.float32)\n",
    "        else:\n",
    "            # Traditional CNN approach\n",
    "            weights = self.gradients.mean(dim=(2, 3))[0]\n",
    "            activations = self.activations[0]\n",
    "            \n",
    "            # Compute weighted activation map\n",
    "            cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=x.device)\n",
    "            for i, w in enumerate(weights):\n",
    "                cam += w * activations[i]\n",
    "            \n",
    "            cam = cam.detach().cpu().numpy()\n",
    "            cam = cv2.resize(cam, (w, h))\n",
    "        \n",
    "        # Apply ReLU and normalize\n",
    "        cam = np.maximum(cam, 0)\n",
    "        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)\n",
    "        \n",
    "        return cam\n",
    "\n",
    "def get_attention_maps(model, img_tensor, model_type):\n",
    "    \"\"\"Get attention maps for visualization\"\"\"\n",
    "    # Check if model type is supported\n",
    "    if model_type not in ['vit', 'deit']:\n",
    "        print(f\"Attention map visualization not supported for {model_type}\")\n",
    "        return None\n",
    "    \n",
    "    # Move image to device\n",
    "    img_tensor = img_tensor.to(device).unsqueeze(0)  # Add batch dimension\n",
    "    \n",
    "    # Initialize list to store attention maps\n",
    "    attention_maps = []\n",
    "    \n",
    "    # Hook function to extract attention maps\n",
    "    def attention_hook(module, input, output):\n",
    "        attention_maps.append(output.detach().cpu())\n",
    "    \n",
    "    # Register hooks for attention layers\n",
    "    hooks = []\n",
    "    for block in model.blocks:\n",
    "        hooks.append(block.attn.register_forward_hook(attention_hook))\n",
    "    \n",
    "    # Forward pass\n",
    "    with torch.no_grad():\n",
    "        model(img_tensor)\n",
    "    \n",
    "    # Remove hooks\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "    \n",
    "    return attention_maps\n",
    "\n",
    "def explain_prediction(face_dict, model_name, model, device):\n",
    "    \"\"\"Explain model prediction using visualization techniques\"\"\"\n",
    "    # Get face tensor\n",
    "    face_tensor = face_dict['tensor'].to(device).unsqueeze(0)  # Add batch dimension\n",
    "    \n",
    "    # Get prediction\n",
    "    with torch.no_grad():\n",
    "        output = model(face_tensor)\n",
    "        prob = torch.sigmoid(output).item()\n",
    "        pred = 'Fake' if prob > 0.5 else 'Real'\n",
    "    \n",
    "    # Set up figure\n",
    "    fig = plt.figure(figsize=(15, 10))\n",
    "    \n",
    "    # 1. Original image\n",
    "    plt.subplot(2, 3, 1)\n",
    "    plt.imshow(face_dict['original'])\n",
    "    plt.title(f'Original Image\\nPrediction: {pred} ({prob:.2f})')\n",
    "    plt.axis('off')\n",
    "    \n",
    "    # 2. Grad-CAM visualization\n",
    "    try:\n",
    "        target_layer = find_target_layer(model, model_name)\n",
    "        grad_cam = GradCAM(model, target_layer)\n",
    "        cam = grad_cam(face_tensor)\n",
    "        \n",
    "        # Convert to heatmap\n",
    "        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)\n",
    "        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n",
    "        \n",
    "        # Resize to match image size\n",
    "        heatmap = cv2.resize(heatmap, (face_dict['original'].shape[1], face_dict['original'].shape[0]))\n",
    "        \n",
    "        # Overlay heatmap on image\n",
    "        superimposed = heatmap * 0.4 + face_dict['original'] * 0.6\n",
    "        superimposed = np.clip(superimposed, 0, 255).astype(np.uint8)\n",
    "        \n",
    "        plt.subplot(2, 3, 2)\n",
    "        plt.imshow(heatmap)\n",
    "        plt.title('Grad-CAM Heatmap')\n",
    "        plt.axis('off')\n",
    "        \n",
    "        plt.subplot(2, 3, 3)\n",
    "        plt.imshow(superimposed)\n",
    "        plt.title('Grad-CAM Overlay')\n",
    "        plt.axis('off')\n",
    "    except Exception as e:\n",
    "        print(f\"Error generating Grad-CAM: {e}\")\n",
    "    \n",
    "    # 3. Attention map visualization (for ViT/DeiT)\n",
    "    if model_name in ['vit', 'deit']:\n",
    "        try:\n",
    "            attention_maps = get_attention_maps(model, face_dict['tensor'], model_name)\n",
    "            \n",
    "            if attention_maps and len(attention_maps) > 0:\n",
    "                # Use the last attention map\n",
    "                attn_map = attention_maps[-1][0]  # First batch\n",
    "                \n",
    "                # Average over heads\n",
    "                avg_attn = attn_{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Demo\n",
    "\n",
    "This notebook provides a demonstration of the deepfake detection system. It allows you to:\n",
    "\n",
    "1. Load pretrained deepfake detection models\n",
    "2. Analyze single images for deepfake detection\n",

In [None]:
def download_sample_images():
    \"\"\"Download sample images for testing\"\"\"
    import requests
    from PIL import Image
    from io import BytesIO
    
    # Create a directory for sample images
    os.makedirs('sample_images', exist_ok=True)
    
    # Sample image URLs (replace these with your own samples or use a dataset)
    sample_urls = {
        'real_1': 'https://raw.githubusercontent.com/ondyari/FaceForensics/master/dataset/sample_images/original.png',
        'fake_1': 'https://raw.githubusercontent.com/ondyari/FaceForensics/master/dataset/sample_images/manipulated.png'
    }
    
    # Download and save samples
    for name, url in sample_urls.items():
        try:
            response = requests.get(url)
            if response.status_code == 200:
                img = Image.open(BytesIO(response.content))
                img_path = f'sample_images/{name}.png'
                img.save(img_path)
                print(f"Downloaded {name} to {img_path}")
            else:
                print(f"Failed to download {name}: HTTP {response.status_code}")
        except Exception as e:
            print(f"Error downloading {name}: {e}")
    
    print("\nSample images downloaded. You can use these for testing the detection system.")

# Create download button
download_button = widgets.Button(
    description='Download Samples',
    disabled=False,
    button_style='info',
    tooltip='Click to download sample images',
    icon='download'
)

# Define button callback
def on_download_button_clicked(b):
    download_sample_images()

# Register callback
download_button.on_click(on_download_button_clicked)

# Display button
display(download_button)

# Display existing samples if they exist
if os.path.exists('sample_images'):
    sample_files = [f for f in os.listdir('sample_images') if f.endswith(('.png', '.jpg', '.jpeg'))]
    if sample_files:
        print("Existing sample images:")
        for file in sample_files:
            print(f"• {file}")

## 9. Conclusion and Further Resources

print("""
# Conclusion

This demo demonstrates the use of transformer-based models for deepfake detection. The system supports:

- Single image analysis with multiple models
- Video analysis with frame-by-frame detection
- Visual explanations using Grad-CAM and attention maps
- Ensemble predictions by combining multiple models

# Further Resources

To learn more about deepfake detection:

1. FaceForensics++ dataset: https://github.com/ondyari/FaceForensics
2. Celeb-DF dataset: https://github.com/yuezunli/celeb-deepfakeforensics
3. Vision Transformers (ViT): https://arxiv.org/abs/2010.11929
4. Data-efficient Image Transformers (DeiT): https://arxiv.org/abs/2012.12877
5. Swin Transformer: https://arxiv.org/abs/2103.14030
6. Grad-CAM: https://arxiv.org/abs/1610.02391

# Next Steps

- Train models on additional datasets for better generalization
- Implement temporal analysis for video detection
- Explore multi-modal approaches combining image and audio analysis
- Deploy the system as a web service or mobile application
""")
## 7. Run Demo

# Create tabs for image and video detection
tab_titles = ['Image Detection', 'Video Detection']
tabs = widgets.Tab()
tabs.children = [create_image_detection_interface(), create_video_detection_interface()]

# Set tab titles
for i, title in enumerate(tab_titles):
    tabs.set_title(i, title)

# Display tabs
display(tabs)

print("\nDeepfake Detection Demo is ready to use!")
print("1. Upload an image or video file")
print("2. Select model(s) to use for detection")
print("3. Click 'Run Detection' to analyze the content")
print("4. For images, toggle 'Show Explanation' to visualize how the model made its decision")
print("\nNote: Face detection is a required step. If no faces are detected, the system will use the whole image.")def create_image_detection_interface():
    \"\"\"Create a demo interface for image-based deepfake detection\"\"\"
    # Create file upload widget
    file_upload = widgets.FileUpload(
        accept='image/*',
        multiple=False,
        description='Upload Image:'
    )
    
    # Create model selection widget
    model_options = list(models.keys())
    if 'ensemble' not in model_options and len(model_options) > 1:
        model_options.append('ensemble')
    
    model_select = widgets.SelectMultiple(
        options=model_options,
        value=model_options,
        description='Select Models:',
        disabled=False
    )
    
    # Create explanation toggle
    explain_toggle = widgets.Checkbox(
        value=False,
        description='Show Explanation',
        disabled=False
    )
    
    # Create run button
    run_button = widgets.Button(
        description='Run Detection',
        disabled=False,
        button_style='success',
        tooltip='Click to run detection',
        icon='check'
    )
    
    # Create output widget
    output = widgets.Output()
    
    # Define run button callback
    def on_run_button_clicked(b):
        # Clear previous output
        output.clear_output()
        
        with output:
            # Check if file is uploaded
            if not file_upload.value:
                print("Please upload an image first")
                return
            
            # Check if models are selected
            if not model_select.value:
                print("Please select at least one model")
                return
            
            try:
                # Get uploaded file
                uploaded_file = list(file_upload.value.values())[0]
                content = uploaded_file['content']
                
                # Convert content to numpy array
                img_array = np.frombuffer(content, dtype=np.uint8)
                img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                
                print(f"Processing image: {uploaded_file['metadata']['name']} ({img.shape[1]}x{img.shape[0]})")
                
                # Preprocess image
                faces = preprocess_image(img, face_detector)
                
                if not faces:
                    print("No faces detected or error in processing image")
                    return
                
                # For simplicity, use the first face
                face = faces[0]
                
                # Get selected models
                selected_models = {name: models[name] for name in model_select.value if name in models}
                
                # Detect deepfake
                results = detect_deepfake(face['tensor'], selected_models, device)
                
                # Visualize results
                visualize_detection_results(face, results)
                
                # Show explanation if toggled
                if explain_toggle.value:
                    print("\nGenerating explanation visualizations...")
                    
                    for name, model in selected_models.items():
                        print(f"\nExplanation for {name.upper()} model:")
                        explain_prediction(face, name, model, device)
            
            except Exception as e:
                print(f"Error: {e}")
    
    # Register callback
    run_button.on_click(on_run_button_clicked)
    
    # Arrange widgets
    ui = widgets.VBox([
        widgets.HBox([file_upload]), 
        widgets.HBox([model_select, explain_toggle]),
        widgets.HBox([run_button]),
        output
    ])
    
    return ui

def create_video_detection_interface():
    \"\"\"Create a demo interface for video-based deepfake detection\"\"\"
    # Create file upload widget
    file_upload = widgets.FileUpload(
        accept='video/*',
        multiple=False,
        description='Upload Video:'
    )
    
    # Create model selection widget
    model_options = list(models.keys())
    if 'ensemble' not in model_options and len(model_options) > 1:
        model_options.append('ensemble')
    
    model_select = widgets.Dropdown(
        options=model_options,
        value=model_options[0] if model_options else None,
        description='Model:',
        disabled=False
    )
    
    # Create sample rate slider
    sample_rate = widgets.IntSlider(
        value=30,
        min=1,
        max=100,
        step=1,
        description='Sample Rate:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    # Create run button
    run_button = widgets.Button(
        description='Run Detection',
        disabled=False,
        button_style='success',
        tooltip='Click to run detection',
        icon='check'
    )
    
    # Create output widget
    output = widgets.Output()
    
    # Define run button callback
    def on_run_button_clicked(b):
        # Clear previous output
        output.clear_output()
        
        with output:
            # Check if file is uploaded
            if not file_upload.value:
                print("Please upload a video first")
                return
            
            # Check if model is selected
            if not model_select.value:
                print("Please select a model")
                return
            
            try:
                # Get uploaded file
                uploaded_file = list(file_upload.value.values())[0]
                content = uploaded_file['content']
                
                # Save video to temporary file
                import tempfile
                temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
                temp_file.write(content)
                temp_file.close()
                
                video_path = temp_file.name
                print(f"Processing video: {uploaded_file['metadata']['name']}")
                
                # Open video
                cap = cv2.VideoCapture(video_path)
                
                # Check if video opened successfully
                if not cap.isOpened():
                    print("Error opening video file")
                    return
                
                # Get video properties
                frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
                fps = cap.get(cv2.CAP_PROP_FPS)
                
                print(f"Video properties: {frame_count} frames, {fps} fps")
                
                # Calculate frames to sample
                step = sample_rate.value
                
                # Initialize variables for results
                frame_indices = []
                predictions = []
                probabilities = []
                
                # Process video
                with tqdm(total=frame_count//step) as pbar:
                    frame_idx = 0
                    
                    while cap.isOpened():
                        ret, frame = cap.read()
                        
                        if not ret:
                            break
                        
                        if frame_idx % step == 0:
                            # Convert frame to RGB
                            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                            
                            # Preprocess frame
                            faces = preprocess_image(frame_rgb, face_detector)
                            
                            if faces:
                                # Use the first face
                                face = faces[0]
                                
                                # Get selected model
                                if model_select.value == 'ensemble':
                                    selected_models = models
                                else:
                                    selected_models = {model_select.value: models[model_select.value]}
                                
                                # Detect deepfake
                                results = detect_deepfake(face['tensor'], selected_models, device)
                                
                                # Store results
                                if 'ensemble' in results:
                                    pred = results['ensemble']['prediction']
                                    prob = results['ensemble']['probability']
                                else:
                                    model_name = list(results.keys())[0]
                                    pred = results[model_name]['prediction']
                                    prob = results[model_name]['probability']
                                
                                frame_indices.append(frame_idx)
                                predictions.append(pred)
                                probabilities.append(prob)
                            
                            pbar.update(1)
                        
                        frame_idx += 1
                
                # Release video
                cap.release()
                
                # Remove temporary file
                os.unlink(video_path)
                
                # Visualize results
                if frame_indices:
                    # Convert probabilities for plotting (0 = Real, 1 = Fake)
                    prob_values = [p if pred == 'Fake' else 1-p for p, pred in zip(probabilities, predictions)]
                    
                    # Create figure
                    plt.figure(figsize=(12, 6))
                    
                    # Plot timeline with colored points
                    colors = ['green' if pred == 'Real' else 'red' for pred in predictions]
                    plt.scatter(frame_indices, prob_values, c=colors, alpha=0.7)
                    
                    # Connect points with line
                    plt.plot(frame_indices, prob_values, 'k--', alpha=0.3)
                    
                    # Add threshold line
                    plt.axhline(y=0.5, color='blue', linestyle='--', alpha=0.5)
                    
                    # Add labels
                    plt.xlabel('Frame Number')
                    plt.ylabel('Fake Probability')
                    plt.title('Deepfake Detection Results Across Video Frames')
                    
                    # Add legend
                    from matplotlib.lines import Line2D
                    legend_elements = [
                        Line2D([0], [0], marker='o', color='w', markerfacecolor='green', label='Real', markersize=10),
                        Line2D([0], [0], marker='o', color='w', markerfacecolor='red', label='Fake', markersize=10)
                    ]
                    plt.legend(handles=legend_elements)
                    
                    plt.grid(True, linestyle='--', alpha=0.7)
                    plt.tight_layout()
                    plt.show()
                    
                    # Calculate summary statistics
                    real_count = predictions.count('Real')
                    fake_count = predictions.count('Fake')
                    total_count = len(predictions)
                    
                    real_percent = real_count / total_count * 100
                    fake_percent = fake_count / total_count * 100
                    
                    print("\nVIDEO ANALYSIS SUMMARY:")
                    print(f"Processed {total_count} frames from the video")
                    print(f"• Real frames: {real_count} ({real_percent:.1f}%)")
                    print(f"• Fake frames: {fake_count} ({fake_percent:.1f}%)")
                    
                    # Overall verdict
                    if fake_percent > 70:
                        verdict = "FAKE"
                        confidence = "high"
                    elif fake_percent > 40:
                        verdict = "FAKE"
                        confidence = "moderate"
                    elif fake_percent > 20:
                        verdict = "SUSPICIOUS"
                        confidence = "low"
                    else:
                        verdict = "REAL"
                        confidence = "high" if real_percent > 80 else "moderate"
                    
                    print(f"\nOverall verdict: Video is likely {verdict} with {confidence} confidence")
                else:
                    print("No faces detected in the sampled frames")
            
            except Exception as e:
                print(f"Error: {e}")
                import traceback
                traceback.print_exc()
    
    # Register callback
    run_button.on_click(on_run_button_clicked)
    
    # Arrange widgets
    ui = widgets.VBox([
        widgets.HBox([file_upload]), 
        widgets.HBox([model_select, sample_rate]),
        widgets.HBox([run_button]),
        output
    ])
    
    return ui{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Demo\n",
    "\n",
    "This notebook provides a demonstration of the deepfake detection system. It allows you to:\n",
    "\n",
    "1. Load pretrained deepfake detection models\n",
    "2. Analyze single images for deepfake detection\n",
    "3. Process videos frame-by-frame for deepfake detection\n",
    "4. Visualize model decisions and explain results\n",
    "5. Test with your own images or sample images\n",
    "\n",
    "The demo uses the transformer-based deepfake detection models (ViT, DeiT, and Swin) and includes visualization tools to help understand the decision-making process."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "import time\n",
    "import argparse\n",
    "import ipywidgets as widgets\n",
    "from IPython.display import display, clear_output, HTML\n",
    "\n",
    "# Add parent directory to path for importing project modules\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Import project modules\n",
    "from models.vit.model import ViT\n",
    "from models.deit.model import DeiT\n",
    "from models.swin.model import SwinTransformer\n",
    "from models.model_zoo.model_factory import create_model\n",
    "from data.preprocessing.face_extraction import setup_face_detector, extract_faces\n",
    "from data.preprocessing.normalization import normalize_face\n",
    "from evaluation.visualization.attention_maps import visualize_attention_maps\n",
    "from evaluation.visualization.grad_cam import visualize_grad_cam\n",
    "\n",
    "# Set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Pretrained Models\n",
    "\n",
    "First, let's load the pretrained models for deepfake detection."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure paths - update these to your checkpoint paths\n",
    "CHECKPOINT_DIR = \"../trained_models\"\n",
    "VIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"vit_celebdf/checkpoints/best.pth\")\n",
    "DEIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"deit_celebdf/checkpoints/best.pth\")\n",
    "SWIN_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"swin_celebdf/checkpoints/best.pth\")\n",
    "\n",
    "# Check if checkpoints exist\n",
    "vit_exists = os.path.exists(VIT_CHECKPOINT)\n",
    "deit_exists = os.path.exists(DEIT_CHECKPOINT)\n",
    "swin_exists = os.path.exists(SWIN_CHECKPOINT)\n",
    "\n",
    "print(f\"ViT checkpoint exists: {vit_exists}\")\n",
    "print(f\"DeiT checkpoint exists: {deit_exists}\")\n",
    "print(f\"Swin checkpoint exists: {swin_exists}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_checkpoint(model, checkpoint_path, device):\n",
    "    \"\"\"Load model from checkpoint\"\"\"\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        print(f\"Checkpoint not found at {checkpoint_path}\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "        \n",
    "        # Different checkpoint formats\n",
    "        if 'model' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model'])\n",
    "        elif 'model_state_dict' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        else:\n",
    "            model.load_state_dict(checkpoint)\n",
    "            \n",
    "        model = model.to(device)\n",
    "        model.eval()  # Set to evaluation mode\n",
    "        print(f\"Model loaded successfully from {checkpoint_path}\")\n",
    "        return model\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading checkpoint: {e}\")\n",
    "        return None\n",
    "\n",
    "# Initialize models\n",
    "models = {}\n",
    "\n",
    "# Load ViT model\n",
    "if vit_exists:\n",
    "    vit_model = ViT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12\n",
    "    )\n",
    "    vit_model = load_checkpoint(vit_model, VIT_CHECKPOINT, device)\n",
    "    if vit_model is not None:\n",
    "        models['vit'] = vit_model\n",
    "\n",
    "# Load DeiT model\n",
    "if deit_exists:\n",
    "    deit_model = DeiT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        distillation=True\n",
    "    )\n",
    "    deit_model = load_checkpoint(deit_model, DEIT_CHECKPOINT, device)\n",
    "    if deit_model is not None:\n",
    "        models['deit'] = deit_model\n",
    "\n",
    "# Load Swin model\n",
    "if swin_exists:\n",
    "    swin_model = SwinTransformer(\n",
    "        img_size=224,\n",
    "        patch_size=4,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=96,\n",
    "        depths=[2, 2, 6, 2],\n",
    "        num_heads=[3, 6, 12, 24],\n",
    "        window_size=7\n",
    "    )\n",
    "    swin_model = load_checkpoint(swin_model, SWIN_CHECKPOINT, device)\n",
    "    if swin_model is not None:\n",
    "        models['swin'] = swin_model\n",
    "\n",
    "print(f\"Loaded {len(models)} models: {list(models.keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Set up Face Detection\n",
    "\n",
    "We need to set up face detection to extract faces from images and videos."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Set up face detector\n",
    "try:\n",
    "    face_detector = setup_face_detector(device='cpu')\n",
    "    print(\"Face detector set up successfully.\")\n",
    "except Exception as e:\n",
    "    print(f\"Error setting up face detector: {e}\")\n",
    "    print(\"Please install the required dependencies: pip install facenet-pytorch opencv-python\")\n",
    "    face_detector = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Image Preprocessing Functions\n",
    "\n",
    "Define functions for preprocessing images for the deepfake detection models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def preprocess_image(image_path, face_detector=None, target_size=224):\n",
    "    \"\"\"Preprocess an image for deepfake detection\"\"\"\n",
    "    # Load image\n",
    "    if isinstance(image_path, str):\n",
    "        # Load from file\n",
    "        if not os.path.exists(image_path):\n",
    "            print(f\"Image not found at {image_path}\")\n",
    "            return None\n",
    "        \n",
    "        img = cv2.imread(image_path)\n",
    "        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    else:\n",
    "        # Assume numpy array\n",
    "        img = image_path\n",
    "        if img.shape[2] == 3 and img.dtype == np.uint8:\n",
    "            # Likely BGR format from OpenCV\n",
    "            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n",
    "    \n",
    "    # Extract face if detector is provided\n",
    "    if face_detector is not None:\n",
    "        faces = extract_faces(img, face_detector)\n",
    "        if not faces:\n",
    "            print(\"No faces detected in the image\")\n",
    "            # Just use the whole image if no faces detected\n",
    "            faces = [img]\n",
    "    else:\n",
    "        # Use the whole image if no detector\n",
    "        faces = [img]\n",
    "    \n",
    "    # Process each face\n",
    "    processed_faces = []\n",
    "    for face in faces:\n",
    "        # Resize to target size\n",
    "        face_resized = cv2.resize(face, (target_size, target_size))\n",
    "        \n",
    "        # Convert to float and normalize\n",
    "        face_float = face_resized.astype(np.float32) / 255.0\n",
    "        \n",
    "        # Normalize with ImageNet mean and std\n",
    "        face_normalized = (face_float - np.array([0.485, 0.456, 0.406])) / np.array([0.229, 0.224, 0.225])\n",
    "        \n",
    "        # Convert to tensor and add batch dimension (C, H, W)\n",
    "        face_tensor = torch.from_numpy(face_normalized.transpose(2, 0, 1)).float()\n",
    "        \n",
    "        processed_faces.append({\n",
    "            'original': face,\n",
    "            'resized': face_resized,\n",
    "            'tensor': face_tensor\n",
    "        })\n",
    "    \n",
    "    return processed_faces"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Deepfake Detection Functions\n",
    "\n",
    "Define functions for performing deepfake detection on images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def detect_deepfake(face_tensor, models, device):\n",
    "    \"\"\"Detect deepfake in a face image\"\"\"\n",
    "    # Check if we have models\n",
    "    if not models:\n",
    "        print(\"No models available for detection\")\n",
    "        return None\n",
    "    \n",
    "    # Move tensor to device\n",
    "    face_tensor = face_tensor.to(device).unsqueeze(0)  # Add batch dimension\n",
    "    \n",
    "    # Get predictions from each model\n",
    "    results = {}\n",
    "    for name, model in models.items():\n",
    "        with torch.no_grad():\n",
    "            # Forward pass\n",
    "            output = model(face_tensor)\n",
    "            \n",
    "            # Convert to probability\n",
    "            prob = torch.sigmoid(output).item()\n",
    "            \n",
    "            # Store result\n",
    "            results[name] = {\n",
    "                'probability': prob,\n",
    "                'prediction': 'Fake' if prob > 0.5 else 'Real',\n",
    "                'confidence': max(prob, 1 - prob)\n",
    "            }\n",
    "    \n",
    "    # Calculate ensemble prediction (simple averaging)\n",
    "    if len(results) > 1:\n",
    "        ensemble_prob = np.mean([r['probability'] for r in results.values()])\n",
    "        results['ensemble'] = {\n",
    "            'probability': ensemble_prob,\n",
    "            'prediction': 'Fake' if ensemble_prob > 0.5 else 'Real',\n",
    "            'confidence': max(ensemble_prob, 1 - ensemble_prob)\n",
    "        }\n",
    "    \n",
    "    return results\n",
    "\n",
    "def visualize_detection_results(face_dict, results):\n",
    "    \"\"\"Visualize deepfake detection results\"\"\"\n",
    "    # Check inputs\n",
    "    if face_dict is None or results is None:\n",
    "        print(\"No face or results to visualize\")\n",
    "        return\n",
    "    \n",
    "    # Set up figure\n",
    "    fig = plt.figure(figsize=(12, 8))\n",
    "    \n",
    "    # Plot original face\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.imshow(face_dict['original'])\n",
    "    plt.title('Input Face')\n",
    "    plt.axis('off')\n",
    "    \n",
    "    # Plot detection results\n",
    "    plt.subplot(1, 2, 2)\n",
    "    \n",
    "    # Create bar chart of fakeness probabilities\n",
    "    models = list(results.keys())\n",
    "    probs = [results[m]['probability'] for m in models]\n",
    "    \n",
    "    # Choose color based on prediction (green for real, red for fake)\n",
    "    colors = ['green' if results[m]['prediction'] == 'Real' else 'red' for m in models]\n",
    "    \n",
    "    bars = plt.barh(models, probs, color=colors)\n",
    "    plt.xlim(0, 1)\n",
    "    plt.xlabel('Probability of Fake')\n",
    "    plt.axvline(x=0.5, color='black', linestyle='--')\n",
    "    \n",
    "    # Add value labels\n",
    "    for i, bar in enumerate(bars):\n",
    "        plt.text(probs[i] + 0.01, bar.get_y() + bar.get_height()/2, \n",
    "                f\"{probs[i]:.2f} ({results[models[i]]['prediction']})\", \n",
    "                va='center')\n",
    "    \n",
    "    plt.grid(True, linestyle='--', alpha=0.7)\n",
    "    plt.title('Deepfake Detection Results')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    # Print summary\n",
    "    print(\"\\nDetection Summary:\")\n",
    "    for model, result in results.items():\n",
    "        print(f\"• {model}: {result['prediction']} with {result['confidence']*100:.1f}% confidence\")\n",
    "    \n",
    "    # Overall verdict (ensemble or only model)\n",
    "    if 'ensemble' in results:\n",
    "        verdict = results['ensemble']['prediction']\n",
    "        confidence = results['ensemble']['confidence']\n",
    "    else:\n",
    "        # Use the only model available\n",
    "        model = list(results.keys())[0]\n",
    "        verdict = results[model]['prediction']\n",
    "        confidence = results[model]['confidence']\n",
    "    \n",
    "    # Print overall verdict with confidence level text\n",
    "    confidence_level = \"high\" if confidence > 0.8 else \"moderate\" if confidence > 0.6 else \"low\"\n",
    "    print(f\"\\nOverall verdict: Image is {verdict} with {confidence_level} confidence ({confidence*100:.1f}%)\")\n",
    "    \n",
    "    return fig"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Explanation Functions\n",
    "\n",
    "Define functions for explaining model decisions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def find_target_layer(model, model_type):\n",
    "    \"\"\"Find the target layer for Grad-CAM based on model type\"\"\"\n",
    "    if model_type == 'vit':\n",
    "        # For ViT, we use the output of the last transformer block\n",
    "        return model.blocks[-1]\n",
    "    elif model_type == 'deit':\n",
    "        # For DeiT, we use the output of the last transformer block\n",
    "        return model.blocks[-1]\n",
    "    elif model_type == 'swin':\n",
    "        # For Swin, we use the output of the last layer\n",
    "        return model.layers[-1]\n",
    "    else:\n",
    "        raise ValueError(f\"Unknown model type: {model_type}\")\n",
    "\n",
    "class GradCAM:\n",
    "    \"\"\"Grad-CAM implementation for transformer models\"\"\"\n",
    "    \n",
    "    def __init__(self, model, target_layer):\n",
    "        self.model = model\n",
    "        self.target_layer = target_layer\n",
    "        self.gradients = None\n",
    "        self.activations = None\n",
    "        \n",
    "        # Register hooks\n",
    "        self.register_hooks()\n",
    "    \n",
    "    def register_hooks(self):\n",
    "        def forward_hook(module, input, output):\n",
    "            self.activations = output.detach()\n",
    "        \n",
    "        def backward_hook(module, grad_input, grad_output):\n",
    "            self.gradients = grad_output[0].detach()\n",
    "        \n",
    "        # Register hooks\n",
    "        self.target_layer.register_forward_hook(forward_hook)\n",
    "        self.target_layer.register_backward_hook(backward_hook)\n",
    "    \n",
    "    def __call__(self, x, class_idx=None):\n",
    "        # Forward pass\n",
    "        b, c, h, w = x.size()\n",
    "        logits = self.model(x)\n",
    "        \n",
    "        # If class_idx is None, use the model's prediction\n",
    "        if class_idx is None:\n",
    "            if logits.dim() > 1 and logits.shape[1] > 1:\n",
    "                class_idx = torch.argmax(logits, dim=1).item()\n",
    "            else:\n",
    "                class_idx = (logits > 0.5).long().item()\n",
    "        \n",
    "        # Backward pass\n",
    "        self.model.zero_grad()\n",
    "        if logits.dim() > 1 and logits.shape[1] > 1:\n",
    "            target = torch.zeros_like(logits)\n",
    "            target[0, class_idx] = 1\n",
    "        else:\n",
    "            target = torch.ones_like(logits) if class_idx == 1 else torch.zeros_like(logits)\n",
    "        \n",
    "        logits.backward(gradient=target, retain_graph=True)\n",
    "        \n",
    "        # Special handling for transformer models\n",
    "        if hasattr(self.model, 'blocks'):\n",
    "            # For ViT/DeiT, reshape activations and gradients to match image size\n",
    "            if self.activations.dim() == 3:  # [B, L, D]\n",
    "                # Skip class token\n",
    "                activations = self.activations[:, 1:, :]\n",
    "                gradients = self.gradients[:, 1:, :]\n",
    "                \n",
    "                # Calculate patch size\n",
    "                patch_size = int(np.sqrt(w // int(np.sqrt(activations.shape[1]))))\n",
    "                num_patches = int(np.sqrt(activations.shape[1]))\n",
    "                \n",
    "                # Reshape to [B, H, W, D]\n",
    "                activations = activations.reshape(b, num_patches, num_patches, -1)\n",
    "                gradients = gradients.reshape(b, num_patches, num_patches, -1)\n",
    "                \n",
    "                # Get weights (average over patch dimensions)\n",
    "                weights = gradients.mean(dim=(1, 2)).unsqueeze(-1).unsqueeze(-1)\n",
    "                \n",
    "                # Compute weighted activation map\n",
    "                cam = (weights * activations).sum(dim=3)[0].detach().cpu().numpy()\n",
    "                \n",
    "                # Resize to original image size\n",
    "                cam = cv2.resize(cam, (w, h))\n",
    "            else:\n",
    "                # Fallback for unsupported format\n",
    "                cam = np.zeros((h, w), dtype=np.float32)\n",
    "        else:\n",
    "            # Traditional CNN approach\n",
    "            weights = self.gradients.mean(dim=(2, 3))[0]\n",
    "            activations = self.activations[0]\n",
    "            \n",
    "            # Compute weighted activation map\n",
    "            cam = torch.zeros(activations.shape[1:], dtype=torch.float32, device=x.device)\n",
    "            for i, w in enumerate(weights):\n",
    "                cam += w * activations[i]\n",
    "            \n",
    "            cam = cam.detach().cpu().numpy()\n",
    "            cam = cv2.resize(cam, (w, h))\n",
    "        \n",
    "        # Apply ReLU and normalize\n",
    "        cam = np.maximum(cam, 0)\n",
    "        cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-8)\n",
    "        \n",
    "        return cam\n",
    "\n",
    "def get_attention_maps(model, img_tensor, model_type):\n",
    "    \"\"\"Get attention maps for visualization\"\"\"\n",
    "    # Check if model type is supported\n",
    "    if model_type not in ['vit', 'deit']:\n",
    "        print(f\"Attention map visualization not supported for {model_type}\")\n",
    "        return None\n",
    "    \n",
    "    # Move image to device\n",
    "    img_tensor = img_tensor.to(device).unsqueeze(0)  # Add batch dimension\n",
    "    \n",
    "    # Initialize list to store attention maps\n",
    "    attention_maps = []\n",
    "    \n",
    "    # Hook function to extract attention maps\n",
    "    def attention_hook(module, input, output):\n",
    "        attention_maps.append(output.detach().cpu())\n",
    "    \n",
    "    # Register hooks for attention layers\n",
    "    hooks = []\n",
    "    for block in model.blocks:\n",
    "        hooks.append(block.attn.register_forward_hook(attention_hook))\n",
    "    \n",
    "    # Forward pass\n",
    "    with torch.no_grad():\n",
    "        model(img_tensor)\n",
    "    \n",
    "    # Remove hooks\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "    \n",
    "    return attention_maps\n",
    "\n",
    "def explain_prediction(face_dict, model_name, model, device):\n",
    "    \"\"\"Explain model prediction using visualization techniques\"\"\"\n",
    "    # Get face tensor\n",
    "    face_tensor = face_dict['tensor'].to(device).unsqueeze(0)  # Add batch dimension\n",
    "    \n",
    "    # Get prediction\n",
    "    with torch.no_grad():\n",
    "        output = model(face_tensor)\n",
    "        prob = torch.sigmoid(output).item()\n",
    "        pred = 'Fake' if prob > 0.5 else 'Real'\n",
    "    \n",
    "    # Set up figure\n",
    "    fig = plt.figure(figsize=(15, 10))\n",
    "    \n",
    "    # 1. Original image\n",
    "    plt.subplot(2, 3, 1)\n",
    "    plt.imshow(face_dict['original'])\n",
    "    plt.title(f'Original Image\\nPrediction: {pred} ({prob:.2f})')\n",
    "    plt.axis('off')\n",
    "    \n",
    "    # 2. Grad-CAM visualization\n",
    "    try:\n",
    "        target_layer = find_target_layer(model, model_name)\n",
    "        grad_cam = GradCAM(model, target_layer)\n",
    "        cam = grad_cam(face_tensor)\n",
    "        \n",
    "        # Convert to heatmap\n",
    "        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)\n",
    "        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)\n",
    "        \n",
    "        # Resize to match image size\n",
    "        heatmap = cv2.resize(heatmap, (face_dict['original'].shape[1], face_dict['original'].shape[0]))\n",
    "        \n",
    "        # Overlay heatmap on image\n",
    "        superimposed = heatmap * 0.4 + face_dict['original'] * 0.6\n",
    "        superimposed = np.clip(superimposed, 0, 255).astype(np.uint8)\n",
    "        \n",
    "        plt.subplot(2, 3, 2)\n",
    "        plt.imshow(heatmap)\n",
    "        plt.title('Grad-CAM Heatmap')\n",
    "        plt.axis('off')\n",
    "        \n",
    "        plt.subplot(2, 3, 3)\n",
    "        plt.imshow(superimposed)\n",
    "        plt.title('Grad-CAM Overlay')\n",
    "        plt.axis('off')\n",
    "    except Exception as e:\n",
    "        print(f\"Error generating Grad-CAM: {e}\")\n",
    "    \n",
    "    # 3. Attention map visualization (for ViT/DeiT)\n",
    "    if model_name in ['vit', 'deit']:\n",
    "        try:\n",
    "            attention_maps = get_attention_maps(model, face_dict['tensor'], model_name)\n",
    "            \n",
    "            if attention_maps and len(attention_maps) > 0:\n",
    "                # Use the last attention map\n",
    "                attn_map = attention_maps[-1][0]  # First batch\n",
    "                \n",
    "                # Average over heads\n",
    "                avg_attn = attn_{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Demo\n",
    "\n",
    "This notebook provides a demonstration of the deepfake detection system. It allows you to:\n",
    "\n",
    "1. Load pretrained deepfake detection models\n",
    "2. Analyze single images for deepfake detection\n",