<a href="https://colab.research.google.com/github/msrishav-28/DeiT_model/blob/main/Deepfake_Detection_DeiT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Comprehensive DeiT Deepfake Detection Framework
# For FaceForensics++ (c40) and Celeb-DF datasets

# Cell 1: Installation and Setup
"""
# Install required packages
!pip install timm==0.6.12
!pip install opencv-python==4.8.0.76
!pip install albumentations==1.3.1
!pip install facenet-pytorch==2.5.3
!pip install wandb==0.15.12
!pip install onnx==1.14.1
!pip install av==10.0.0
!pip install ffmpeg-python==0.2.0
!pip install scikit-learn==1.3.0
!pip install tqdm==4.66.1

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
"""

"\n# Install required packages\n!pip install timm==0.6.12\n!pip install opencv-python==4.8.0.76\n!pip install albumentations==1.3.1\n!pip install facenet-pytorch==2.5.3\n!pip install wandb==0.15.12\n!pip install onnx==1.14.1\n!pip install av==10.0.0\n!pip install ffmpeg-python==0.2.0\n!pip install scikit-learn==1.3.0\n!pip install tqdm==4.66.1\n\n# Mount Google Drive\nfrom google.colab import drive\ndrive.mount('/content/drive')\n"

In [None]:
!pip install av

Collecting av
  Downloading av-14.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (4.7 kB)
Downloading av-14.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (35.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m35.2/35.2 MB[0m [31m48.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: av
Successfully installed av-14.3.0


In [None]:
# Cell 2: Import Libraries
import os
import sys
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
from pathlib import Path
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torch.cuda.amp import autocast, GradScaler
from timm import create_model
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
import av
from facenet_pytorch import MTCNN
from sklearn.metrics import roc_curve, auc, precision_recall_curve, confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import warnings
warnings.filterwarnings('ignore')

ModuleNotFoundError: No module named 'facenet_pytorch'

In [None]:
# Cell 3: Set Random Seeds
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

In [None]:
# Cell 4: Project Configuration
class Config:
    # Project paths
    base_path = "/content/drive/MyDrive/deepfake_detection"

    # FaceForensics++ configuration
    ff_path = f"{base_path}/FaceForensics++"
    ff_compression = "c40"
    ff_methods = ["Deepfakes", "Face2Face", "FaceSwap", "NeuralTextures"]
    ff_samples_per_method = 500  # 500 samples from each category

    # Celeb-DF configuration
    celebdf_path = f"{base_path}/Celeb-DF"
    celebdf_v2_path = f"{base_path}/Celeb-DF-v2"
    celebdf_samples = 2000  # Samples to use from Celeb-DF

    # Frame extraction
    frames_per_video = 20
    face_margin = 0.3  # Margin around detected face
    min_face_size = 48  # Minimum face size to consider

    # Dataset configuration
    image_size = 224
    batch_size = 32
    num_workers = 2
    train_ratio = 0.8
    val_ratio = 0.1
    test_ratio = 0.1

    # Model configuration
    model_name = "deit_small_patch16_224"
    pretrained = True
    num_classes = 2

    # Training configuration
    num_epochs = 20
    learning_rate = 2e-5
    weight_decay = 1e-4
    warmup_epochs = 2
    checkpoint_dir = f"{base_path}/checkpoints"
    results_dir = f"{base_path}/results"

    # Advanced training features
    mixed_precision = True
    gradient_accumulation_steps = 1
    gradient_clip_val = 1.0
    label_smoothing = 0.1
    mixup_alpha = 0.2

    # Save paths
    save_frequency = 1

    # Logging
    use_wandb = False
    project_name = "deepfake-detection"
    experiment_name = "deit-ffpp-celebdf"

config = Config()

# Make sure directories exist
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.results_dir, exist_ok=True)

In [None]:
# Cell 5: Video Processing Utilities
class VideoProcessor:
    """Video processing utilities for deepfake detection"""
    def __init__(self, config, device='cuda'):
        self.config = config
        self.device = device

        # Initialize face detector
        self.face_detector = MTCNN(
            keep_all=True,
            post_process=False,
            min_face_size=self.config.min_face_size,
            device=self.device
        )

    def extract_frames(self, video_path, output_dir, max_frames=None):
        """Extract frames from a video file"""
        if max_frames is None:
            max_frames = self.config.frames_per_video

        os.makedirs(output_dir, exist_ok=True)

        try:
            # Open the video file
            container = av.open(video_path)

            # Get video stream
            video_stream = next(s for s in container.streams if s.type == 'video')

            # Calculate frame interval to extract evenly distributed frames
            n_frames = video_stream.frames
            if n_frames <= 0:
                # If frames can't be determined, guess based on duration and fps
                duration = float(video_stream.duration * video_stream.time_base)
                n_frames = int(duration * video_stream.average_rate)

            if n_frames <= 0:
                print(f"Warning: Could not determine frame count for {video_path}")
                n_frames = 1000  # Assume a reasonable number

            interval = max(1, n_frames // max_frames)

            # Extract frames
            frame_count = 0
            saved_count = 0

            for frame in container.decode(video_stream):
                if frame_count % interval == 0:
                    # Convert to PIL Image
                    img = frame.to_image()

                    # Save the frame
                    img_path = os.path.join(output_dir, f"frame_{saved_count:04d}.jpg")
                    img.save(img_path)

                    saved_count += 1
                    if saved_count >= max_frames:
                        break

                frame_count += 1

            return saved_count

        except Exception as e:
            print(f"Error processing video {video_path}: {e}")
            return 0

    def detect_faces(self, image_path, output_path=None, return_image=False):
        """Detect faces in an image and save cropped face images"""
        try:
            # Load the image
            if isinstance(image_path, str):
                img = cv2.imread(image_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            else:
                img = image_path

            # Get face detections
            boxes, probs = self.face_detector.detect(img)

            if boxes is None or len(boxes) == 0:
                return []

            # Process each detected face
            face_images = []

            for i, (box, prob) in enumerate(zip(boxes, probs)):
                if prob < 0.9:  # Confidence threshold
                    continue

                # Get coordinates
                x1, y1, x2, y2 = box.tolist()

                # Add margin
                margin = self.config.face_margin
                h, w = img.shape[:2]

                dx = (x2 - x1) * margin
                dy = (y2 - y1) * margin

                x1 = max(0, int(x1 - dx))
                y1 = max(0, int(y1 - dy))
                x2 = min(w, int(x2 + dx))
                y2 = min(h, int(y2 + dy))

                # Crop face
                face_img = img[y1:y2, x1:x2]

                # Save face if requested
                if output_path:
                    os.makedirs(os.path.dirname(output_path), exist_ok=True)
                    if len(boxes) == 1:
                        face_path = output_path
                    else:
                        basename = os.path.splitext(output_path)[0]
                        ext = os.path.splitext(output_path)[1]
                        face_path = f"{basename}_{i}{ext}"

                    cv2.imwrite(face_path, cv2.cvtColor(face_img, cv2.COLOR_RGB2BGR))

                if return_image:
                    face_images.append(face_img)

            return face_images

        except Exception as e:
            print(f"Error detecting faces in {image_path}: {e}")
            return []

    def process_video(self, video_path, output_dir, extract_faces=True):
        """Process a video: extract frames and detect faces"""
        # Create temporary directory for frames
        frames_dir = os.path.join(output_dir, "frames")
        faces_dir = os.path.join(output_dir, "faces")

        os.makedirs(frames_dir, exist_ok=True)
        if extract_faces:
            os.makedirs(faces_dir, exist_ok=True)

        # Extract frames
        num_frames = self.extract_frames(video_path, frames_dir)

        # Detect faces in frames
        face_count = 0
        if extract_faces and num_frames > 0:
            for frame_file in os.listdir(frames_dir):
                if not frame_file.endswith(('.jpg', '.png', '.jpeg')):
                    continue

                frame_path = os.path.join(frames_dir, frame_file)
                face_path = os.path.join(faces_dir, frame_file)

                _ = self.detect_faces(frame_path, face_path)

                if os.path.exists(face_path):
                    face_count += 1

        return num_frames, face_count

In [None]:
# Cell 6: Dataset Preparation
class DeepfakeDatasetPreparation:
    """Prepare FaceForensics++ and Celeb-DF datasets"""
    def __init__(self, config, device='cuda'):
        self.config = config
        self.device = device
        self.processor = VideoProcessor(config, device)

    def prepare_ff_dataset(self):
        """Prepare FaceForensics++ dataset"""
        print("Preparing FaceForensics++ dataset...")

        # Define paths
        real_dir = os.path.join(self.config.ff_path, "original_sequences", "youtube",
                            self.config.ff_compression, "videos")

        # Create output directories
        output_base = os.path.join(self.config.ff_path, "processed")
        os.makedirs(output_base, exist_ok=True)

        # Process real videos
        real_videos = []
        if os.path.exists(real_dir):
            real_videos = [os.path.join(real_dir, f) for f in os.listdir(real_dir)
                          if f.endswith(".mp4") or f.endswith(".avi")]

            # Limit samples if needed
            random.shuffle(real_videos)
            real_videos = real_videos[:self.config.ff_samples_per_method]

            print(f"Processing {len(real_videos)} real videos...")
            for video_path in tqdm(real_videos):
                video_name = os.path.basename(video_path)
                output_dir = os.path.join(output_base, "original", video_name)
                self.processor.process_video(video_path, output_dir)

        # Process fake videos for each method
        for method in self.config.ff_methods:
            fake_dir = os.path.join(self.config.ff_path, "manipulated_sequences",
                                   method, self.config.ff_compression, "videos")

            if not os.path.exists(fake_dir):
                print(f"Directory not found: {fake_dir}")
                continue

            fake_videos = [os.path.join(fake_dir, f) for f in os.listdir(fake_dir)
                          if f.endswith(".mp4") or f.endswith(".avi")]

            # Limit samples if needed
            random.shuffle(fake_videos)
            fake_videos = fake_videos[:self.config.ff_samples_per_method]

            print(f"Processing {len(fake_videos)} {method} fake videos...")
            for video_path in tqdm(fake_videos):
                video_name = os.path.basename(video_path)
                output_dir = os.path.join(output_base, method, video_name)
                self.processor.process_video(video_path, output_dir)

        print("FaceForensics++ dataset preparation complete.")

    def prepare_celebdf_dataset(self, version="v2"):
        """Prepare Celeb-DF dataset"""
        print(f"Preparing Celeb-DF {version} dataset...")

        # Define paths based on version
        if version == "v2":
            celeb_path = self.config.celebdf_v2_path
        else:
            celeb_path = self.config.celebdf_path

        if not os.path.exists(celeb_path):
            print(f"Celeb-DF {version} path not found: {celeb_path}")
            return

        # Create output directories
        output_base = os.path.join(celeb_path, "processed")
        os.makedirs(output_base, exist_ok=True)

        # Get real and fake video lists
        real_videos = []
        fake_videos = []

        # Find real videos (usually in Celeb-real or similar directory)
        for root, _, files in os.walk(celeb_path):
            for file in files:
                if file.endswith((".mp4", ".avi")):
                    file_path = os.path.join(root, file)
                    if "real" in root.lower():
                        real_videos.append(file_path)
                    elif "fake" in root.lower() or "synthesis" in root.lower():
                        fake_videos.append(file_path)

        # If structured differently, try alternative approach
        if len(real_videos) == 0 or len(fake_videos) == 0:
            # Some Celeb-DF releases use a txt list file
            list_file = os.path.join(celeb_path, "List_of_testing_videos.txt")
            if os.path.exists(list_file):
                with open(list_file, "r") as f:
                    for line in f:
                        parts = line.strip().split()
                        if len(parts) >= 2:
                            video_path = os.path.join(celeb_path, parts[1])
                            if os.path.exists(video_path):
                                if "real" in parts[1].lower():
                                    real_videos.append(video_path)
                                else:
                                    fake_videos.append(video_path)

        # Limit samples if needed
        real_limit = min(len(real_videos), self.config.celebdf_samples // 2)
        fake_limit = min(len(fake_videos), self.config.celebdf_samples // 2)

        random.shuffle(real_videos)
        random.shuffle(fake_videos)

        real_videos = real_videos[:real_limit]
        fake_videos = fake_videos[:fake_limit]

        # Process real videos
        print(f"Processing {len(real_videos)} real Celeb-DF videos...")
        for video_path in tqdm(real_videos):
            video_name = os.path.basename(video_path)
            output_dir = os.path.join(output_base, "real", video_name)
            self.processor.process_video(video_path, output_dir)

        # Process fake videos
        print(f"Processing {len(fake_videos)} fake Celeb-DF videos...")
        for video_path in tqdm(fake_videos):
            video_name = os.path.basename(video_path)
            output_dir = os.path.join(output_base, "fake", video_name)
            self.processor.process_video(video_path, output_dir)

        print(f"Celeb-DF {version} dataset preparation complete.")

    def create_dataset_csv(self):
        """Create CSV files with dataset information"""
        print("Creating dataset CSV files...")

        datasets = [
            {"name": "ff++", "path": os.path.join(self.config.ff_path, "processed")},
            {"name": "celebdf", "path": os.path.join(self.config.celebdf_v2_path, "processed")}
        ]

        all_data = []

        for dataset in datasets:
            if not os.path.exists(dataset["path"]):
                print(f"Dataset path not found: {dataset['path']}")
                continue

            # Walk through directories
            for root, _, files in os.walk(dataset["path"]):
                for file in files:
                    if file.endswith((".jpg", ".png", ".jpeg")):
                        file_path = os.path.join(root, file)

                        # Determine if real or fake
                        is_real = "original" in root.lower() or "real" in root.lower()
                        label = 0 if is_real else 1

                        # Determine manipulation method
                        if "deepfakes" in root.lower():
                            method = "Deepfakes"
                        elif "face2face" in root.lower():
                            method = "Face2Face"
                        elif "faceswap" in root.lower():
                            method = "FaceSwap"
                        elif "neuraltextures" in root.lower():
                            method = "NeuralTextures"
                        elif "fake" in root.lower() or "synthesis" in root.lower():
                            method = "Celeb-DF"
                        else:
                            method = "Original"

                        all_data.append({
                            "dataset": dataset["name"],
                            "path": file_path,
                            "label": label,
                            "method": method
                        })

        # Shuffle data
        random.shuffle(all_data)

        # Split into train, validation, and test sets
        n_samples = len(all_data)
        n_train = int(n_samples * self.config.train_ratio)
        n_val = int(n_samples * self.config.val_ratio)

        train_data = all_data[:n_train]
        val_data = all_data[n_train:n_train+n_val]
        test_data = all_data[n_train+n_val:]

        # Create DataFrames
        train_df = pd.DataFrame(train_data)
        val_df = pd.DataFrame(val_data)
        test_df = pd.DataFrame(test_data)

        # Save to CSV
        train_df.to_csv(os.path.join(self.config.base_path, "train_data.csv"), index=False)
        val_df.to_csv(os.path.join(self.config.base_path, "val_data.csv"), index=False)
        test_df.to_csv(os.path.join(self.config.base_path, "test_data.csv"), index=False)

        print(f"Dataset CSV files created: {len(train_data)} train, {len(val_data)} validation, {len(test_data)} test samples")

        return train_df, val_df, test_df

In [None]:

# Cell 7: Dataset and Dataloader
class DeepfakeDataset(Dataset):
    """Dataset for deepfake detection"""
    def __init__(self, data_df, transform=None):
        self.data_df = data_df
        self.transform = transform

    def __len__(self):
        return len(self.data_df)

    def __getitem__(self, idx):
        row = self.data_df.iloc[idx]

        # Load image
        try:
            image = Image.open(row['path']).convert('RGB')
        except Exception as e:
            print(f"Error loading image {row['path']}: {e}")
            # Return a default image in case of error
            image = Image.new('RGB', (224, 224), color=0)

        # Apply transforms
        if self.transform:
            try:
                if isinstance(self.transform, A.Compose):
                    # For albumentations
                    image = np.array(image)
                    transformed = self.transform(image=image)
                    image = transformed['image']
                else:
                    # For torchvision transforms
                    image = self.transform(image)
            except Exception as e:
                print(f"Error applying transform: {e}")
                # Return a default tensor in case of error
                image = torch.zeros((3, 224, 224))

        return {
            'image': image,
            'label': row['label'],
            'method': row['method'],
            'dataset': row['dataset'],
            'path': row['path']
        }

def get_data_transforms(config):
    """Create data augmentation pipelines"""
    # Albumentations transforms
    train_transform = A.Compose([
        A.RandomResizedCrop(config.image_size, config.image_size, scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            A.CLAHE(p=0.5),
        ], p=0.5),
        A.OneOf([
            A.GaussNoise(p=0.5),
            A.GaussianBlur(blur_limit=3, p=0.5),
            A.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
        ], p=0.5),
        A.OneOf([
            A.Sharpen(p=0.5),
            A.Emboss(p=0.5),
            A.MotionBlur(p=0.5),
        ], p=0.3),
        A.RandomBrightnessContrast(p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    val_transform = A.Compose([
        A.Resize(config.image_size, config.image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    return train_transform, val_transform

def create_dataloaders(config, train_df, val_df, test_df):
    """Create dataloaders for training and evaluation"""
    train_transform, val_transform = get_data_transforms(config)

    train_dataset = DeepfakeDataset(train_df, transform=train_transform)
    val_dataset = DeepfakeDataset(val_df, transform=val_transform)
    test_dataset = DeepfakeDataset(test_df, transform=val_transform)

    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=config.num_workers,
        pin_memory=True,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    return train_loader, val_loader, test_loader


In [None]:
# Cell 8: Model Architecture
class DeiTForDeepfakeDetection(nn.Module):
    """Enhanced DeiT model for deepfake detection"""
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Load base model
        self.base_model = create_model(
            config.model_name,
            pretrained=config.pretrained,
            num_classes=0  # Remove classifier
        )

        # Get feature dimension
        self.feature_dim = self.base_model.num_features

        # Add attention pooling
        self.attention = nn.Sequential(
            nn.Linear(self.feature_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 1)
        )

        # Add classifier
        self.classifier = nn.Sequential(
            nn.LayerNorm(self.feature_dim),
            nn.Dropout(0.3),
            nn.Linear(self.feature_dim, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(0.2),
            nn.Linear(512, config.num_classes)
        )

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Initialize newly added layers"""
        for module in [self.attention, self.classifier]:
            for m in module.modules():
                if isinstance(m, nn.Linear):
                    nn.init.trunc_normal_(m.weight, std=0.02)
                    if m.bias is not None:
                        nn.init.constant_(m.bias, 0)
                elif isinstance(m, nn.LayerNorm):
                    nn.init.constant_(m.bias, 0)
                    nn.init.constant_(m.weight, 1.0)

    def forward(self, x, return_features=False):
        # Extract base features
        features = self.base_model(x)

        # Apply attention pooling
        attention_weights = self.attention(features)
        attention_weights = F.softmax(attention_weights, dim=1)
        attended_features = (features * attention_weights).sum(dim=1)

        # Classification
        logits = self.classifier(attended_features)

        if return_features:
            return logits, features, attention_weights
        return logits

In [None]:
# Cell 9: Training and Evaluation Functions
class DeepfakeTrainer:
    """Trainer for the deepfake detection model"""
    def __init__(self, model, train_loader, val_loader, test_loader, config, device):
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.config = config
        self.device = device

        # Move model to device
        self.model.to(self.device)

        # Optimizer
        self.optimizer = optim.AdamW(
            self.model.parameters(),
            lr=config.learning_rate,
            weight_decay=config.weight_decay
        )

        # Learning rate scheduler
        self.scheduler = self._get_scheduler()

        # Loss function
        self.criterion = nn.CrossEntropyLoss(label_smoothing=config.label_smoothing)

        # Gradient scaler for mixed precision training
        self.scaler = GradScaler() if config.mixed_precision else None

        # Best metrics for model saving
        self.best_val_auc = 0.0
        self.best_epoch = 0

    def _get_scheduler(self):
        """Get learning rate scheduler"""
        if self.config.warmup_epochs > 0:
            # LinearWarmupCosineAnnealingLR
            return optim.lr_scheduler.OneCycleLR(
                self.optimizer,
                max_lr=self.config.learning_rate,
                total_steps=self.config.num_epochs * len(self.train_loader),
                pct_start=self.config.warmup_epochs / self.config.num_epochs,
                div_factor=25.0,
                final_div_factor=1000.0,
                anneal_strategy='cos'
            )
        else:
            # CosineAnnealingLR
            return optim.lr_scheduler.CosineAnnealingLR(
                self.optimizer,
                T_max=self.config.num_epochs
            )

    def mixup_data(self, x, y, alpha=0.2):
        """Apply mixup augmentation"""
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(self.device)

        mixed_x = lam * x + (1 - lam) * x[index, :]
        y_a, y_b = y, y[index]

        return mixed_x, y_a, y_b, lam

    def train_epoch(self, epoch):
        """Train for one epoch"""
        self.model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, batch in enumerate(tqdm(self.train_loader, desc=f"Epoch {epoch+1} Training")):
            x = batch['image'].to(self.device)
            y = batch['label'].to(self.device)

            # Apply mixup if configured
            do_mixup = self.config.mixup_alpha > 0 and np.random.random() < 0.5
            if do_mixup:
                x, y_a, y_b, lam = self.mixup_data(x, y, self.config.mixup_alpha)
                y_a, y_b = y_a.long(), y_b.long()

            # Gradient accumulation steps
            self.optimizer.zero_grad()

            if self.config.mixed_precision:
                with autocast():
                    outputs = self.model(x)

                    if do_mixup:
                        loss = lam * self.criterion(outputs, y_a) + (1 - lam) * self.criterion(outputs, y_b)
                    else:
                        loss = self.criterion(outputs, y)

                # Scale the loss and backpropagate
                self.scaler.scale(loss).backward()

                # Clip gradients if configured
                if self.config.gradient_clip_val > 0:
                    self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_val)

                # Update weights
                self.scaler.step(self.optimizer)
                self.scaler.update()
            else:
                # Standard training without mixed precision
                outputs = self.model(x)

                if do_mixup:
                    loss = lam * self.criterion(outputs, y_a) + (1 - lam) * self.criterion(outputs, y_b)
                else:
                    loss = self.criterion(outputs, y)

                # Backward pass
                loss.backward()

                # Clip gradients if configured
                if self.config.gradient_clip_val > 0:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.gradient_clip_val)

                # Update weights
                self.optimizer.step()

            # Update learning rate if step-based scheduler
            if isinstance(self.scheduler, optim.lr_scheduler.OneCycleLR):
                self.scheduler.step()

            # Update metrics
            train_loss += loss.item()

            if not do_mixup:
                _, predicted = outputs.max(1)
                total += y.size(0)
                correct += predicted.eq(y).sum().item()

        # Calculate epoch metrics
        train_loss = train_loss / len(self.train_loader)
        train_acc = correct / total if total > 0 else 0

        return train_loss, train_acc

    def validate(self, loader, is_test=False):
        """Validate or test the model"""
        self.model.eval()
        val_loss = 0.0

        all_preds = []
        all_probs = []
        all_labels = []
        all_methods = []
        all_datasets = []

        with torch.no_grad():
            for batch in tqdm(loader, desc="Testing" if is_test else "Validation"):
                x = batch['image'].to(self.device)
                y = batch['label'].to(self.device)
                methods = batch['method']
                datasets = batch['dataset']

                # Forward pass
                outputs = self.model(x)
                loss = self.criterion(outputs, y)

                # Calculate probabilities
                probs = F.softmax(outputs, dim=1)

                # Get predictions
                _, preds = torch.max(outputs, 1)

                # Update metrics
                val_loss += loss.item()

                # Store predictions and labels
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(probs[:, 1].cpu().numpy())  # Probability of fake class
                all_labels.extend(y.cpu().numpy())
                all_methods.extend(methods)
                all_datasets.extend(datasets)

        # Calculate metrics
        val_loss = val_loss / len(loader)
        metrics = self.calculate_metrics(all_labels, all_preds, all_probs)

        # Calculate per-method and per-dataset metrics
        method_metrics = self.calculate_group_metrics(all_labels, all_preds, all_probs, all_methods)
        dataset_metrics = self.calculate_group_metrics(all_labels, all_preds, all_probs, all_datasets)

        return val_loss, metrics, method_metrics, dataset_metrics

    def calculate_metrics(self, y_true, y_pred, y_prob):
        """Calculate various metrics"""
        metrics = {
            'accuracy': accuracy_score(y_true, y_pred),
            'precision': precision_score(y_true, y_pred, zero_division=0),
            'recall': recall_score(y_true, y_pred, zero_division=0),
            'f1': f1_score(y_true, y_pred, zero_division=0),
            'auc': roc_auc_score(y_true, y_prob) if len(set(y_true)) > 1 else 0.5
        }

        # Calculate confusion matrix
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()
        metrics['tn'] = tn
        metrics['fp'] = fp
        metrics['fn'] = fn
        metrics['tp'] = tp

        return metrics

    def calculate_group_metrics(self, y_true, y_pred, y_prob, groups):
        """Calculate metrics per group (method or dataset)"""
        group_metrics = {}

        for group in set(groups):
            # Get indices for this group
            indices = [i for i, g in enumerate(groups) if g == group]

            if len(indices) == 0:
                continue

            # Get labels, predictions, and probabilities for this group
            group_true = [y_true[i] for i in indices]
            group_pred = [y_pred[i] for i in indices]
            group_prob = [y_prob[i] for i in indices]

            # Calculate metrics
            try:
                group_metrics[group] = self.calculate_metrics(group_true, group_pred, group_prob)
                group_metrics[group]['count'] = len(indices)
            except Exception as e:
                print(f"Error calculating metrics for group {group}: {e}")
                group_metrics[group] = {'count': len(indices), 'error': str(e)}

        return group_metrics

    def train(self):
        """Main training loop"""
        print(f"Starting training for {self.config.num_epochs} epochs...")

        # Initialize metrics storage
        history = {
            'train_loss': [], 'train_acc': [],
            'val_loss': [], 'val_metrics': [],
            'method_metrics': [], 'dataset_metrics': []
        }

        for epoch in range(self.config.num_epochs):
            # Train
            train_loss, train_acc = self.train_epoch(epoch)

            # Validate
            val_loss, val_metrics, method_metrics, dataset_metrics = self.validate(self.val_loader)

            # Update epoch-based scheduler
            if isinstance(self.scheduler, optim.lr_scheduler.CosineAnnealingLR):
                self.scheduler.step()

            # Update history
            history['train_loss'].append(train_loss)
            history['train_acc'].append(train_acc)
            history['val_loss'].append(val_loss)
            history['val_metrics'].append(val_metrics)
            history['method_metrics'].append(method_metrics)
            history['dataset_metrics'].append(dataset_metrics)

            # Print metrics
            print(f"Epoch {epoch+1}/{self.config.num_epochs}:")
            print(f"  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
            print(f"  Val Loss: {val_loss:.4f}, Val AUC: {val_metrics['auc']:.4f}")

            # Save checkpoint if it's the best model so far
            if val_metrics['auc'] > self.best_val_auc:
                self.best_val_auc = val_metrics['auc']
                self.best_epoch = epoch

                # Save model
                checkpoint_path = os.path.join(self.config.checkpoint_dir, "best_model.pth")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_metrics': val_metrics,
                    'config': self.config.__dict__
                }, checkpoint_path)

                print(f"  New best model saved with AUC: {self.best_val_auc:.4f}")

            # Save regular checkpoint
            if (epoch + 1) % self.config.save_frequency == 0:
                checkpoint_path = os.path.join(self.config.checkpoint_dir, f"checkpoint_epoch_{epoch+1}.pth")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': self.model.state_dict(),
                    'optimizer_state_dict': self.optimizer.state_dict(),
                    'val_metrics': val_metrics,
                    'config': self.config.__dict__
                }, checkpoint_path)

        # Save final model
        checkpoint_path = os.path.join(self.config.checkpoint_dir, "final_model.pth")
        torch.save({
            'epoch': self.config.num_epochs - 1,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'val_metrics': history['val_metrics'][-1],
            'config': self.config.__dict__
        }, checkpoint_path)

        print(f"Training completed. Best model at epoch {self.best_epoch+1} with AUC: {self.best_val_auc:.4f}")

        return history

    def test(self, model_path=None):
        """Test the model on the test set"""
        if model_path:
            # Load the model
            checkpoint = torch.load(model_path, map_location=self.device)
            self.model.load_state_dict(checkpoint['model_state_dict'])

        # Run validation on test set
        test_loss, test_metrics, method_metrics, dataset_metrics = self.validate(self.test_loader, is_test=True)

        print(f"Test Results:")
        print(f"  Loss: {test_loss:.4f}")
        print(f"  Accuracy: {test_metrics['accuracy']:.4f}")
        print(f"  AUC: {test_metrics['auc']:.4f}")
        print(f"  F1 Score: {test_metrics['f1']:.4f}")

        # Save results
        results = {
            'test_loss': test_loss,
            'test_metrics': test_metrics,
            'method_metrics': method_metrics,
            'dataset_metrics': dataset_metrics
        }

        results_path = os.path.join(self.config.results_dir, "test_results.json")
        with open(results_path, 'w') as f:
            json.dump(results, f, indent=2)

        return results

In [None]:
# Cell 10: Visualization Functions
def plot_training_history(history):
    """Plot training history"""
    plt.figure(figsize=(15, 5))

    # Plot loss
    plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Val Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()

    # Plot accuracy
    plt.subplot(1, 3, 2)
    plt.plot(history['train_acc'], label='Train Acc')
    val_acc = [metrics['accuracy'] for metrics in history['val_metrics']]
    plt.plot(val_acc, label='Val Acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.legend()

    # Plot AUC
    plt.subplot(1, 3, 3)
    val_auc = [metrics['auc'] for metrics in history['val_metrics']]
    plt.plot(val_auc, label='Val AUC')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.title('Validation AUC')
    plt.legend()

    plt.tight_layout()
    plt.show()

def plot_confusion_matrix(y_true, y_pred, classes=['Real', 'Fake']):
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred, labels=[0, 1])
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.show()

def plot_roc_curve(y_true, y_score):
    """Plot ROC curve"""
    fpr, tpr, _ = roc_curve(y_true, y_score)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.3f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.show()

def plot_method_performance(method_metrics):
    """Plot performance by manipulation method"""
    methods = list(method_metrics.keys())
    accs = [method_metrics[m]['accuracy'] for m in methods]
    aucs = [method_metrics[m]['auc'] for m in methods]

    plt.figure(figsize=(12, 6))

    # Plot accuracies
    plt.subplot(1, 2, 1)
    plt.bar(methods, accs, color='skyblue')
    plt.ylim([0, 1])
    plt.xlabel('Method')
    plt.ylabel('Accuracy')
    plt.title('Accuracy by Manipulation Method')
    plt.xticks(rotation=45)

    # Plot AUCs
    plt.subplot(1, 2, 2)
    plt.bar(methods, aucs, color='lightgreen')
    plt.ylim([0, 1])
    plt.xlabel('Method')
    plt.ylabel('AUC')
    plt.title('AUC by Manipulation Method')
    plt.xticks(rotation=45)

    plt.tight_layout()
    plt.show()

def visualize_attention(model, image_path, transform, device):
    """Visualize attention maps"""
    # Load image
    image = Image.open(image_path).convert('RGB')
    orig_image = np.array(image)

    # Apply transform
    if isinstance(transform, A.Compose):
        transformed = transform(image=np.array(image))
        tensor = transformed['image']
    else:
        tensor = transform(image)

    # Add batch dimension
    tensor = tensor.unsqueeze(0).to(device)

    # Get predictions and attention weights
    model.eval()
    with torch.no_grad():
        logits, features, attention_weights = model(tensor, return_features=True)
        probs = F.softmax(logits, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        pred_prob = probs[0, pred_class].item()

        # Get attention weights
        attention = attention_weights.squeeze().cpu().numpy()

    # Create heatmap
    attention = attention.reshape(-1)
    attention = (attention - attention.min()) / (attention.max() - attention.min())

    # Resize image for visualization
    h, w = orig_image.shape[:2]
    resized = cv2.resize(orig_image, (224, 224))

    # Normalize attention for visualization
    attention_map = attention.reshape(1, -1)
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    attention_map = attention_map.reshape(14, 14)  # Adjust based on model patch size
    attention_map = cv2.resize(attention_map, (224, 224))

    # Apply colormap
    heatmap = cv2.applyColorMap(np.uint8(255 * attention_map), cv2.COLORMAP_JET)
    heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

    # Overlay heatmap on image
    alpha = 0.5
    overlay = alpha * heatmap + (1 - alpha) * resized

    # Plot
    plt.figure(figsize=(15, 5))

    plt.subplot(1, 3, 1)
    plt.imshow(resized)
    plt.title(f"Original Image")
    plt.axis('off')

    plt.subplot(1, 3, 2)
    plt.imshow(attention_map, cmap='hot')
    plt.title("Attention Map")
    plt.axis('off')

    plt.subplot(1, 3, 3)
    plt.imshow(overlay.astype(np.uint8))
    plt.title(f"Overlay (Pred: {'Fake' if pred_class == 1 else 'Real'}, {pred_prob:.2f})")
    plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
# Cell 11: Model Deployment Utils
def export_to_onnx(model, save_path, input_shape=(1, 3, 224, 224), device='cuda'):
    """Export model to ONNX format for deployment"""
    model.eval()
    dummy_input = torch.randn(input_shape).to(device)

    torch.onnx.export(
        model,
        dummy_input,
        save_path,
        export_params=True,
        opset_version=13,
        do_constant_folding=True,
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
    )

    print(f"Model exported to ONNX format: {save_path}")

def inference_pipeline(image_path, model, transform, device='cuda'):
    """Complete inference pipeline for a single image"""
    # Load image
    image = Image.open(image_path).convert('RGB')

    # Apply transform
    if isinstance(transform, A.Compose):
        transformed = transform(image=np.array(image))
        tensor = transformed['image']
    else:
        tensor = transform(image)

    # Add batch dimension
    tensor = tensor.unsqueeze(0).to(device)

    # Get predictions
    model.eval()
    with torch.no_grad():
        logits = model(tensor)
        probs = F.softmax(logits, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        pred_prob = probs[0, pred_class].item()

    result = {
        'path': image_path,
        'prediction': 'Fake' if pred_class == 1 else 'Real',
        'confidence': pred_prob,
        'fake_probability': probs[0, 1].item()
    }

    return result

In [None]:
# Cell 12: Main Execution Script
def main():
    """Main execution function"""
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Configuration
    config = Config()

    # Step 1: Dataset Preparation (optional, can be commented out if already done)
    """
    print("Step 1: Preparing datasets...")
    preparation = DeepfakeDatasetPreparation(config, device)

    # Process FaceForensics++
    preparation.prepare_ff_dataset()

    # Process Celeb-DF
    preparation.prepare_celebdf_dataset(version="v2")

    # Create dataset CSV files
    train_df, val_df, test_df = preparation.create_dataset_csv()
    """

    # Step 2: Load dataset CSVs
    print("Step 2: Loading dataset CSVs...")
    train_df = pd.read_csv(os.path.join(config.base_path, "train_data.csv"))
    val_df = pd.read_csv(os.path.join(config.base_path, "val_data.csv"))
    test_df = pd.read_csv(os.path.join(config.base_path, "test_data.csv"))

    print(f"Dataset loaded: {len(train_df)} train, {len(val_df)} validation, {len(test_df)} test samples")

    # Step 3: Create dataloaders
    print("Step 3: Creating dataloaders...")
    train_loader, val_loader, test_loader = create_dataloaders(config, train_df, val_df, test_df)

    # Step 4: Create model
    print("Step 4: Creating model...")
    model = DeiTForDeepfakeDetection(config).to(device)

    # Print model summary
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Model created with {total_params:,} total parameters, {trainable_params:,} trainable")

    # Step 5: Train model
    print("Step 5: Training model...")
    trainer = DeepfakeTrainer(model, train_loader, val_loader, test_loader, config, device)
    history = trainer.train()

    # Step 6: Evaluate on test set
    print("Step 6: Evaluating on test set...")
    best_model_path = os.path.join(config.checkpoint_dir, "best_model.pth")
    test_results = trainer.test(best_model_path)

    # Step 7: Visualize results
    print("Step 7: Visualizing results...")
    plot_training_history(history)

    # Get test predictions
    checkpoint = torch.load(best_model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Get predictions on test set
    all_preds = []
    all_probs = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Getting test predictions"):
            x = batch['image'].to(device)
            y = batch['label']

            outputs = model(x)
            probs = F.softmax(outputs, dim=1)
            _, preds = torch.max(outputs, 1)

            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_labels.extend(y.numpy())

    # Plot confusion matrix
    plot_confusion_matrix(all_labels, all_preds)

    # Plot ROC curve
    plot_roc_curve(all_labels, all_probs)

    # Plot method performance
    plot_method_performance(test_results['method_metrics'])

    # Step 8: Export model for deployment
    print("Step 8: Exporting model for deployment...")
    onnx_path = os.path.join(config.results_dir, "deepfake_detector.onnx")
    export_to_onnx(model, onnx_path, device=device)

    print("Complete pipeline execution finished successfully!")
    return model, history, test_results

if __name__ == "__main__":
    main()

In [None]:
# Cell 13: Additional Testing and Visualization
def test_on_sample_images(model, config, device='cuda'):
    """Test model on sample images and visualize results"""
    # Load the model if not already loaded
    if isinstance(model, str):
        model = DeiTForDeepfakeDetection(config).to(device)
        checkpoint = torch.load(model, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

    # Get transform
    _, transform = get_data_transforms(config)

    # Test directories
    test_dirs = [
        # FaceForensics++ samples
        os.path.join(config.ff_path, "processed/original"),
        os.path.join(config.ff_path, "processed/Deepfakes"),
        os.path.join(config.ff_path, "processed/Face2Face"),
        os.path.join(config.ff_path, "processed/FaceSwap"),
        os.path.join(config.ff_path, "processed/NeuralTextures"),
        # Celeb-DF samples
        os.path.join(config.celebdf_v2_path, "processed/real"),
        os.path.join(config.celebdf_v2_path, "processed/fake")
    ]

    # Results storage
    results = []

    # Test on random samples from each directory
    for test_dir in test_dirs:
        if not os.path.exists(test_dir):
            continue

        # Find all image files
        image_files = []
        for root, _, files in os.walk(test_dir):
            for file in files:
                if file.endswith((".jpg", ".png", ".jpeg")):
                    image_files.append(os.path.join(root, file))

        if len(image_files) == 0:
            continue

        # Select random samples
        samples = random.sample(image_files, min(5, len(image_files)))

        for image_path in samples:
            # Run inference
            result = inference_pipeline(image_path, model, transform, device)

            # Determine true label from path
            true_label = "Fake"
            if "original" in image_path.lower() or "real" in image_path.lower():
                true_label = "Real"

            # Add to results
            result['true_label'] = true_label
            result['dir'] = os.path.basename(test_dir)
            results.append(result)

            # Visualize attention
            visualize_attention(model, image_path, transform, device)

            # Print result
            print(f"Image: {os.path.basename(image_path)}")
            print(f"Directory: {result['dir']}")
            print(f"True label: {result['true_label']}")
            print(f"Prediction: {result['prediction']} (Confidence: {result['confidence']:.4f})")
            print(f"Fake probability: {result['fake_probability']:.4f}")
            print("-" * 50)

    # Calculate overall accuracy
    correct = sum(1 for r in results if r['prediction'] == r['true_label'])
    accuracy = correct / len(results)

    print(f"Overall accuracy on sample images: {accuracy:.4f} ({correct}/{len(results)})")

    return results

In [None]:
# Cell 14: Cross-Dataset Evaluation
def cross_dataset_evaluation(model, config, device='cuda'):
    """Evaluate model trained on one dataset on another dataset"""
    # Load the model if not already loaded
    if isinstance(model, str):
        model = DeiTForDeepfakeDetection(config).to(device)
        checkpoint = torch.load(model, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])

    # Get transform
    _, transform = get_data_transforms(config)

    # Load dataset CSVs
    all_data = pd.read_csv(os.path.join(config.base_path, "test_data.csv"))

    # Split by dataset
    ff_data = all_data[all_data['dataset'] == 'ff++']
    celebdf_data = all_data[all_data['dataset'] == 'celebdf']

    # Create datasets and dataloaders
    ff_dataset = DeepfakeDataset(ff_data, transform=transform)
    celebdf_dataset = DeepfakeDataset(celebdf_data, transform=transform)

    ff_loader = DataLoader(
        ff_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    celebdf_loader = DataLoader(
        celebdf_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        num_workers=config.num_workers,
        pin_memory=True
    )

    # Evaluate on both datasets
    print("Evaluating on FaceForensics++...")
    ff_results = {}
    if len(ff_data) > 0:
        trainer = DeepfakeTrainer(model, None, None, ff_loader, config, device)
        ff_loss, ff_metrics, ff_method_metrics, _ = trainer.validate(ff_loader, is_test=True)
        ff_results = {
            'loss': ff_loss,
            'metrics': ff_metrics,
            'method_metrics': ff_method_metrics
        }
        print(f"  Loss: {ff_loss:.4f}, AUC: {ff_metrics['auc']:.4f}")

    print("Evaluating on Celeb-DF...")
    celebdf_results = {}
    if len(celebdf_data) > 0:
        trainer = DeepfakeTrainer(model, None, None, celebdf_loader, config, device)
        celebdf_loss, celebdf_metrics, celebdf_method_metrics, _ = trainer.validate(celebdf_loader, is_test=True)
        celebdf_results = {
            'loss': celebdf_loss,
            'metrics': celebdf_metrics,
            'method_metrics': celebdf_method_metrics
        }
        print(f"  Loss: {celebdf_loss:.4f}, AUC: {celebdf_metrics['auc']:.4f}")

    # Compare results
    results = {
        'ff++': ff_results,
        'celebdf': celebdf_results
    }

    # Save results
    results_path = os.path.join(config.results_dir, "cross_dataset_results.json")
    with open(results_path, 'w') as f:
        json.dump(results, f, indent=2)

    # Plot

In [None]:
"""
Utility functions for deepfake detection project
"""

import os
import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
from facenet_pytorch import MTCNN
import albumentations as A
from albumentations.pytorch import ToTensorV2

def setup_environment():
    """Setup the environment for Google Colab"""
    # Mount Google Drive
    from google.colab import drive
    drive.mount('/content/drive')

    # Install required packages
    packages = [
        "timm==0.6.12",
        "opencv-python==4.8.0.76",
        "albumentations==1.3.1",
        "facenet-pytorch==2.5.3",
        "wandb==0.15.12",
        "onnx==1.14.1",
        "av==10.0.0",
        "ffmpeg-python==0.2.0",
        "scikit-learn==1.3.0",
        "tqdm==4.66.1"
    ]

    for package in packages:
        !pip install {package}

    # Set random seeds for reproducibility
    import random
    import numpy as np
    import torch

    random.seed(42)
    np.random.seed(42)
    torch.manual_seed(42)
    torch.cuda.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    print("Environment setup completed")

    return torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def detect_and_crop_face(image_path, face_detector=None, output_path=None, margin=0.3):
    """Detect and crop a face from an image"""
    if face_detector is None:
        face_detector = MTCNN(
            keep_all=False,
            post_process=False,
            min_face_size=48,
            device='cuda' if torch.cuda.is_available() else 'cpu'
        )

    # Load image
    if isinstance(image_path, str):
        if not os.path.exists(image_path):
            print(f"Image not found: {image_path}")
            return None

        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    else:
        image = image_path

    # Detect faces
    try:
        boxes, probs = face_detector.detect(image)

        if boxes is None or len(boxes) == 0:
            print(f"No face detected in {image_path}")
            return None

        # Get the face with highest probability
        box = boxes[0]

        # Add margin
        h, w = image.shape[:2]
        x1, y1, x2, y2 = box.tolist()

        dx = (x2 - x1) * margin
        dy = (y2 - y1) * margin

        x1 = max(0, int(x1 - dx))
        y1 = max(0, int(y1 - dy))
        x2 = min(w, int(x2 + dx))
        y2 = min(h, int(y2 + dy))

        # Crop face
        face_img = image[y1:y2, x1:x2]

        # Save if requested
        if output_path:
            os.makedirs(os.path.dirname(output_path), exist_ok=True)
            cv2.imwrite(output_path, cv2.cvtColor(face_img, cv2.COLOR_RGB2BGR))

        return face_img

    except Exception as e:
        print(f"Error detecting face in {image_path}: {e}")
        return None

def get_default_transforms(image_size=224):
    """Get default transforms for training and validation"""
    train_transform = A.Compose([
        A.RandomResizedCrop(image_size, image_size, scale=(0.8, 1.0)),
        A.HorizontalFlip(p=0.5),
        A.OneOf([
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
            A.CLAHE(p=0.5),
        ], p=0.5),
        A.OneOf([
            A.GaussNoise(p=0.5),
            A.GaussianBlur(blur_limit=3, p=0.5),
            A.ImageCompression(quality_lower=50, quality_upper=100, p=0.5),
        ], p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    val_transform = A.Compose([
        A.Resize(image_size, image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])

    return train_transform, val_transform

def predict_image(model, image_path, transform=None, device=None, face_detector=None):
    """Predict if an image is real or fake"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if transform is None:
        _, transform = get_default_transforms()

    # Load model if it's a path
    if isinstance(model, str):
        from model import DeiTForDeepfakeDetection
        model_instance = DeiTForDeepfakeDetection(num_classes=2).to(device)
        checkpoint = torch.load(model, map_location=device)
        model_instance.load_state_dict(checkpoint['model_state_dict'])
        model = model_instance

    model.eval()

    # Process image
    if face_detector is not None:
        # Detect and crop face
        face_img = detect_and_crop_face(image_path, face_detector)
        if face_img is None:
            return {
                'error': 'No face detected',
                'prediction': None,
                'confidence': 0.0,
                'probabilities': [0.0, 0.0]
            }
    else:
        # Just load the image
        try:
            face_img = Image.open(image_path).convert('RGB')
            face_img = np.array(face_img)
        except Exception as e:
            return {
                'error': f'Error loading image: {e}',
                'prediction': None,
                'confidence': 0.0,
                'probabilities': [0.0, 0.0]
            }

    # Apply transform
    try:
        transformed = transform(image=face_img)
        tensor = transformed['image'].unsqueeze(0).to(device)
    except Exception as e:
        return {
            'error': f'Error applying transform: {e}',
            'prediction': None,
            'confidence': 0.0,
            'probabilities': [0.0, 0.0]
        }

    # Get prediction
    with torch.no_grad():
        logits = model(tensor)
        probs = F.softmax(logits, dim=1)
        pred_class = torch.argmax(probs, dim=1).item()
        pred_prob = probs[0, pred_class].item()

    result = {
        'prediction': 'Fake' if pred_class == 1 else 'Real',
        'confidence': pred_prob,
        'probabilities': probs[0].cpu().numpy().tolist()
    }

    return result

def visualize_prediction(image_path, result, face_detector=None):
    """Visualize prediction result on an image"""
    # Load image
    try:
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    except Exception as e:
        print(f"Error loading image: {e}")
        return

    # Detect face if detector is provided
    if face_detector is not None:
        face_img = detect_and_crop_face(image_path, face_detector)
        if face_img is not None:
            image = face_img

    # Create figure
    plt.figure(figsize=(10, 6))

    # Display image
    plt.imshow(image)

    # Add prediction text
    if 'error' in result and result['error']:
        plt.title(f"Error: {result['error']}", color='red', fontsize=14)
    else:
        pred = result['prediction']
        conf = result['confidence']
        color = 'green' if pred == 'Real' else 'red'
        plt.title(f"Prediction: {pred} (Confidence: {conf:.4f})", color=color, fontsize=14)

        # Add probabilities
        if 'probabilities' in result:
            real_prob = result['probabilities'][0]
            fake_prob = result['probabilities'][1]
            plt.xlabel(f"Real: {real_prob:.4f}, Fake: {fake_prob:.4f}", fontsize=12)

    plt.axis('off')
    plt.tight_layout()
    plt.show()

def extract_frames_from_video(video_path, output_dir, num_frames=20):
    """Extract frames from a video file"""
    import av

    if not os.path.exists(video_path):
        print(f"Video not found: {video_path}")
        return False

    os.makedirs(output_dir, exist_ok=True)

    try:
        # Open the video
        container = av.open(video_path)
        video_stream = next(s for s in container.streams if s.type == 'video')

        # Get total frames
        n_frames = video_stream.frames
        if n_frames <= 0:
            # Estimate using duration and framerate
            duration = float(video_stream.duration * video_stream.time_base)
            n_frames = int(duration * video_stream.average_rate)

        if n_frames <= 0:
            print(f"Could not determine number of frames for {video_path}")
            n_frames = 1000  # Assume a reasonable number

        # Calculate interval
        interval = max(1, n_frames // num_frames)

        # Extract frames
        frame_count = 0
        saved_count = 0

        for frame in container.decode(video_stream):
            if frame_count % interval == 0:
                # Save frame
                img = frame.to_image()
                img_path = os.path.join(output_dir, f"frame_{saved_count:04d}.jpg")
                img.save(img_path)

                saved_count += 1
                if saved_count >= num_frames:
                    break

            frame_count += 1

        print(f"Extracted {saved_count} frames from {video_path}")
        return True

    except Exception as e:
        print(f"Error extracting frames from {video_path}: {e}")
        return False

def batch_process_frames(model, frames_dir, transform=None, device=None, face_detector=None):
    """Process all frames in a directory and return results"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    if face_detector is None:
        face_detector = MTCNN(
            keep_all=False,
            post_process=False,
            min_face_size=48,
            device=device
        )

    if transform is None:
        _, transform = get_default_transforms()

    # Get all image files
    image_files = []
    for f in os.listdir(frames_dir):
        if f.lower().endswith(('.jpg', '.jpeg', '.png')):
            image_files.append(os.path.join(frames_dir, f))

    if not image_files:
        print(f"No image files found in {frames_dir}")
        return []

    # Process each frame
    results = []
    for img_path in image_files:
        result = predict_image(model, img_path, transform, device, face_detector)
        result['path'] = img_path
        results.append(result)

    return results

def analyze_video_results(results):
    """Analyze results from multiple frames of a video"""
    if not results:
        return {
            'prediction': 'Unknown',
            'confidence': 0.0,
            'error': 'No results provided'
        }

    # Filter out error results
    valid_results = [r for r in results if 'error' not in r or not r['error']]

    if not valid_results:
        return {
            'prediction': 'Unknown',
            'confidence': 0.0,
            'error': 'No valid frame results'
        }

    # Count predictions
    real_count = sum(1 for r in valid_results if r['prediction'] == 'Real')
    fake_count = len(valid_results) - real_count

    # Average probabilities
    avg_real_prob = sum(r['probabilities'][0] for r in valid_results) / len(valid_results)
    avg_fake_prob = sum(r['probabilities'][1] for r in valid_results) / len(valid_results)

    # Determine final prediction
    if avg_fake_prob > avg_real_prob:
        final_pred = 'Fake'
        confidence = avg_fake_prob
    else:
        final_pred = 'Real'
        confidence = avg_real_prob

    # Create summary
    summary = {
        'prediction': final_pred,
        'confidence': confidence,
        'avg_real_prob': avg_real_prob,
        'avg_fake_prob': avg_fake_prob,
        'real_count': real_count,
        'fake_count': fake_count,
        'total_frames': len(valid_results),
        'detailed_results': valid_results
    }

    return summary

def predict_video(model, video_path, output_dir=None, device=None, num_frames=20):
    """Predict if a video is real or fake"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create temporary directory for frames
    if output_dir is None:
        import tempfile
        output_dir = tempfile.mkdtemp()
    else:
        os.makedirs(output_dir, exist_ok=True)

    frames_dir = os.path.join(output_dir, "frames")
    os.makedirs(frames_dir, exist_ok=True)

    # Extract frames
    success = extract_frames_from_video(video_path, frames_dir, num_frames)

    if not success:
        return {
            'prediction': 'Unknown',
            'confidence': 0.0,
            'error': 'Failed to extract frames'
        }

    # Initialize face detector
    face_detector = MTCNN(
        keep_all=False,
        post_process=False,
        min_face_size=48,
        device=device
    )

    # Get transform
    _, transform = get_default_transforms()

    # Process frames
    results = batch_process_frames(model, frames_dir, transform, device, face_detector)

    # Analyze results
    summary = analyze_video_results(results)
    summary['video_path'] = video_path

    return summary

def visualize_video_results(video_path, results, output_dir=None):
    """Visualize results from video analysis"""
    # Create figure
    plt.figure(figsize=(12, 8))

    # Add title
    if 'error' in results and results['error']:
        plt.suptitle(f"Error: {results['error']}", color='red', fontsize=16)
    else:
        pred = results['prediction']
        conf = results['confidence']
        color = 'green' if pred == 'Real' else 'red'
        plt.suptitle(f"Video Prediction: {pred} (Confidence: {conf:.4f})", color=color, fontsize=16)

    # Plot frame-by-frame results if available
    if 'detailed_results' in results and results['detailed_results']:
        # Get frame indices and fake probabilities
        frames = range(len(results['detailed_results']))
        fake_probs = [r['probabilities'][1] for r in results['detailed_results']]

        plt.plot(frames, fake_probs, 'r-', marker='o', label='Fake Probability')
        plt.axhline(y=0.5, color='gray', linestyle='--', alpha=0.5)

        plt.ylim(0, 1.05)
        plt.xlabel('Frame Index')
        plt.ylabel('Fake Probability')
        plt.grid(True, alpha=0.3)
        plt.legend()

        # Additional stats
        stats = (
            f"Real Frames: {results['real_count']}, "
            f"Fake Frames: {results['fake_count']}, "
            f"Total: {results['total_frames']}\n"
            f"Avg Real Prob: {results['avg_real_prob']:.4f}, "
            f"Avg Fake Prob: {results['avg_fake_prob']:.4f}"
        )
        plt.title(stats, fontsize=10)

    plt.tight_layout()

    # Save if output directory is provided
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
        video_name = os.path.basename(video_path)
        output_file = os.path.join(output_dir, f"{os.path.splitext(video_name)[0]}_results.png")
        plt.savefig(output_file, dpi=150)
        print(f"Results visualization saved to {output_file}")

    plt.show()

def main():
    """Main function for quick tests"""
    # Setup environment
    device = setup_environment()
    print(f"Using device: {device}")

    # Example usage
    from model import DeiTForDeepfakeDetection

    # Create model
    model = DeiTForDeepfakeDetection(num_classes=2).to(device)

    # Load weights if available
    model_path = "/content/drive/MyDrive/deepfake_detection/checkpoints/best_model.pth"
    if os.path.exists(model_path):
        checkpoint = torch.load(model_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
        print(f"Loaded model from {model_path}")
    else:
        print(f"Model not found at {model_path}, using initialized model")

    # Example prediction
    example_image = "/content/drive/MyDrive/deepfake_detection/sample_images/test.jpg"
    if os.path.exists(example_image):
        result = predict_image(model, example_image, device=device)
        visualize_prediction(example_image, result)

    # Example video prediction
    example_video = "/content/drive/MyDrive/deepfake_detection/sample_videos/test.mp4"
    if os.path.exists(example_video):
        results = predict_video(model, example_video, device=device)
        visualize_video_results(example_video, results)

if __name__ == "__main__":
    main()