In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Model Visualization\n",
    "\n",
    "This notebook provides visualizations and interpretability tools for the deepfake detection models. It includes:\n",
    "\n",
    "1. Model architecture visualization\n",
    "2. Attention map visualization\n",
    "3. Grad-CAM analysis\n",
    "4. Feature visualization\n",
    "5. Comparison of model behaviors\n",
    "\n",
    "These visualizations help understand how the models are making decisions and what features they are focusing on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Import project modules\n",
    "from models.vit.model import ViT\n",
    "from models.deit.model import DeiT\n",
    "from models.swin.model import SwinTransformer\n",
    "from models.model_zoo.model_factory import create_model\n",
    "from data.datasets.faceforensics import FaceForensicsDataset\n",
    "from data.datasets.celebdf import CelebDFDataset\n",
    "from evaluation.visualization.attention_maps import visualize_attention_maps\n",
    "from evaluation.visualization.grad_cam import visualize_grad_cam\n",
    "from evaluation.visualization.feature_visualization import visualize_features\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Pretrained Models\n",
    "\n",
    "Load pretrained models for visualization. You need to have trained models saved in checkpoint format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure paths - update these to your checkpoint paths\n",
    "CHECKPOINT_DIR = \"../trained_models\"\n",
    "VIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"vit_celebdf/checkpoints/best.pth\")\n",
    "DEIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"deit_celebdf/checkpoints/best.pth\")\n",
    "SWIN_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"swin_celebdf/checkpoints/best.pth\")\n",
    "\n",
    "# Check if checkpoints exist\n",
    "vit_exists = os.path.exists(VIT_CHECKPOINT)\n",
    "deit_exists = os.path.exists(DEIT_CHECKPOINT)\n",
    "swin_exists = os.path.exists(SWIN_CHECKPOINT)\n",
    "\n",
    "print(f\"ViT checkpoint exists: {vit_exists}\")\n",
    "print(f\"DeiT checkpoint exists: {deit_exists}\")\n",
    "print(f\"Swin checkpoint exists: {swin_exists}\")\n",
    "\n",
    "# Set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_checkpoint(model, checkpoint_path, device):\n",
    "    \"\"\"Load model from checkpoint\"\"\"\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        print(f\"Checkpoint not found at {checkpoint_path}\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "        \n",
    "        # Different checkpoint formats\n",
    "        if 'model' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model'])\n",
    "        elif 'model_state_dict' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        else:\n",
    "            model.load_state_dict(checkpoint)\n",
    "            \n",
    "        model = model.to(device)\n",
    "        model.eval()  # Set to evaluation mode\n",
    "        print(f\"Model loaded successfully from {checkpoint_path}\")\n",
    "        return model\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading checkpoint: {e}\")\n",
    "        return None\n",
    "\n",
    "# Initialize models\n",
    "models = {}\n",
    "\n",
    "# Load ViT model\n",
    "if vit_exists:\n",
    "    vit_model = ViT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12\n",
    "    )\n",
    "    vit_model = load_checkpoint(vit_model, VIT_CHECKPOINT, device)\n",
    "    if vit_model is not None:\n",
    "        models['vit'] = vit_model\n",
    "\n",
    "# Load DeiT model\n",
    "if deit_exists:\n",
    "    deit_model = DeiT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        distillation=True\n",
    "    )\n",
    "    deit_model = load_checkpoint(deit_model, DEIT_CHECKPOINT, device)\n",
    "    if deit_model is not None:\n",
    "        models['deit'] = deit_model\n",
    "\n",
    "# Load Swin model\n",
    "if swin_exists:\n",
    "    swin_model = SwinTransformer(\n",
    "        img_size=224,\n",
    "        patch_size=4,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=96,\n",
    "        depths=[2, 2, 6, 2],\n",
    "        num_heads=[3, 6, 12, 24],\n",
    "        window_size=7\n",
    "    )\n",
    "    swin_model = load_checkpoint(swin_model, SWIN_CHECKPOINT, device)\n",
    "    if swin_model is not None:\n",
    "        models['swin'] = swin_model\n",
    "\n",
    "print(f\"Loaded {len(models)} models: {list(models.keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Sample Data\n",
    "\n",
    "Load some sample images to visualize the model behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure dataset paths - update these to your local paths\n",
    "FACEFORENSICS_ROOT = \"/path/to/datasets/FaceForensics\"\n",
    "CELEBDF_ROOT = \"/path/to/datasets/CelebDF\"\n",
    "\n",
    "# Choose one dataset to use for visualization\n",
    "DATASET_ROOT = CELEBDF_ROOT  # Change to the dataset you want to use\n",
    "DATASET_NAME = \"celebdf\"     # Change to match your dataset (\"faceforensics\" or \"celebdf\")\n",
    "\n",
    "# Check if directory exists\n",
    "dataset_exists = os.path.exists(DATASET_ROOT)\n",
    "print(f\"Dataset path exists: {dataset_exists}\")\n",
    "\n",
    "# Define transform\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "# Load dataset if it exists\n",
    "if dataset_exists:\n",
    "    if DATASET_NAME == \"faceforensics\":\n",
    "        dataset = FaceForensicsDataset(\n",
    "            root=DATASET_ROOT,\n",
    "            split=\"test\",  # Use test split for visualization\n",
    "            img_size=224,\n",
    "            transform=transform\n",
    "        )\n",
    "    elif DATASET_NAME == \"celebdf\":\n",
    "        dataset = CelebDFDataset(\n",
    "            root=DATASET_ROOT,\n",
    "            split=\"test\",  # Use test split for visualization\n",
    "            img_size=224,\n",
    "            transform=transform\n",
    "        )\n",
    "    else:\n",
    "        print(f\"Unknown dataset name: {DATASET_NAME}\")\n",
    "        dataset = None\n",
    "        \n",
    "    if dataset is not None:\n",
    "        print(f\"Dataset loaded with {len(dataset)} samples\")\n",
    "else:\n",
    "    print(\"Dataset path not found. Cannot load samples.\")\n",
    "    dataset = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def get_sample_batch(dataset, batch_size=4):\n",
    "    \"\"\"Get a sample batch with equal number of real and fake samples\"\"\"\n",
    "    if dataset is None:\n",
    "        return None, None\n",
    "    \n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size*10, shuffle=True)\n",
    "    images, labels = next(iter(dataloader))\n",
    "    \n",
    "    # Separate real and fake\n",
    "    real_imgs = [img for img, label in zip(images, labels) if label == 0]\n",
    "    fake_imgs = [img for img, label in zip(images, labels) if label == 1]\n",
    "    \n",
    "    # Make sure we have enough samples\n",
    "    half_batch = batch_size // 2\n",
    "    if len(real_imgs) < half_batch or len(fake_imgs) < half_batch:\n",
    "        print(\"Not enough samples of each class. Getting more...\")\n",
    "        return get_sample_batch(dataset, batch_size)\n",
    "    \n",
    "    # Get equal number of each\n",
    "    real_imgs = real_imgs[:half_batch]\n",
    "    fake_imgs = fake_imgs[:half_batch]\n",
    "    \n",
    "    # Combine\n",
    "    sample_imgs = real_imgs + fake_imgs\n",
    "    sample_labels = [0] * half_batch + [1] * half_batch\n",
    "    \n",
    "    return torch.stack(sample_imgs), torch.tensor(sample_labels)\n",
    "\n",
    "# Get sample batch if dataset exists\n",
    "if dataset is not None:\n",
    "    sample_images, sample_labels = get_sample_batch(dataset, batch_size=8)\n",
    "    print(f\"Sample batch shape: {sample_images.shape}\")\n",
    "    print(f\"Sample labels: {sample_labels}\")\n",
    "    \n",
    "    # Visualize samples\n",
    "    plt.figure(figsize=(15, 8))\n",
    "    for i in range(len(sample_images)):\n",
    "        plt.subplot(2, 4, i+1)\n",
    "        img = sample_images[i].permute(1, 2, 0).numpy()\n",
    "        # Denormalize for visualization\n",
    "        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])\n",
    "        img = np.clip(img, 0, 1)\n",
    "        plt.imshow(img)\n",
    "        plt.title(\"Real\" if sample_labels[i] == 0 else \"Fake\")\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"No dataset available. Skipping sample visualization.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Model Architecture Visualization\n",
    "\n",
    "Visualize the architecture of the models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def count_parameters(model):\n",
    "    \"\"\"Count number of trainable parameters\"\"\"\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "def print_model_summary(model, name):\n",
    "    \"\"\"Print model summary\"\"\"\n",
    "    print(f\"\\n{name} Model Summary:\")\n",
    "    print(f\"Total parameters: {count_parameters(model):,}\")\n",
    "    \n",
    "    # Print main component information\n",
    "    if name.lower() == 'vit':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Number of patches: {model.num_patches}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of transformer blocks: {len(model.blocks)}\")\n",
    "        print(f\"Number of attention heads: {model.blocks[0].attn.num_heads}\")\n",
    "    elif name.lower() == 'deit':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Number of patches: {model.num_patches}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of transformer blocks: {len(model.blocks)}\")\n",
    "        print(f\"Number of attention heads: {model.blocks[0].attn.num_heads}\")\n",
    "        print(f\"Using distillation: {model.distillation is not None}\")\n",
    "    elif name.lower() == 'swin':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of layers: {model.num_layers}\")\n",
    "        print(f\"Depths: {model.depths}\")\n",
    "\n",
    "# Visualize model architectures\n",
    "for name, model in models.items():\n",
    "    print_model_summary(model, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Try to generate a visual diagram of the model architecture\n",
    "try:\n",
    "    from torchviz import make_dot\n",
    "    from torch.autograd import Variable\n",
    "    \n",
    "    # Function to create model diagram\n",
    "    def visualize_model_graph(model, name):\n",
    "        # Create a sample input\n",
    "        x = Variable(torch.randn(1, 3, 224, 224)).to(device)\n",
    "        \n",
    "        # Generate output\n",
    "        y = model(x)\n",
    "        \n",
    "        # Create dot graph\n",
    "        dot = make_dot(y, params=dict(list(model.named_parameters())))\n",
    "        \n",
    "        # Save and display\n",
    "        dot.format = 'png'\n",
    "        dot.render(f\"{name}_architecture\", cleanup=True)\n",
    "        \n",
    "        # Display\n",
    "        from IPython.display import Image\n",
    "        return Image(filename=f\"{name}_architecture.png\")\n",
    "    \n",
    "    # Visualize each model\n",
    "    for name, model in models.items():\n",
    "        print(f\"\\nGenerating architecture diagram for {name}...\")\n",
    "        display(visualize_model_graph(model, name))\n",
    "except ImportError:\n",
    "    print(\"torchviz not installed. Install with: pip install torchviz\")\n",
    "    print(\"Also requires graphviz to be installed. See: https://graphviz.org/download/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Attention Map Visualization\n",
    "\n",
    "Visualize attention maps from transformer-based models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def get_attention_maps(model, img, model_type):\n",
    "    \"\"\"Get attention maps for the given model and image\"\"\"\n",
    "    # Move image to device\n",
    "    img = img.unsqueeze(0).to(device)\n",
    "    attention_maps = []\n",
    "    \n",
    "    # Hook function to retrieve attention maps\n",
    "    def hook_fn(module, input, output):\n",
    "        if model_type in ['vit', 'deit']:\n",
    "            # For ViT/DeiT, attention is of shape (batch_size, num_heads, seq_len, seq_len)\n",
    "            attention_maps.append(output.detach().cpu())\n",
    "        elif model_type == 'swin':\n",
    "            # For Swin, attention structure is different\n",
    "            # This is a simplified approach - may need adjustment\n",
    "            attention_maps.append(output.detach().cpu())\n",
    "    \n",
    "    # Register hooks\n",
    "    hooks = []\n",
    "    if model_type in ['vit', 'deit']:\n",
    "        for block in model.blocks:\n",
    "            hooks.append(block.attn.register_forward_hook(hook_fn))\n",
    "    elif model_type == 'swin':\n",
    "        # For Swin, we need to find the window attention modules\n",
    "        for layer in model.layers:\n",
    "            for block in layer.blocks:\n",
    "                hooks.append(block.attn.register_forward_hook(hook_fn))\n",
    "    \n",
    "    # Forward pass\n",
    "    with torch.no_grad():\n",
    "        model(img)\n",
    "    \n",
    "    # Remove hooks\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "    \n",
    "    return attention_maps\n",
    "\n",
    "def visualize_attention(model, img, model_type, layer_idx=-1, head_idx=0):\n",
    "    \"\"\"Visualize attention map for a specific layer and head\"\"\"\n",
    "    attention_maps = get_attention_maps(model, img, model_type)\n",
    "    \n",
    "    if not attention_maps:\n",
    "        print(\"No attention maps retrieved.\")\n",
    "        return\n",
    "    \n",
    "    # Get attention map for the specified layer\n",
    "    if layer_idx < 0:\n",
    "        layer_idx = len(attention_maps) + layer_idx\n",
    "    \n",
    "    if layer_idx >= len(attention_maps):\n",
    "        print(f\"Layer index {layer_idx} out of range. Max: {len(attention_maps)-1}\")\n",
    "        return\n",
    "    \n",
    "    attention = attention_maps[layer_idx][0]  # Get the first batch\n",
    "    \n",
    "    # For ViT/DeiT\n",
    "    if model_type in ['vit', 'deit']:\n",
    "        num_heads = attention.shape[0]\n",
    "        if head_idx >= num_heads:\n",
    "            print(f\"Head index {head_idx} out of range. Max: {num_heads-1}\")\n",
    "            return\n",
    "        \n",
    "        # Get attention for specific head\n",
    "        attn_map = attention[head_idx].numpy()\n",
    "        \n",
    "        # The first row corresponds to the [CLS] token's attention to all other tokens\n",
    "        cls_attn = attn_map[0, 1:]  # Skip the attention to [CLS] itself\n",
    "        \n",
    "        # Reshape to image size for visualization\n",
    "        size = int(np.sqrt(len(cls_attn)))\n",
    "        cls_attn = cls_attn.reshape(size, size)\n",
    "        \n",
    "        # Visualize\n",
    "        plt.figure(figsize=(12, 5))\n",
    "        \n",
    "        plt.subplot(1, 2, 1)\n",
    "        # Denormalize image for visualization\n",
    "        img_np = img.permute(1, 2, 0).numpy()\n",
    "        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])\n",
    "        img_np = np.clip(img_np, 0, 1)\n",
    "        plt.imshow(img_np)\n",
    "        plt.title(\"Input Image\")\n",
    "        plt.axis('off')\n",
    "        \n",
    "        plt.subplot(1, 2, 2)\n",
    "        plt.imshow(cls_attn, cmap='viridis')\n",
    "        plt.title(f\"Layer {layer_idx}, Head {head_idx} - CLS Token Attention\")\n",
    "        plt.colorbar(format='%.2f')\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        \n",
    "        return cls_attn\n",
    "    else:\n",
    "        print(f\"Attention visualization for {model_type} is not implemented.\")\n",
    "        return None\n",
    "\n",
    "# Visualize attention maps if we have sample images and transformer models\n",
    "if dataset is not None and len(models) > 0:\n",
    "    # Get a sample image\n",
    "    sample_idx = 0  # Choose an index from the sample batch\n",
    "    sample_img = sample_images[sample_idx]\n",
    "    sample_label = sample_labels[sample_idx]\n",
    "    \n",
    "    print(f\"Visualizing attention for a {'real' if sample_label == 0 else 'fake'} sample\")\n",
    "    \n",
    "    # Visualize for each transformer model\n",
    "    for name, model in models.items():\n",
    "        if name in ['vit', 'deit']:  # Currently implemented for ViT and DeiT\n",
    "            print(f\"\\nAttention maps for {name.upper()}:\")\n",
    "            \n",
    "            # Visualize last layer attention\n",
    "            print(\"Last layer attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=-1, head_idx=0)\n",
    "            \n",
    "            # Visualize middle layer attention\n",
    "            middle_layer = len(model.blocks) // 2\n",
    "            print(f\"Middle layer ({middle_layer}) attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=middle_layer, head_idx=0)\n",
    "            \n",
    "            # Visualize first layer attention\n",
    "            print(\"First layer attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=0, head_idx=0)\n",
    "else:\n",
    "    print(\"Cannot visualize attention maps without sample images and transformer models.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Grad-CAM Visualization\n",
    "\n",
    "Use Grad-CAM to visualize the regions of the image that are important for the model's decision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "class GradCAM:\n",
    "    \"\"\"Grad-CAM implementation for CNN-based models\"\"\"\n",
    "    \n",
    "    def __init__(self, model, target_layer):\n",
    "        self.model = model\n",
    "        self.target_layer = target_layer\n",
    "        self.gradients = None\n",
    "        self.activations = None\n",
    "        \n",
    "        # Register hooks\n",
    "        self.register_hooks()\n",
    "    \n",
    "    def register_hooks(self):\n",
    "        def forward_hook(module, input, output):\n",
    "            self.activations = output.detach()\n",
    "        \n",
    "        def backward_hook(module, grad_input, grad_output):\n",
    "            self.gradients = grad_output[0].detach()\n",
    "        \n",
    "        # Register hooks\n",
    "        self.target_layer.register_forward_hook(forward_hook)\n",
    "        self.target_layer.register_backward_hook(backward_hook)\n",
    "    \n",
    "    def __call__(self, x, class_idx=None):\n",
    "        # Forward pass\n",
    "        b, c, h, w = x.size()\n",
    "        logits = self.model(x)\n",
    "        \n",
    "        # If class_idx is None, use the model's prediction\n",
    "        if class_idx is None:\n",
    "            if logits.dim() > 1:\n",
    "                class_idx = torch.argmax(logits, dim=1).item()\n",
    "            else:\n",
    "                class_idx = (logits > 0.5).long().item()\n",
    "        \n",
    "        # Backward pass\n",
    "        self.model.zero_grad()\n",
    "        if logits.dim() > 1:\n",
    "            target = torch.zeros_like(logits)\n",
    "            target[0, class_idx] = 1\n",
    "        else:\n",
    "            target = torch.ones_like(logits) if class_idx == 1 else torch.zeros_like(logits)\n",
    "        \n",
    "        logits.backward(gradient=target, retain_graph=True)\n",

In [None]:
def visualize_activation_statistics(statistics, labels):
    \"\"\"Visualize activation statistics\"\"\"
    # Convert labels to class names
    class_names = ['Real', 'Fake']
    label_names = [class_names[int(label)] for label in labels]
    
    # For each layer
    for layer_name, layer_stats in statistics.items():
        print(f\"\\nActivation Statistics for {layer_name}:\")\n",
        
        # For each statistic
        for stat_name, stat_values in layer_stats.items():
            plt.figure(figsize=(12, 6))
            
            # Calculate feature-wise statistics
            feature_means = {}
            for i, label in enumerate(np.unique(labels)):
                class_name = class_names[int(label)]
                class_indices = labels == label
                feature_means[class_name] = np.mean(stat_values[class_indices], axis=0)
            
            # Number of features to show
            num_features = min(20, stat_values.shape[1])
            
            # Plot feature-wise statistics
            x = np.arange(num_features)
            width = 0.35
            
            fig, ax = plt.subplots(figsize=(15, 6))
            rects1 = ax.bar(x - width/2, feature_means['Real'][:num_features], width, label='Real')
            rects2 = ax.bar(x + width/2, feature_means['Fake'][:num_features], width, label='Fake')
            
            ax.set_xlabel('Feature Index')
            ax.set_ylabel(f'Mean {stat_name}')
            ax.set_title(f'Feature-wise Mean {stat_name} - {layer_name}')
            ax.set_xticks(x)
            ax.set_xticklabels([str(i) for i in range(num_features)])
            ax.legend()
            
            plt.tight_layout()
            plt.show()
            
            # Plot overall distribution
            plt.figure(figsize=(12, 6))
            sns.boxplot(x='Class', y='Value', data=pd.DataFrame({
                'Class': label_names,
                'Value': np.mean(stat_values, axis=1)
            }))
            plt.title(f'Distribution of Mean {stat_name} - {layer_name}')
            plt.show()

# Function to visualize feature representations
def visualize_feature_space(model, dataloader, method='tsne'):
    \"\"\"Visualize feature space using dimensionality reduction\"\"\"
    # Dictionary to store features
    features_dict = {
        'embeddings': [],
        'labels': []
    }
    
    # Extract features
    model.eval()
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc=\"Extracting features\"):
            # Move to device
            imgs = imgs.to(device)
            
            # Get embeddings
            if hasattr(model, 'extract_features'):
                # Use extract_features method if available
                features = model.extract_features(imgs)
            elif isinstance(model, (ViT, DeiT)):
                # For ViT/DeiT, use the output of forward_features
                features = model.forward_features(imgs)
                if isinstance(features, tuple):
                    features = features[0]  # For DeiT, take the class token features
            elif isinstance(model, SwinTransformer):
                # For Swin, use the output of forward_features
                features = model.forward_features(imgs)
            else:
                # Fall back to a simple forward pass and assumedef get_activation_statistics(model, dataloader, layer_name=None):
    \"\"\"Get activation statistics for a specific layer\"\"\"
    # Set model to evaluation mode
    model.eval()
    
    # Dictionary to store activations
    activations = {}
    
    # Hook function to get activations
    def hook_fn(name):
        def hook(module, input, output):
            activations[name] = output.detach().clone()
        return hook
    
    # Register hooks
    hooks = []
    if layer_name is None:
        # Register hooks for multiple interesting layers
        if hasattr(model, 'blocks'):  # ViT/DeiT
            hooks.append(model.blocks[0].register_forward_hook(hook_fn('first_block')))
            hooks.append(model.blocks[-1].register_forward_hook(hook_fn('last_block')))
        elif hasattr(model, 'layers'):  # Swin
            hooks.append(model.layers[0].register_forward_hook(hook_fn('first_layer')))
            hooks.append(model.layers[-1].register_forward_hook(hook_fn('last_layer')))
    else:
        # Register hook for the specific layer
        layer = get_layer_by_name(model, layer_name)
        if layer is not None:
            hooks.append(layer.register_forward_hook(hook_fn(layer_name)))
    
    # Dictionary to store statistics
    statistics = {}
    
    # List to store labels
    labels = []
    
    # Process batches
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc=\"Collecting activations\"):
            # Move to device
            imgs = imgs.to(device)
            
            # Forward pass
            model(imgs)
            
            # Store labels
            labels.extend(lbls.cpu().numpy())
            
            # Calculate statistics
            for name, activation in activations.items():
                # Initialize statistics for this layer if not exists
                if name not in statistics:
                    statistics[name] = {'mean': [], 'std': [], 'min': [], 'max': []}
                
                # Calculate batch statistics
                if activation.dim() > 2:  # For 3D+ activations, get stats over spatial dimensions
                    batch_mean = activation.mean(dim=[2, 3]).cpu().numpy()
                    batch_std = activation.std(dim=[2, 3]).cpu().numpy()
                    batch_min = activation.min(dim=2)[0].min(dim=2)[0].cpu().numpy()
                    batch_max = activation.max(dim=2)[0].max(dim=2)[0].cpu().numpy()
                else:  # For 2D activations (e.g., transformer outputs)
                    batch_mean = activation.mean(dim=1).cpu().numpy()
                    batch_std = activation.std(dim=1).cpu().numpy()
                    batch_min = activation.min(dim=1)[0].cpu().numpy()
                    batch_max = activation.max(dim=1)[0].cpu().numpy()
                
                # Store batch statistics
                statistics[name]['mean'].extend(batch_mean)
                statistics[name]['std'].extend(batch_std)
                statistics[name]['min'].extend(batch_min)
                statistics[name]['max'].extend(batch_max)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    # Convert lists to numpy arrays
    for name in statistics:
        for stat_name in statistics[name]:
            statistics[name][stat_name] = np.array(statistics[name][stat_name])
    
    return statistics, np.array(labels)

def get_layer_by_name(model, name):
    \"\"\"Get layer by name\"\"\"\n",
    if name == 'first_block' and hasattr(model, 'blocks'):
        return model.blocks[0]
    elif name == 'last_block' and hasattr(model, 'blocks'):
        return model.blocks[-1]
    elif name == 'first_layer' and hasattr(model, 'layers'):
        return model.layers[0]
    elif name == 'last_layer' and hasattr(model, 'layers'):
        return model.layers[-1]
    else:
        print(f\"Unknown layer name: {name}\")
        return Nonedef find_target_layer(model, model_type):
    \"\"\"Find the target layer for Grad-CAM based on model type\"\"\"
    if model_type == 'vit':
        # For ViT, we use the output of the last transformer block
        return model.blocks[-1]
    elif model_type == 'deit':
        # For DeiT, we use the output of the last transformer block
        return model.blocks[-1]
    elif model_type == 'swin':
        # For Swin, we use the output of the last layer
        return model.layers[-1]
    else:
        raise ValueError(f\"Unknown model type: {model_type}\")

# Apply Grad-CAM to our sample images
if dataset is not None and len(models) > 0:
    print(\"\\n\\nGrad-CAM Visualization:\")
    
    # Try for different samples (real and fake)
    for i in range(min(4, len(sample_images))):
        sample_img = sample_images[i]
        sample_label = sample_labels[i]
        
        print(f\"\\nGrad-CAM for {'real' if sample_label == 0 else 'fake'} sample {i+1}:\")
        
        for name, model in models.items():
            print(f\"\\n{name.upper()} model:\")
            try:
                # Find target layer
                target_layer = find_target_layer(model, name)
                
                # Visualize Grad-CAM
                visualize_gradcam(model, sample_img, target_layer, class_idx=sample_label)
            except Exception as e:
                print(f\"Error applying Grad-CAM to {name} model: {e}\")
else:
    print(\"Cannot visualize Grad-CAM without sample images and models.\"){
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Model Visualization\n",
    "\n",
    "This notebook provides visualizations and interpretability tools for the deepfake detection models. It includes:\n",
    "\n",
    "1. Model architecture visualization\n",
    "2. Attention map visualization\n",
    "3. Grad-CAM analysis\n",
    "4. Feature visualization\n",
    "5. Comparison of model behaviors\n",
    "\n",
    "These visualizations help understand how the models are making decisions and what features they are focusing on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Import project modules\n",
    "from models.vit.model import ViT\n",
    "from models.deit.model import DeiT\n",
    "from models.swin.model import SwinTransformer\n",
    "from models.model_zoo.model_factory import create_model\n",
    "from data.datasets.faceforensics import FaceForensicsDataset\n",
    "from data.datasets.celebdf import CelebDFDataset\n",
    "from evaluation.visualization.attention_maps import visualize_attention_maps\n",
    "from evaluation.visualization.grad_cam import visualize_grad_cam\n",
    "from evaluation.visualization.feature_visualization import visualize_features\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Pretrained Models\n",
    "\n",
    "Load pretrained models for visualization. You need to have trained models saved in checkpoint format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure paths - update these to your checkpoint paths\n",
    "CHECKPOINT_DIR = \"../trained_models\"\n",
    "VIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"vit_celebdf/checkpoints/best.pth\")\n",
    "DEIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"deit_celebdf/checkpoints/best.pth\")\n",
    "SWIN_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"swin_celebdf/checkpoints/best.pth\")\n",
    "\n",
    "# Check if checkpoints exist\n",
    "vit_exists = os.path.exists(VIT_CHECKPOINT)\n",
    "deit_exists = os.path.exists(DEIT_CHECKPOINT)\n",
    "swin_exists = os.path.exists(SWIN_CHECKPOINT)\n",
    "\n",
    "print(f\"ViT checkpoint exists: {vit_exists}\")\n",
    "print(f\"DeiT checkpoint exists: {deit_exists}\")\n",
    "print(f\"Swin checkpoint exists: {swin_exists}\")\n",
    "\n",
    "# Set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_checkpoint(model, checkpoint_path, device):\n",
    "    \"\"\"Load model from checkpoint\"\"\"\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        print(f\"Checkpoint not found at {checkpoint_path}\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "        \n",
    "        # Different checkpoint formats\n",
    "        if 'model' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model'])\n",
    "        elif 'model_state_dict' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        else:\n",
    "            model.load_state_dict(checkpoint)\n",
    "            \n",
    "        model = model.to(device)\n",
    "        model.eval()  # Set to evaluation mode\n",
    "        print(f\"Model loaded successfully from {checkpoint_path}\")\n",
    "        return model\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading checkpoint: {e}\")\n",
    "        return None\n",
    "\n",
    "# Initialize models\n",
    "models = {}\n",
    "\n",
    "# Load ViT model\n",
    "if vit_exists:\n",
    "    vit_model = ViT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12\n",
    "    )\n",
    "    vit_model = load_checkpoint(vit_model, VIT_CHECKPOINT, device)\n",
    "    if vit_model is not None:\n",
    "        models['vit'] = vit_model\n",
    "\n",
    "# Load DeiT model\n",
    "if deit_exists:\n",
    "    deit_model = DeiT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        distillation=True\n",
    "    )\n",
    "    deit_model = load_checkpoint(deit_model, DEIT_CHECKPOINT, device)\n",
    "    if deit_model is not None:\n",
    "        models['deit'] = deit_model\n",
    "\n",
    "# Load Swin model\n",
    "if swin_exists:\n",
    "    swin_model = SwinTransformer(\n",
    "        img_size=224,\n",
    "        patch_size=4,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=96,\n",
    "        depths=[2, 2, 6, 2],\n",
    "        num_heads=[3, 6, 12, 24],\n",
    "        window_size=7\n",
    "    )\n",
    "    swin_model = load_checkpoint(swin_model, SWIN_CHECKPOINT, device)\n",
    "    if swin_model is not None:\n",
    "        models['swin'] = swin_model\n",
    "\n",
    "print(f\"Loaded {len(models)} models: {list(models.keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Sample Data\n",
    "\n",
    "Load some sample images to visualize the model behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure dataset paths - update these to your local paths\n",
    "FACEFORENSICS_ROOT = \"/path/to/datasets/FaceForensics\"\n",
    "CELEBDF_ROOT = \"/path/to/datasets/CelebDF\"\n",
    "\n",
    "# Choose one dataset to use for visualization\n",
    "DATASET_ROOT = CELEBDF_ROOT  # Change to the dataset you want to use\n",
    "DATASET_NAME = \"celebdf\"     # Change to match your dataset (\"faceforensics\" or \"celebdf\")\n",
    "\n",
    "# Check if directory exists\n",
    "dataset_exists = os.path.exists(DATASET_ROOT)\n",
    "print(f\"Dataset path exists: {dataset_exists}\")\n",
    "\n",
    "# Define transform\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "# Load dataset if it exists\n",
    "if dataset_exists:\n",
    "    if DATASET_NAME == \"faceforensics\":\n",
    "        dataset = FaceForensicsDataset(\n",
    "            root=DATASET_ROOT,\n",
    "            split=\"test\",  # Use test split for visualization\n",
    "            img_size=224,\n",
    "            transform=transform\n",
    "        )\n",
    "    elif DATASET_NAME == \"celebdf\":\n",
    "        dataset = CelebDFDataset(\n",
    "            root=DATASET_ROOT,\n",
    "            split=\"test\",  # Use test split for visualization\n",
    "            img_size=224,\n",
    "            transform=transform\n",
    "        )\n",
    "    else:\n",
    "        print(f\"Unknown dataset name: {DATASET_NAME}\")\n",
    "        dataset = None\n",
    "        \n",
    "    if dataset is not None:\n",
    "        print(f\"Dataset loaded with {len(dataset)} samples\")\n",
    "else:\n",
    "    print(\"Dataset path not found. Cannot load samples.\")\n",
    "    dataset = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def get_sample_batch(dataset, batch_size=4):\n",
    "    \"\"\"Get a sample batch with equal number of real and fake samples\"\"\"\n",
    "    if dataset is None:\n",
    "        return None, None\n",
    "    \n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size*10, shuffle=True)\n",
    "    images, labels = next(iter(dataloader))\n",
    "    \n",
    "    # Separate real and fake\n",
    "    real_imgs = [img for img, label in zip(images, labels) if label == 0]\n",
    "    fake_imgs = [img for img, label in zip(images, labels) if label == 1]\n",
    "    \n",
    "    # Make sure we have enough samples\n",
    "    half_batch = batch_size // 2\n",
    "    if len(real_imgs) < half_batch or len(fake_imgs) < half_batch:\n",
    "        print(\"Not enough samples of each class. Getting more...\")\n",
    "        return get_sample_batch(dataset, batch_size)\n",
    "    \n",
    "    # Get equal number of each\n",
    "    real_imgs = real_imgs[:half_batch]\n",
    "    fake_imgs = fake_imgs[:half_batch]\n",
    "    \n",
    "    # Combine\n",
    "    sample_imgs = real_imgs + fake_imgs\n",
    "    sample_labels = [0] * half_batch + [1] * half_batch\n",
    "    \n",
    "    return torch.stack(sample_imgs), torch.tensor(sample_labels)\n",
    "\n",
    "# Get sample batch if dataset exists\n",
    "if dataset is not None:\n",
    "    sample_images, sample_labels = get_sample_batch(dataset, batch_size=8)\n",
    "    print(f\"Sample batch shape: {sample_images.shape}\")\n",
    "    print(f\"Sample labels: {sample_labels}\")\n",
    "    \n",
    "    # Visualize samples\n",
    "    plt.figure(figsize=(15, 8))\n",
    "    for i in range(len(sample_images)):\n",
    "        plt.subplot(2, 4, i+1)\n",
    "        img = sample_images[i].permute(1, 2, 0).numpy()\n",
    "        # Denormalize for visualization\n",
    "        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])\n",
    "        img = np.clip(img, 0, 1)\n",
    "        plt.imshow(img)\n",
    "        plt.title(\"Real\" if sample_labels[i] == 0 else \"Fake\")\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"No dataset available. Skipping sample visualization.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Model Architecture Visualization\n",
    "\n",
    "Visualize the architecture of the models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def count_parameters(model):\n",
    "    \"\"\"Count number of trainable parameters\"\"\"\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "def print_model_summary(model, name):\n",
    "    \"\"\"Print model summary\"\"\"\n",
    "    print(f\"\\n{name} Model Summary:\")\n",
    "    print(f\"Total parameters: {count_parameters(model):,}\")\n",
    "    \n",
    "    # Print main component information\n",
    "    if name.lower() == 'vit':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Number of patches: {model.num_patches}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of transformer blocks: {len(model.blocks)}\")\n",
    "        print(f\"Number of attention heads: {model.blocks[0].attn.num_heads}\")\n",
    "    elif name.lower() == 'deit':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Number of patches: {model.num_patches}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of transformer blocks: {len(model.blocks)}\")\n",
    "        print(f\"Number of attention heads: {model.blocks[0].attn.num_heads}\")\n",
    "        print(f\"Using distillation: {model.distillation is not None}\")\n",
    "    elif name.lower() == 'swin':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of layers: {model.num_layers}\")\n",
    "        print(f\"Depths: {model.depths}\")\n",
    "\n",
    "# Visualize model architectures\n",
    "for name, model in models.items():\n",
    "    print_model_summary(model, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Try to generate a visual diagram of the model architecture\n",
    "try:\n",
    "    from torchviz import make_dot\n",
    "    from torch.autograd import Variable\n",
    "    \n",
    "    # Function to create model diagram\n",
    "    def visualize_model_graph(model, name):\n",
    "        # Create a sample input\n",
    "        x = Variable(torch.randn(1, 3, 224, 224)).to(device)\n",
    "        \n",
    "        # Generate output\n",
    "        y = model(x)\n",
    "        \n",
    "        # Create dot graph\n",
    "        dot = make_dot(y, params=dict(list(model.named_parameters())))\n",
    "        \n",
    "        # Save and display\n",
    "        dot.format = 'png'\n",
    "        dot.render(f\"{name}_architecture\", cleanup=True)\n",
    "        \n",
    "        # Display\n",
    "        from IPython.display import Image\n",
    "        return Image(filename=f\"{name}_architecture.png\")\n",
    "    \n",
    "    # Visualize each model\n",
    "    for name, model in models.items():\n",
    "        print(f\"\\nGenerating architecture diagram for {name}...\")\n",
    "        display(visualize_model_graph(model, name))\n",
    "except ImportError:\n",
    "    print(\"torchviz not installed. Install with: pip install torchviz\")\n",
    "    print(\"Also requires graphviz to be installed. See: https://graphviz.org/download/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Attention Map Visualization\n",
    "\n",
    "Visualize attention maps from transformer-based models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def get_attention_maps(model, img, model_type):\n",
    "    \"\"\"Get attention maps for the given model and image\"\"\"\n",
    "    # Move image to device\n",
    "    img = img.unsqueeze(0).to(device)\n",
    "    attention_maps = []\n",
    "    \n",
    "    # Hook function to retrieve attention maps\n",
    "    def hook_fn(module, input, output):\n",
    "        if model_type in ['vit', 'deit']:\n",
    "            # For ViT/DeiT, attention is of shape (batch_size, num_heads, seq_len, seq_len)\n",
    "            attention_maps.append(output.detach().cpu())\n",
    "        elif model_type == 'swin':\n",
    "            # For Swin, attention structure is different\n",
    "            # This is a simplified approach - may need adjustment\n",
    "            attention_maps.append(output.detach().cpu())\n",
    "    \n",
    "    # Register hooks\n",
    "    hooks = []\n",
    "    if model_type in ['vit', 'deit']:\n",
    "        for block in model.blocks:\n",
    "            hooks.append(block.attn.register_forward_hook(hook_fn))\n",
    "    elif model_type == 'swin':\n",
    "        # For Swin, we need to find the window attention modules\n",
    "        for layer in model.layers:\n",
    "            for block in layer.blocks:\n",
    "                hooks.append(block.attn.register_forward_hook(hook_fn))\n",
    "    \n",
    "    # Forward pass\n",
    "    with torch.no_grad():\n",
    "        model(img)\n",
    "    \n",
    "    # Remove hooks\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "    \n",
    "    return attention_maps\n",
    "\n",
    "def visualize_attention(model, img, model_type, layer_idx=-1, head_idx=0):\n",
    "    \"\"\"Visualize attention map for a specific layer and head\"\"\"\n",
    "    attention_maps = get_attention_maps(model, img, model_type)\n",
    "    \n",
    "    if not attention_maps:\n",
    "        print(\"No attention maps retrieved.\")\n",
    "        return\n",
    "    \n",
    "    # Get attention map for the specified layer\n",
    "    if layer_idx < 0:\n",
    "        layer_idx = len(attention_maps) + layer_idx\n",
    "    \n",
    "    if layer_idx >= len(attention_maps):\n",
    "        print(f\"Layer index {layer_idx} out of range. Max: {len(attention_maps)-1}\")\n",
    "        return\n",
    "    \n",
    "    attention = attention_maps[layer_idx][0]  # Get the first batch\n",
    "    \n",
    "    # For ViT/DeiT\n",
    "    if model_type in ['vit', 'deit']:\n",
    "        num_heads = attention.shape[0]\n",
    "        if head_idx >= num_heads:\n",
    "            print(f\"Head index {head_idx} out of range. Max: {num_heads-1}\")\n",
    "            return\n",
    "        \n",
    "        # Get attention for specific head\n",
    "        attn_map = attention[head_idx].numpy()\n",
    "        \n",
    "        # The first row corresponds to the [CLS] token's attention to all other tokens\n",
    "        cls_attn = attn_map[0, 1:]  # Skip the attention to [CLS] itself\n",
    "        \n",
    "        # Reshape to image size for visualization\n",
    "        size = int(np.sqrt(len(cls_attn)))\n",
    "        cls_attn = cls_attn.reshape(size, size)\n",
    "        \n",
    "        # Visualize\n",
    "        plt.figure(figsize=(12, 5))\n",
    "        \n",
    "        plt.subplot(1, 2, 1)\n",
    "        # Denormalize image for visualization\n",
    "        img_np = img.permute(1, 2, 0).numpy()\n",
    "        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])\n",
    "        img_np = np.clip(img_np, 0, 1)\n",
    "        plt.imshow(img_np)\n",
    "        plt.title(\"Input Image\")\n",
    "        plt.axis('off')\n",
    "        \n",
    "        plt.subplot(1, 2, 2)\n",
    "        plt.imshow(cls_attn, cmap='viridis')\n",
    "        plt.title(f\"Layer {layer_idx}, Head {head_idx} - CLS Token Attention\")\n",
    "        plt.colorbar(format='%.2f')\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        \n",
    "        return cls_attn\n",
    "    else:\n",
    "        print(f\"Attention visualization for {model_type} is not implemented.\")\n",
    "        return None\n",
    "\n",
    "# Visualize attention maps if we have sample images and transformer models\n",
    "if dataset is not None and len(models) > 0:\n",
    "    # Get a sample image\n",
    "    sample_idx = 0  # Choose an index from the sample batch\n",
    "    sample_img = sample_images[sample_idx]\n",
    "    sample_label = sample_labels[sample_idx]\n",
    "    \n",
    "    print(f\"Visualizing attention for a {'real' if sample_label == 0 else 'fake'} sample\")\n",
    "    \n",
    "    # Visualize for each transformer model\n",
    "    for name, model in models.items():\n",
    "        if name in ['vit', 'deit']:  # Currently implemented for ViT and DeiT\n",
    "            print(f\"\\nAttention maps for {name.upper()}:\")\n",
    "            \n",
    "            # Visualize last layer attention\n",
    "            print(\"Last layer attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=-1, head_idx=0)\n",
    "            \n",
    "            # Visualize middle layer attention\n",
    "            middle_layer = len(model.blocks) // 2\n",
    "            print(f\"Middle layer ({middle_layer}) attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=middle_layer, head_idx=0)\n",
    "            \n",
    "            # Visualize first layer attention\n",
    "            print(\"First layer attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=0, head_idx=0)\n",
    "else:\n",
    "    print(\"Cannot visualize attention maps without sample images and transformer models.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Grad-CAM Visualization\n",
    "\n",
    "Use Grad-CAM to visualize the regions of the image that are important for the model's decision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "class GradCAM:\n",
    "    \"\"\"Grad-CAM implementation for CNN-based models\"\"\"\n",
    "    \n",
    "    def __init__(self, model, target_layer):\n",
    "        self.model = model\n",
    "        self.target_layer = target_layer\n",
    "        self.gradients = None\n",
    "        self.activations = None\n",
    "        \n",
    "        # Register hooks\n",
    "        self.register_hooks()\n",
    "    \n",
    "    def register_hooks(self):\n",
    "        def forward_hook(module, input, output):\n",
    "            self.activations = output.detach()\n",
    "        \n",
    "        def backward_hook(module, grad_input, grad_output):\n",
    "            self.gradients = grad_output[0].detach()\n",
    "        \n",
    "        # Register hooks\n",
    "        self.target_layer.register_forward_hook(forward_hook)\n",
    "        self.target_layer.register_backward_hook(backward_hook)\n",
    "    \n",
    "    def __call__(self, x, class_idx=None):\n",
    "        # Forward pass\n",
    "        b, c, h, w = x.size()\n",
    "        logits = self.model(x)\n",
    "        \n",
    "        # If class_idx is None, use the model's prediction\n",
    "        if class_idx is None:\n",
    "            if logits.dim() > 1:\n",
    "                class_idx = torch.argmax(logits, dim=1).item()\n",
    "            else:\n",
    "                class_idx = (logits > 0.5).long().item()\n",
    "        \n",
    "        # Backward pass\n",
    "        self.model.zero_grad()\n",
    "        if logits.dim() > 1:\n",
    "            target = torch.zeros_like(logits)\n",
    "            target[0, class_idx] = 1\n",
    "        else:\n",
    "            target = torch.ones_like(logits) if class_idx == 1 else torch.zeros_like(logits)\n",
    "        \n",
    "        logits.backward(gradient=target, retain_graph=True)\n",

In [None]:
## 7. Summary and Insights

def generate_insights_summary(models):
    \"\"\"Generate a summary of insights from model visualization\"\"\"
    print("\n\nSummary of Insights from Model Visualization:\n")
    
    # 1. Model Architecture Insights
    print("1. Model Architecture Insights:")
    for name, model in models.items():
        print(f"  - {name.upper()} model has {count_parameters(model):,} trainable parameters")
        
        if name == 'vit':
            print(f"    * Uses a transformer architecture with {len(model.blocks)} blocks")
            print(f"    * Uses {model.blocks[0].attn.num_heads} attention heads")
            print(f"    * Processes images as {model.num_patches} patches of size {model.patch_size}x{model.patch_size}")
        elif name == 'deit':
            print(f"    * Similar to ViT but with {len(model.blocks)} blocks")
            print(f"    * Uses {model.blocks[0].attn.num_heads} attention heads")
            print(f"    * Uses distillation token: {model.distillation is not None}")
        elif name == 'swin':
            print(f"    * Uses hierarchical structure with {model.num_layers} layers")
            print(f"    * Uses shifted windows for efficient attention")
    
    # 2. Attention Mechanism Insights
    print("\n2. Attention Mechanism Insights:")
    for name in models:
        if name in ['vit', 'deit']:
            print(f"  - {name.upper()} model:")
            print("    * Class token attention shows which image regions are important for classification")
            print("    * Early layers capture low-level features, while later layers focus on semantic regions")
            print("    * Different attention heads specialize in different aspects of the image")
    
    # 3. Grad-CAM Insights
    print("\n3. Grad-CAM Insights:")
    print("  - Highlights regions that influence the model's decision")
    print("  - For real faces, models often focus on natural facial features")
    print("  - For fake faces, models often focus on inconsistencies or artifacts")
    
    # 4. Feature Space Insights
    print("\n4. Feature Space Insights:")
    print("  - Models learn to separate real and fake samples in the feature space")
    print("  - The degree of separation indicates the model's confidence")
    print("  - Samples close to the decision boundary are more challenging to classify")
    
    # 5. Model Comparison Insights
    if len(models) > 1:
        print("\n5. Model Comparison Insights:")
        print("  - Different models may focus on different aspects of the images")
        print("  - Agreement between models suggests more reliable predictions")
        print("  - Ensemble methods could improve overall performance by combining strengths")
    
    # 6. Challenging Cases
    print("\n6. Challenging Cases:")
    print("  - High-quality deepfakes may fool models by preserving natural facial features")
    print("  - Poor quality real images may be misclassified due to compression artifacts")
    print("  - Models may struggle with unusual lighting, poses, or facial expressions")
    
    # 7. Practical Recommendations
    print("\n7. Practical Recommendations:")
    print("  - Use multiple models for more reliable detection")
    print("  - Consider model confidence in the decision-making process")
    print("  - Further training on challenging examples could improve performance")
    print("  - Model interpretability tools can help understand and debug failures")

# Generate insights summary if we have models
if len(models) > 0:
    generate_insights_summary(models)

print("\n\nModel Visualization Notebook Complete!")
print("You can use the above visualizations and analyses to better understand how the deepfake detection models work.")
print("This understanding can help improve model performance, interpretability, and trust in the system.")
# Model interpretability analysis
if dataset is not None and len(models) > 0:
    print("\n\nDetailed Model Interpretability Analysis:")
    
    # Get a few interesting samples
    # Try to find real and fake samples
    real_samples = []
    fake_samples = []
    
    for i in range(len(sample_images)):
        if sample_labels[i] == 0 and len(real_samples) < 2:
            real_samples.append((sample_images[i], i))
        elif sample_labels[i] == 1 and len(fake_samples) < 2:
            fake_samples.append((sample_images[i], i))
        
        if len(real_samples) >= 2 and len(fake_samples) >= 2:
            break
    
    # Combine samples
    selected_samples = real_samples + fake_samples
    
    # Analyze each sample with each model
    for img, idx in selected_samples:
        label = sample_labels[idx]
        label_name = "Real" if label == 0 else "Fake"
        print(f"\n\nAnalyzing {label_name} sample (index {idx}):")
        
        # Show image
        plt.figure(figsize=(6, 6))
        img_np = img.permute(1, 2, 0).numpy()
        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
        img_np = np.clip(img_np, 0, 1)
        plt.imshow(img_np)
        plt.title(f"Sample Image (Ground Truth: {label_name})")
        plt.axis('off')
        plt.show()
        
        # Analyze with each model
        for name, model in models.items():
            analyze_model_interpretability(model, img, name)
else:
    print("Cannot perform interpretability analysis without sample images and models.")## 6. Model Interpretability Analysis

def analyze_model_interpretability(model, img, model_type):
    \"\"\"Comprehensive interpretability analysis for a single image\"\"\"
    print(f"\nInterpretability Analysis for {model_type.upper()} model:")
    
    # Move image to device
    img_tensor = img.unsqueeze(0).to(device)
    
    # Get model prediction
    model.eval()
    with torch.no_grad():
        output = model(img_tensor)
        
        # Handle different output formats
        if output.dim() > 1 and output.shape[1] > 1:  # Multi-class
            prob = torch.softmax(output, dim=1)[0, 1].item()  # Probability of fake class
            pred = output.argmax(dim=1).item()
        else:  # Binary
            prob = torch.sigmoid(output).item()
            pred = int(prob > 0.5)
    
    pred_class = "Fake" if pred == 1 else "Real"
    print(f"Model prediction: {pred_class} with confidence {prob:.4f}")
    
    # 1. Attention Map Visualization
    if model_type in ['vit', 'deit']:
        print("\n1. Attention Map Visualization:")
        visualize_attention(model, img, model_type, layer_idx=-1, head_idx=0)
    
    # 2. Grad-CAM Visualization
    print("\n2. Grad-CAM Visualization:")
    try:
        target_layer = find_target_layer(model, model_type)
        visualize_gradcam(model, img, target_layer, class_idx=pred)
    except Exception as e:
        print(f"Error applying Grad-CAM: {e}")
    
    # 3. Feature Importance Analysis
    print("\n3. Feature Importance Analysis:")
    if model_type in ['vit', 'deit']:
        # For transformer models, analyze attention to different patches
        attention_maps = get_attention_maps(model, img, model_type)
        
        if attention_maps:
            # Get attention from last layer
            attention = attention_maps[-1][0]  # Get the first batch
            
            # Average across heads
            avg_attention = attention.mean(dim=0)
            
            # Get CLS token attention to patches
            cls_attention = avg_attention[0, 1:]
            
            # Reshape to image grid
            size = int(np.sqrt(len(cls_attention)))
            attn_grid = cls_attention.reshape(size, size).numpy()
            
            # Visualize as heatmap
            plt.figure(figsize=(8, 6))
            plt.imshow(attn_grid, cmap='viridis')
            plt.colorbar(format='%.2f')
            plt.title(f"{model_type.upper()} Patch Importance")
            plt.show()
            
            # Identify most important patches
            top_k = 5
            flat_indices = np.argsort(cls_attention.numpy())[-top_k:]
            row_indices = flat_indices // size
            col_indices = flat_indices % size
            
            print(f"Top {top_k} important patches:")
            for i, (row, col) in enumerate(zip(row_indices, col_indices)):
                print(f"  {i+1}. Patch at position ({row}, {col}) with importance {cls_attention[row*size + col]:.4f}")
    
    return pred, prob# Compare model predictions
if dataset is not None and len(models) > 1:
    print("\n\nComparing Model Predictions:")
    
    # Create a dataloader for testing
    test_dataset = torch.utils.data.Subset(dataset, 
                                         indices=np.random.choice(len(dataset), 
                                                               size=min(300, len(dataset)), 
                                                               replace=False))
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)
    
    # Compare predictions
    predictions, probabilities, labels = compare_model_predictions(models, test_dataloader)
    
    # Analyze difficult examples
    print("\nAnalyzing Difficult Examples:")
    
    # Find examples where models disagree
    disagreement = np.zeros(len(labels), dtype=bool)
    model_names = list(models.keys())
    for i in range(len(model_names)):
        for j in range(i+1, len(model_names)):
            disagreement |= (predictions[model_names[i]] != predictions[model_names[j]])
    
    # Find examples where all models are wrong
    all_wrong = np.ones(len(labels), dtype=bool)
    for name in models:
        all_wrong &= (predictions[name] != labels)
    
    # Print statistics
    n_disagreement = disagreement.sum()
    n_all_wrong = all_wrong.sum()
    print(f"Models disagree on {n_disagreement} examples ({n_disagreement/len(labels)*100:.1f}%)")
    print(f"All models are wrong on {n_all_wrong} examples ({n_all_wrong/len(labels)*100:.1f}%)")
    
    # Get indices of disagreement and all-wrong examples
    disagreement_indices = np.where(disagreement)[0]
    all_wrong_indices = np.where(all_wrong)[0]
    
    # Print confusion matrix for each model
    print("\nConfusion Matrices:")
    for name in models:
        from sklearn.metrics import confusion_matrix
        cm = confusion_matrix(labels, predictions[name])
        print(f"\n{name.upper()} Confusion Matrix:")
        print(cm)
        
        # Plot confusion matrix
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                   xticklabels=['Real', 'Fake'], yticklabels=['Real', 'Fake'])
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title(f'{name.upper()} Confusion Matrix')
        plt.show()
else:
    print("Cannot compare model predictions without dataset or with less than 2 models.")def compare_model_predictions(models, dataloader):
    \"\"\"Compare predictions from different models\"\"\"
    # Dictionary to store predictions
    predictions = {name: [] for name in models}
    probabilities = {name: [] for name in models}
    
    # List to store ground truth labels
    labels = []
    
    # Get predictions from each model
    for imgs, lbls in tqdm(dataloader, desc="Getting predictions"):
        # Move images to device
        imgs = imgs.to(device)
        
        # Store ground truth labels
        labels.extend(lbls.numpy())
        
        # Get predictions from each model
        for name, model in models.items():
            model.eval()
            with torch.no_grad():
                outputs = model(imgs)
                
                # Handle different output formats
                if outputs.dim() > 1 and outputs.shape[1] > 1:  # Multi-class
                    probs = torch.softmax(outputs, dim=1)[:, 1].cpu().numpy()  # Probability of fake class
                    preds = outputs.argmax(dim=1).cpu().numpy()
                else:  # Binary
                    probs = torch.sigmoid(outputs).cpu().numpy().flatten()
                    preds = (probs > 0.5).astype(int)
                
                # Store predictions and probabilities
                predictions[name].extend(preds)
                probabilities[name].extend(probs)
    
    # Convert to numpy arrays
    labels = np.array(labels)
    for name in models:
        predictions[name] = np.array(predictions[name])
        probabilities[name] = np.array(probabilities[name])
    
    # Compare predictions
    plt.figure(figsize=(15, 10))
    
    # Plot distribution of predictions
    plt.subplot(2, 2, 1)
    for name in models:
        sns.kdeplot(probabilities[name], label=name.upper())
    plt.xlabel('Probability of Fake')
    plt.ylabel('Density')
    plt.title('Distribution of Predicted Probabilities')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)
    
    # Plot prediction agreement
    plt.subplot(2, 2, 2)
    if len(models) > 1:
        model_names = list(models.keys())
        agreement_matrix = np.zeros((len(model_names), len(model_names)))
        
        for i, name1 in enumerate(model_names):
            for j, name2 in enumerate(model_names):
                if i == j:
                    agreement_matrix[i, j] = 1.0
                else:
                    agreement = np.mean(predictions[name1] == predictions[name2])
                    agreement_matrix[i, j] = agreement
        
        sns.heatmap(agreement_matrix, annot=True, fmt='.2f', 
                   xticklabels=[name.upper() for name in model_names],
                   yticklabels=[name.upper() for name in model_names],
                   cmap='YlGnBu')
        plt.title('Model Prediction Agreement')
    else:
        plt.text(0.5, 0.5, 'Need at least 2 models to compare', 
                ha='center', va='center', fontsize=12)
        plt.axis('off')
    
    # Plot accuracy comparison
    plt.subplot(2, 2, 3)
    accuracies = {}
    for name in models:
        accuracies[name] = np.mean(predictions[name] == labels)
    
    plt.bar(range(len(accuracies)), list(accuracies.values()), tick_label=[name.upper() for name in accuracies])
    plt.xlabel('Model')
    plt.ylabel('Accuracy')
    plt.title('Model Accuracy Comparison')
    plt.ylim(0.5, 1.0)  # Start from 0.5 for better visual comparison
    
    # For each bar, add the accuracy value on top
    for i, v in enumerate(accuracies.values()):
        plt.text(i, v + 0.01, f'{v:.2f}', ha='center')
    
    # Plot per-class accuracy
    plt.subplot(2, 2, 4)
    per_class_acc = {}
    for name in models:
        real_acc = np.mean((predictions[name] == labels)[labels == 0])
        fake_acc = np.mean((predictions[name] == labels)[labels == 1])
        per_class_acc[name] = [real_acc, fake_acc]
    
    # Convert to DataFrame for easier plotting
    df = pd.DataFrame(per_class_acc, index=['Real', 'Fake']).T
    df.plot(kind='bar', ax=plt.gca())
    plt.xlabel('Model')
    plt.ylabel('Accuracy')
    plt.title('Per-class Accuracy Comparison')
    plt.ylim(0.5, 1.0)
    plt.legend(title='Class')
    
    plt.tight_layout()
    plt.show()
    
    return predictions, probabilities, labels# Feature space visualization
if dataset is not None and len(models) > 0:
    print("\n\nFeature Space Visualization:")
    
    # Create a smaller dataloader for visualization
    vis_dataset = torch.utils.data.Subset(dataset, 
                                         indices=np.random.choice(len(dataset), 
                                                               size=min(200, len(dataset)), 
                                                               replace=False))
    vis_dataloader = DataLoader(vis_dataset, batch_size=32, shuffle=False)
    
    # Visualize feature space for each model
    for name, model in models.items():
        print(f"\nVisualizing feature space for {name.upper()} model:")
        try:
            visualize_feature_space(model, vis_dataloader, method='tsne')
        except Exception as e:
            print(f"Error visualizing feature space for {name} model: {e}")
else:
    print("Cannot visualize feature space without dataset and models.")def visualize_activation_statistics(statistics, labels):
    \"\"\"Visualize activation statistics\"\"\"
    # Convert labels to class names
    class_names = ['Real', 'Fake']
    label_names = [class_names[int(label)] for label in labels]
    
    # For each layer
    for layer_name, layer_stats in statistics.items():
        print(f\"\\nActivation Statistics for {layer_name}:\")\n",
        
        # For each statistic
        for stat_name, stat_values in layer_stats.items():
            plt.figure(figsize=(12, 6))
            
            # Calculate feature-wise statistics
            feature_means = {}
            for i, label in enumerate(np.unique(labels)):
                class_name = class_names[int(label)]
                class_indices = labels == label
                feature_means[class_name] = np.mean(stat_values[class_indices], axis=0)
            
            # Number of features to show
            num_features = min(20, stat_values.shape[1])
            
            # Plot feature-wise statistics
            x = np.arange(num_features)
            width = 0.35
            
            fig, ax = plt.subplots(figsize=(15, 6))
            rects1 = ax.bar(x - width/2, feature_means['Real'][:num_features], width, label='Real')
            rects2 = ax.bar(x + width/2, feature_means['Fake'][:num_features], width, label='Fake')
            
            ax.set_xlabel('Feature Index')
            ax.set_ylabel(f'Mean {stat_name}')
            ax.set_title(f'Feature-wise Mean {stat_name} - {layer_name}')
            ax.set_xticks(x)
            ax.set_xticklabels([str(i) for i in range(num_features)])
            ax.legend()
            
            plt.tight_layout()
            plt.show()
            
            # Plot overall distribution
            plt.figure(figsize=(12, 6))
            sns.boxplot(x='Class', y='Value', data=pd.DataFrame({
                'Class': label_names,
                'Value': np.mean(stat_values, axis=1)
            }))
            plt.title(f'Distribution of Mean {stat_name} - {layer_name}')
            plt.show()

# Function to visualize feature representations
def visualize_feature_space(model, dataloader, method='tsne'):
    \"\"\"Visualize feature space using dimensionality reduction\"\"\"
    # Dictionary to store features
    features_dict = {
        'embeddings': [],
        'labels': []
    }
    
    # Extract features
    model.eval()
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc="Extracting features"):
            # Move to device
            imgs = imgs.to(device)
            
            # Get embeddings
            if hasattr(model, 'extract_features'):
                # Use extract_features method if available
                features = model.extract_features(imgs)
            elif isinstance(model, (ViT, DeiT)):
                # For ViT/DeiT, use the output of forward_features
                features = model.forward_features(imgs)
                if isinstance(features, tuple):
                    features = features[0]  # For DeiT, take the class token features
            elif isinstance(model, SwinTransformer):
                # For Swin, use the output of forward_features
                features = model.forward_features(imgs)
            else:
                # Fall back to a simple forward pass and assume the model has a feature extractor
                print("Using generic feature extraction - this might not work as expected")
                features = model(imgs)
            
            # Store features and labels
            features_dict['embeddings'].append(features.cpu().numpy())
            features_dict['labels'].append(lbls.numpy())
    
    # Concatenate features and labels
    embeddings = np.concatenate(features_dict['embeddings'], axis=0)
    labels = np.concatenate(features_dict['labels'], axis=0)
    
    # Apply dimensionality reduction
    if method.lower() == 'tsne':
        try:
            from sklearn.manifold import TSNE
            print("Applying t-SNE dimensionality reduction...")
            reduced_features = TSNE(n_components=2, random_state=42).fit_transform(embeddings)
        except ImportError:
            print("sklearn not installed. Install with: pip install scikit-learn")
            return
    elif method.lower() == 'pca':
        try:
            from sklearn.decomposition import PCA
            print("Applying PCA dimensionality reduction...")
            reduced_features = PCA(n_components=2, random_state=42).fit_transform(embeddings)
        except ImportError:
            print("sklearn not installed. Install with: pip install scikit-learn")
            return
    else:
        print(f"Unknown dimensionality reduction method: {method}")
        return
    
    # Visualize
    plt.figure(figsize=(10, 8))
    scatter = plt.scatter(reduced_features[:, 0], reduced_features[:, 1], 
                         c=labels, cmap='viridis', alpha=0.8, s=30)
    
    # Add legend
    legend1 = plt.legend(*scatter.legend_elements(),
                        loc="upper right", title="Classes")
    plt.gca().add_artist(legend1)
    
    # Add title and labels
    plt.title(f"Feature Space Visualization using {method.upper()}")
    plt.xlabel(f"{method.upper()} Component 1")
    plt.ylabel(f"{method.upper()} Component 2")
    
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.tight_layout()
    plt.show()
    
    return reduced_features, labelsdef get_activation_statistics(model, dataloader, layer_name=None):
    \"\"\"Get activation statistics for a specific layer\"\"\"
    # Set model to evaluation mode
    model.eval()
    
    # Dictionary to store activations
    activations = {}
    
    # Hook function to get activations
    def hook_fn(name):
        def hook(module, input, output):
            activations[name] = output.detach().clone()
        return hook
    
    # Register hooks
    hooks = []
    if layer_name is None:
        # Register hooks for multiple interesting layers
        if hasattr(model, 'blocks'):  # ViT/DeiT
            hooks.append(model.blocks[0].register_forward_hook(hook_fn('first_block')))
            hooks.append(model.blocks[-1].register_forward_hook(hook_fn('last_block')))
        elif hasattr(model, 'layers'):  # Swin
            hooks.append(model.layers[0].register_forward_hook(hook_fn('first_layer')))
            hooks.append(model.layers[-1].register_forward_hook(hook_fn('last_layer')))
    else:
        # Register hook for the specific layer
        layer = get_layer_by_name(model, layer_name)
        if layer is not None:
            hooks.append(layer.register_forward_hook(hook_fn(layer_name)))
    
    # Dictionary to store statistics
    statistics = {}
    
    # List to store labels
    labels = []
    
    # Process batches
    with torch.no_grad():
        for imgs, lbls in tqdm(dataloader, desc=\"Collecting activations\"):
            # Move to device
            imgs = imgs.to(device)
            
            # Forward pass
            model(imgs)
            
            # Store labels
            labels.extend(lbls.cpu().numpy())
            
            # Calculate statistics
            for name, activation in activations.items():
                # Initialize statistics for this layer if not exists
                if name not in statistics:
                    statistics[name] = {'mean': [], 'std': [], 'min': [], 'max': []}
                
                # Calculate batch statistics
                if activation.dim() > 2:  # For 3D+ activations, get stats over spatial dimensions
                    batch_mean = activation.mean(dim=[2, 3]).cpu().numpy()
                    batch_std = activation.std(dim=[2, 3]).cpu().numpy()
                    batch_min = activation.min(dim=2)[0].min(dim=2)[0].cpu().numpy()
                    batch_max = activation.max(dim=2)[0].max(dim=2)[0].cpu().numpy()
                else:  # For 2D activations (e.g., transformer outputs)
                    batch_mean = activation.mean(dim=1).cpu().numpy()
                    batch_std = activation.std(dim=1).cpu().numpy()
                    batch_min = activation.min(dim=1)[0].cpu().numpy()
                    batch_max = activation.max(dim=1)[0].cpu().numpy()
                
                # Store batch statistics
                statistics[name]['mean'].extend(batch_mean)
                statistics[name]['std'].extend(batch_std)
                statistics[name]['min'].extend(batch_min)
                statistics[name]['max'].extend(batch_max)
    
    # Remove hooks
    for hook in hooks:
        hook.remove()
    
    # Convert lists to numpy arrays
    for name in statistics:
        for stat_name in statistics[name]:
            statistics[name][stat_name] = np.array(statistics[name][stat_name])
    
    return statistics, np.array(labels)

def get_layer_by_name(model, name):
    \"\"\"Get layer by name\"\"\"\n",
    if name == 'first_block' and hasattr(model, 'blocks'):
        return model.blocks[0]
    elif name == 'last_block' and hasattr(model, 'blocks'):
        return model.blocks[-1]
    elif name == 'first_layer' and hasattr(model, 'layers'):
        return model.layers[0]
    elif name == 'last_layer' and hasattr(model, 'layers'):
        return model.layers[-1]
    else:
        print(f\"Unknown layer name: {name}\")
        return Nonedef find_target_layer(model, model_type):
    \"\"\"Find the target layer for Grad-CAM based on model type\"\"\"
    if model_type == 'vit':
        # For ViT, we use the output of the last transformer block
        return model.blocks[-1]
    elif model_type == 'deit':
        # For DeiT, we use the output of the last transformer block
        return model.blocks[-1]
    elif model_type == 'swin':
        # For Swin, we use the output of the last layer
        return model.layers[-1]
    else:
        raise ValueError(f\"Unknown model type: {model_type}\")

# Apply Grad-CAM to our sample images
if dataset is not None and len(models) > 0:
    print(\"\\n\\nGrad-CAM Visualization:\")
    
    # Try for different samples (real and fake)
    for i in range(min(4, len(sample_images))):
        sample_img = sample_images[i]
        sample_label = sample_labels[i]
        
        print(f\"\\nGrad-CAM for {'real' if sample_label == 0 else 'fake'} sample {i+1}:\")
        
        for name, model in models.items():
            print(f\"\\n{name.upper()} model:\")
            try:
                # Find target layer
                target_layer = find_target_layer(model, name)
                
                # Visualize Grad-CAM
                visualize_gradcam(model, sample_img, target_layer, class_idx=sample_label)
            except Exception as e:
                print(f\"Error applying Grad-CAM to {name} model: {e}\")
else:
    print(\"Cannot visualize Grad-CAM without sample images and models.\"){
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Deepfake Detection Model Visualization\n",
    "\n",
    "This notebook provides visualizations and interpretability tools for the deepfake detection models. It includes:\n",
    "\n",
    "1. Model architecture visualization\n",
    "2. Attention map visualization\n",
    "3. Grad-CAM analysis\n",
    "4. Feature visualization\n",
    "5. Comparison of model behaviors\n",
    "\n",
    "These visualizations help understand how the models are making decisions and what features they are focusing on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Import necessary libraries\n",
    "import os\n",
    "import sys\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from PIL import Image\n",
    "import cv2\n",
    "from tqdm.notebook import tqdm\n",
    "import torch.nn.functional as F\n",
    "import torchvision.transforms as transforms\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# Add parent directory to path to enable imports from project\n",
    "sys.path.append(os.path.abspath('..'))\n",
    "\n",
    "# Import project modules\n",
    "from models.vit.model import ViT\n",
    "from models.deit.model import DeiT\n",
    "from models.swin.model import SwinTransformer\n",
    "from models.model_zoo.model_factory import create_model\n",
    "from data.datasets.faceforensics import FaceForensicsDataset\n",
    "from data.datasets.celebdf import CelebDFDataset\n",
    "from evaluation.visualization.attention_maps import visualize_attention_maps\n",
    "from evaluation.visualization.grad_cam import visualize_grad_cam\n",
    "from evaluation.visualization.feature_visualization import visualize_features\n",
    "\n",
    "# Set plot style\n",
    "plt.style.use('fivethirtyeight')\n",
    "sns.set(style=\"whitegrid\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Pretrained Models\n",
    "\n",
    "Load pretrained models for visualization. You need to have trained models saved in checkpoint format."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure paths - update these to your checkpoint paths\n",
    "CHECKPOINT_DIR = \"../trained_models\"\n",
    "VIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"vit_celebdf/checkpoints/best.pth\")\n",
    "DEIT_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"deit_celebdf/checkpoints/best.pth\")\n",
    "SWIN_CHECKPOINT = os.path.join(CHECKPOINT_DIR, \"swin_celebdf/checkpoints/best.pth\")\n",
    "\n",
    "# Check if checkpoints exist\n",
    "vit_exists = os.path.exists(VIT_CHECKPOINT)\n",
    "deit_exists = os.path.exists(DEIT_CHECKPOINT)\n",
    "swin_exists = os.path.exists(SWIN_CHECKPOINT)\n",
    "\n",
    "print(f\"ViT checkpoint exists: {vit_exists}\")\n",
    "print(f\"DeiT checkpoint exists: {deit_exists}\")\n",
    "print(f\"Swin checkpoint exists: {swin_exists}\")\n",
    "\n",
    "# Set device\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def load_checkpoint(model, checkpoint_path, device):\n",
    "    \"\"\"Load model from checkpoint\"\"\"\n",
    "    if not os.path.exists(checkpoint_path):\n",
    "        print(f\"Checkpoint not found at {checkpoint_path}\")\n",
    "        return None\n",
    "    \n",
    "    try:\n",
    "        checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "        \n",
    "        # Different checkpoint formats\n",
    "        if 'model' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model'])\n",
    "        elif 'model_state_dict' in checkpoint:\n",
    "            model.load_state_dict(checkpoint['model_state_dict'])\n",
    "        else:\n",
    "            model.load_state_dict(checkpoint)\n",
    "            \n",
    "        model = model.to(device)\n",
    "        model.eval()  # Set to evaluation mode\n",
    "        print(f\"Model loaded successfully from {checkpoint_path}\")\n",
    "        return model\n",
    "    except Exception as e:\n",
    "        print(f\"Error loading checkpoint: {e}\")\n",
    "        return None\n",
    "\n",
    "# Initialize models\n",
    "models = {}\n",
    "\n",
    "# Load ViT model\n",
    "if vit_exists:\n",
    "    vit_model = ViT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12\n",
    "    )\n",
    "    vit_model = load_checkpoint(vit_model, VIT_CHECKPOINT, device)\n",
    "    if vit_model is not None:\n",
    "        models['vit'] = vit_model\n",
    "\n",
    "# Load DeiT model\n",
    "if deit_exists:\n",
    "    deit_model = DeiT(\n",
    "        img_size=224,\n",
    "        patch_size=16,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=768,\n",
    "        depth=12,\n",
    "        num_heads=12,\n",
    "        distillation=True\n",
    "    )\n",
    "    deit_model = load_checkpoint(deit_model, DEIT_CHECKPOINT, device)\n",
    "    if deit_model is not None:\n",
    "        models['deit'] = deit_model\n",
    "\n",
    "# Load Swin model\n",
    "if swin_exists:\n",
    "    swin_model = SwinTransformer(\n",
    "        img_size=224,\n",
    "        patch_size=4,\n",
    "        in_channels=3,\n",
    "        num_classes=1,\n",
    "        embed_dim=96,\n",
    "        depths=[2, 2, 6, 2],\n",
    "        num_heads=[3, 6, 12, 24],\n",
    "        window_size=7\n",
    "    )\n",
    "    swin_model = load_checkpoint(swin_model, SWIN_CHECKPOINT, device)\n",
    "    if swin_model is not None:\n",
    "        models['swin'] = swin_model\n",
    "\n",
    "print(f\"Loaded {len(models)} models: {list(models.keys())}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Load Sample Data\n",
    "\n",
    "Load some sample images to visualize the model behavior."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Configure dataset paths - update these to your local paths\n",
    "FACEFORENSICS_ROOT = \"/path/to/datasets/FaceForensics\"\n",
    "CELEBDF_ROOT = \"/path/to/datasets/CelebDF\"\n",
    "\n",
    "# Choose one dataset to use for visualization\n",
    "DATASET_ROOT = CELEBDF_ROOT  # Change to the dataset you want to use\n",
    "DATASET_NAME = \"celebdf\"     # Change to match your dataset (\"faceforensics\" or \"celebdf\")\n",
    "\n",
    "# Check if directory exists\n",
    "dataset_exists = os.path.exists(DATASET_ROOT)\n",
    "print(f\"Dataset path exists: {dataset_exists}\")\n",
    "\n",
    "# Define transform\n",
    "transform = transforms.Compose([\n",
    "    transforms.Resize((224, 224)),\n",
    "    transforms.ToTensor(),\n",
    "    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
    "])\n",
    "\n",
    "# Load dataset if it exists\n",
    "if dataset_exists:\n",
    "    if DATASET_NAME == \"faceforensics\":\n",
    "        dataset = FaceForensicsDataset(\n",
    "            root=DATASET_ROOT,\n",
    "            split=\"test\",  # Use test split for visualization\n",
    "            img_size=224,\n",
    "            transform=transform\n",
    "        )\n",
    "    elif DATASET_NAME == \"celebdf\":\n",
    "        dataset = CelebDFDataset(\n",
    "            root=DATASET_ROOT,\n",
    "            split=\"test\",  # Use test split for visualization\n",
    "            img_size=224,\n",
    "            transform=transform\n",
    "        )\n",
    "    else:\n",
    "        print(f\"Unknown dataset name: {DATASET_NAME}\")\n",
    "        dataset = None\n",
    "        \n",
    "    if dataset is not None:\n",
    "        print(f\"Dataset loaded with {len(dataset)} samples\")\n",
    "else:\n",
    "    print(\"Dataset path not found. Cannot load samples.\")\n",
    "    dataset = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def get_sample_batch(dataset, batch_size=4):\n",
    "    \"\"\"Get a sample batch with equal number of real and fake samples\"\"\"\n",
    "    if dataset is None:\n",
    "        return None, None\n",
    "    \n",
    "    dataloader = DataLoader(dataset, batch_size=batch_size*10, shuffle=True)\n",
    "    images, labels = next(iter(dataloader))\n",
    "    \n",
    "    # Separate real and fake\n",
    "    real_imgs = [img for img, label in zip(images, labels) if label == 0]\n",
    "    fake_imgs = [img for img, label in zip(images, labels) if label == 1]\n",
    "    \n",
    "    # Make sure we have enough samples\n",
    "    half_batch = batch_size // 2\n",
    "    if len(real_imgs) < half_batch or len(fake_imgs) < half_batch:\n",
    "        print(\"Not enough samples of each class. Getting more...\")\n",
    "        return get_sample_batch(dataset, batch_size)\n",
    "    \n",
    "    # Get equal number of each\n",
    "    real_imgs = real_imgs[:half_batch]\n",
    "    fake_imgs = fake_imgs[:half_batch]\n",
    "    \n",
    "    # Combine\n",
    "    sample_imgs = real_imgs + fake_imgs\n",
    "    sample_labels = [0] * half_batch + [1] * half_batch\n",
    "    \n",
    "    return torch.stack(sample_imgs), torch.tensor(sample_labels)\n",
    "\n",
    "# Get sample batch if dataset exists\n",
    "if dataset is not None:\n",
    "    sample_images, sample_labels = get_sample_batch(dataset, batch_size=8)\n",
    "    print(f\"Sample batch shape: {sample_images.shape}\")\n",
    "    print(f\"Sample labels: {sample_labels}\")\n",
    "    \n",
    "    # Visualize samples\n",
    "    plt.figure(figsize=(15, 8))\n",
    "    for i in range(len(sample_images)):\n",
    "        plt.subplot(2, 4, i+1)\n",
    "        img = sample_images[i].permute(1, 2, 0).numpy()\n",
    "        # Denormalize for visualization\n",
    "        img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])\n",
    "        img = np.clip(img, 0, 1)\n",
    "        plt.imshow(img)\n",
    "        plt.title(\"Real\" if sample_labels[i] == 0 else \"Fake\")\n",
    "        plt.axis('off')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"No dataset available. Skipping sample visualization.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Model Architecture Visualization\n",
    "\n",
    "Visualize the architecture of the models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def count_parameters(model):\n",
    "    \"\"\"Count number of trainable parameters\"\"\"\n",
    "    return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "\n",
    "def print_model_summary(model, name):\n",
    "    \"\"\"Print model summary\"\"\"\n",
    "    print(f\"\\n{name} Model Summary:\")\n",
    "    print(f\"Total parameters: {count_parameters(model):,}\")\n",
    "    \n",
    "    # Print main component information\n",
    "    if name.lower() == 'vit':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Number of patches: {model.num_patches}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of transformer blocks: {len(model.blocks)}\")\n",
    "        print(f\"Number of attention heads: {model.blocks[0].attn.num_heads}\")\n",
    "    elif name.lower() == 'deit':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Number of patches: {model.num_patches}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of transformer blocks: {len(model.blocks)}\")\n",
    "        print(f\"Number of attention heads: {model.blocks[0].attn.num_heads}\")\n",
    "        print(f\"Using distillation: {model.distillation is not None}\")\n",
    "    elif name.lower() == 'swin':\n",
    "        print(f\"Image size: {model.img_size}\")\n",
    "        print(f\"Patch size: {model.patch_size}\")\n",
    "        print(f\"Embedding dimension: {model.embed_dim}\")\n",
    "        print(f\"Number of layers: {model.num_layers}\")\n",
    "        print(f\"Depths: {model.depths}\")\n",
    "\n",
    "# Visualize model architectures\n",
    "for name, model in models.items():\n",
    "    print_model_summary(model, name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# Try to generate a visual diagram of the model architecture\n",
    "try:\n",
    "    from torchviz import make_dot\n",
    "    from torch.autograd import Variable\n",
    "    \n",
    "    # Function to create model diagram\n",
    "    def visualize_model_graph(model, name):\n",
    "        # Create a sample input\n",
    "        x = Variable(torch.randn(1, 3, 224, 224)).to(device)\n",
    "        \n",
    "        # Generate output\n",
    "        y = model(x)\n",
    "        \n",
    "        # Create dot graph\n",
    "        dot = make_dot(y, params=dict(list(model.named_parameters())))\n",
    "        \n",
    "        # Save and display\n",
    "        dot.format = 'png'\n",
    "        dot.render(f\"{name}_architecture\", cleanup=True)\n",
    "        \n",
    "        # Display\n",
    "        from IPython.display import Image\n",
    "        return Image(filename=f\"{name}_architecture.png\")\n",
    "    \n",
    "    # Visualize each model\n",
    "    for name, model in models.items():\n",
    "        print(f\"\\nGenerating architecture diagram for {name}...\")\n",
    "        display(visualize_model_graph(model, name))\n",
    "except ImportError:\n",
    "    print(\"torchviz not installed. Install with: pip install torchviz\")\n",
    "    print(\"Also requires graphviz to be installed. See: https://graphviz.org/download/\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Attention Map Visualization\n",
    "\n",
    "Visualize attention maps from transformer-based models."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "def get_attention_maps(model, img, model_type):\n",
    "    \"\"\"Get attention maps for the given model and image\"\"\"\n",
    "    # Move image to device\n",
    "    img = img.unsqueeze(0).to(device)\n",
    "    attention_maps = []\n",
    "    \n",
    "    # Hook function to retrieve attention maps\n",
    "    def hook_fn(module, input, output):\n",
    "        if model_type in ['vit', 'deit']:\n",
    "            # For ViT/DeiT, attention is of shape (batch_size, num_heads, seq_len, seq_len)\n",
    "            attention_maps.append(output.detach().cpu())\n",
    "        elif model_type == 'swin':\n",
    "            # For Swin, attention structure is different\n",
    "            # This is a simplified approach - may need adjustment\n",
    "            attention_maps.append(output.detach().cpu())\n",
    "    \n",
    "    # Register hooks\n",
    "    hooks = []\n",
    "    if model_type in ['vit', 'deit']:\n",
    "        for block in model.blocks:\n",
    "            hooks.append(block.attn.register_forward_hook(hook_fn))\n",
    "    elif model_type == 'swin':\n",
    "        # For Swin, we need to find the window attention modules\n",
    "        for layer in model.layers:\n",
    "            for block in layer.blocks:\n",
    "                hooks.append(block.attn.register_forward_hook(hook_fn))\n",
    "    \n",
    "    # Forward pass\n",
    "    with torch.no_grad():\n",
    "        model(img)\n",
    "    \n",
    "    # Remove hooks\n",
    "    for hook in hooks:\n",
    "        hook.remove()\n",
    "    \n",
    "    return attention_maps\n",
    "\n",
    "def visualize_attention(model, img, model_type, layer_idx=-1, head_idx=0):\n",
    "    \"\"\"Visualize attention map for a specific layer and head\"\"\"\n",
    "    attention_maps = get_attention_maps(model, img, model_type)\n",
    "    \n",
    "    if not attention_maps:\n",
    "        print(\"No attention maps retrieved.\")\n",
    "        return\n",
    "    \n",
    "    # Get attention map for the specified layer\n",
    "    if layer_idx < 0:\n",
    "        layer_idx = len(attention_maps) + layer_idx\n",
    "    \n",
    "    if layer_idx >= len(attention_maps):\n",
    "        print(f\"Layer index {layer_idx} out of range. Max: {len(attention_maps)-1}\")\n",
    "        return\n",
    "    \n",
    "    attention = attention_maps[layer_idx][0]  # Get the first batch\n",
    "    \n",
    "    # For ViT/DeiT\n",
    "    if model_type in ['vit', 'deit']:\n",
    "        num_heads = attention.shape[0]\n",
    "        if head_idx >= num_heads:\n",
    "            print(f\"Head index {head_idx} out of range. Max: {num_heads-1}\")\n",
    "            return\n",
    "        \n",
    "        # Get attention for specific head\n",
    "        attn_map = attention[head_idx].numpy()\n",
    "        \n",
    "        # The first row corresponds to the [CLS] token's attention to all other tokens\n",
    "        cls_attn = attn_map[0, 1:]  # Skip the attention to [CLS] itself\n",
    "        \n",
    "        # Reshape to image size for visualization\n",
    "        size = int(np.sqrt(len(cls_attn)))\n",
    "        cls_attn = cls_attn.reshape(size, size)\n",
    "        \n",
    "        # Visualize\n",
    "        plt.figure(figsize=(12, 5))\n",
    "        \n",
    "        plt.subplot(1, 2, 1)\n",
    "        # Denormalize image for visualization\n",
    "        img_np = img.permute(1, 2, 0).numpy()\n",
    "        img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])\n",
    "        img_np = np.clip(img_np, 0, 1)\n",
    "        plt.imshow(img_np)\n",
    "        plt.title(\"Input Image\")\n",
    "        plt.axis('off')\n",
    "        \n",
    "        plt.subplot(1, 2, 2)\n",
    "        plt.imshow(cls_attn, cmap='viridis')\n",
    "        plt.title(f\"Layer {layer_idx}, Head {head_idx} - CLS Token Attention\")\n",
    "        plt.colorbar(format='%.2f')\n",
    "        \n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "        \n",
    "        return cls_attn\n",
    "    else:\n",
    "        print(f\"Attention visualization for {model_type} is not implemented.\")\n",
    "        return None\n",
    "\n",
    "# Visualize attention maps if we have sample images and transformer models\n",
    "if dataset is not None and len(models) > 0:\n",
    "    # Get a sample image\n",
    "    sample_idx = 0  # Choose an index from the sample batch\n",
    "    sample_img = sample_images[sample_idx]\n",
    "    sample_label = sample_labels[sample_idx]\n",
    "    \n",
    "    print(f\"Visualizing attention for a {'real' if sample_label == 0 else 'fake'} sample\")\n",
    "    \n",
    "    # Visualize for each transformer model\n",
    "    for name, model in models.items():\n",
    "        if name in ['vit', 'deit']:  # Currently implemented for ViT and DeiT\n",
    "            print(f\"\\nAttention maps for {name.upper()}:\")\n",
    "            \n",
    "            # Visualize last layer attention\n",
    "            print(\"Last layer attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=-1, head_idx=0)\n",
    "            \n",
    "            # Visualize middle layer attention\n",
    "            middle_layer = len(model.blocks) // 2\n",
    "            print(f\"Middle layer ({middle_layer}) attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=middle_layer, head_idx=0)\n",
    "            \n",
    "            # Visualize first layer attention\n",
    "            print(\"First layer attention:\")\n",
    "            visualize_attention(model, sample_img, name, layer_idx=0, head_idx=0)\n",
    "else:\n",
    "    print(\"Cannot visualize attention maps without sample images and transformer models.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Grad-CAM Visualization\n",
    "\n",
    "Use Grad-CAM to visualize the regions of the image that are important for the model's decision."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "class GradCAM:\n",
    "    \"\"\"Grad-CAM implementation for CNN-based models\"\"\"\n",
    "    \n",
    "    def __init__(self, model, target_layer):\n",
    "        self.model = model\n",
    "        self.target_layer = target_layer\n",
    "        self.gradients = None\n",
    "        self.activations = None\n",
    "        \n",
    "        # Register hooks\n",
    "        self.register_hooks()\n",
    "    \n",
    "    def register_hooks(self):\n",
    "        def forward_hook(module, input, output):\n",
    "            self.activations = output.detach()\n",
    "        \n",
    "        def backward_hook(module, grad_input, grad_output):\n",
    "            self.gradients = grad_output[0].detach()\n",
    "        \n",
    "        # Register hooks\n",
    "        self.target_layer.register_forward_hook(forward_hook)\n",
    "        self.target_layer.register_backward_hook(backward_hook)\n",
    "    \n",
    "    def __call__(self, x, class_idx=None):\n",
    "        # Forward pass\n",
    "        b, c, h, w = x.size()\n",
    "        logits = self.model(x)\n",
    "        \n",
    "        # If class_idx is None, use the model's prediction\n",
    "        if class_idx is None:\n",
    "            if logits.dim() > 1:\n",
    "                class_idx = torch.argmax(logits, dim=1).item()\n",
    "            else:\n",
    "                class_idx = (logits > 0.5).long().item()\n",
    "        \n",
    "        # Backward pass\n",
    "        self.model.zero_grad()\n",
    "        if logits.dim() > 1:\n",
    "            target = torch.zeros_like(logits)\n",
    "            target[0, class_idx] = 1\n",
    "        else:\n",
    "            target = torch.ones_like(logits) if class_idx == 1 else torch.zeros_like(logits)\n",
    "        \n",
    "        logits.backward(gradient=target, retain_graph=True)\n",