<a href="https://colab.research.google.com/github/msrishav-28/ViT_model/blob/main/Deepfake_Detection_ViT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Create a fresh environment with compatible libraries for T4 GPU
# First, clean up any potential conflicts
!pip uninstall -y torch torchvision timm facenet-pytorch albumentations opencv-python pytorch-lightning wandb
!pip cache purge

# Install specific compatible versions optimized for T4 GPU
!pip install -q torch==2.0.1 torchvision==0.15.2 --extra-index-url https://download.pytorch.org/whl/cu118
!pip install -q timm==0.9.7 einops==0.6.1 scikit-learn==1.2.2 matplotlib==3.7.1 seaborn==0.12.2 tqdm==4.65.0
!pip install -q albumentations==1.3.1 facenet-pytorch==2.5.3 opencv-python==4.8.0.76 pytorchcv==0.0.67
!pip install -q gdown==4.7.1 pytorch-lightning==2.0.9 wandb==0.15.11

import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torch.cuda.amp import autocast, GradScaler
import timm
import numpy as np
import pandas as pd
from pathlib import Path
import cv2
import json
import random
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_curve, auc, confusion_matrix, classification_report
from tqdm.auto import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import gc
import warnings
warnings.filterwarnings('ignore')

# T4-specific optimizations
torch.backends.cuda.matmul.allow_tf32 = True  # Enable TensorFloat-32 for faster matrix operations on T4
torch.backends.cudnn.benchmark = True  # Optimize CUDNN for specific hardware
torch.backends.cudnn.deterministic = False  # Allow non-deterministic algorithms for speed

# Check GPU
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"🎉 GPU Detected: {torch.cuda.get_device_name(0)}")
    print(f"Memory Available: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Run quick tensor operation to verify CUDA is working
    test_tensor = torch.randn(1000, 1000, device=device)
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    result = torch.matmul(test_tensor, test_tensor)
    end.record()

    # Waits for everything to finish
    torch.cuda.synchronize()
    print(f"CUDA test operation completed in {start.elapsed_time(end):.2f} ms")
else:
    device = torch.device("cpu")
    print("⚠️ No GPU detected! Training will be extremely slow.")

# Clear memory
torch.cuda.empty_cache()
gc.collect()

Found existing installation: torch 2.6.0+cu124
Uninstalling torch-2.6.0+cu124:
  Successfully uninstalled torch-2.6.0+cu124
Found existing installation: torchvision 0.21.0+cu124
Uninstalling torchvision-0.21.0+cu124:
  Successfully uninstalled torchvision-0.21.0+cu124
Found existing installation: timm 1.0.15
Uninstalling timm-1.0.15:
  Successfully uninstalled timm-1.0.15
[0mFound existing installation: albumentations 2.0.6
Uninstalling albumentations-2.0.6:
  Successfully uninstalled albumentations-2.0.6
Found existing installation: opencv-python 4.11.0.86
Uninstalling opencv-python-4.11.0.86:
  Successfully uninstalled opencv-python-4.11.0.86
[0mFound existing installation: wandb 0.19.10
Uninstalling wandb-0.19.10:
  Successfully uninstalled wandb-0.19.10
[0mFiles removed: 0
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.3/2.3 GB[0m [31m432.7 kB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.1/6.1 MB[0m [31m88.2 


A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.0.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/usr/local/lib/python3.11/dist-packages/colab_kernel_launcher.py", line 37, in <module>
    ColabKernelApp.launch_instance()
  File "/usr/local/lib/python3.11/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.11/dist-packages/ipykernel/kernelapp.py", line 712, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.11/dist-package

AttributeError: `np.float_` was removed in the NumPy 2.0 release. Use `np.float64` instead.

In [None]:
# First, check if anything exists in the target directory
!ls -la /content/drive

# If there are files, remove them
!rm -rf /content/drive/*

# Try mounting again
from google.colab import drive
drive.mount('/content/drive')

# Create directories for our project
!mkdir -p /content/drive/MyDrive/deepfake_detection/checkpoints
!mkdir -p /content/drive/MyDrive/deepfake_detection/logs
!mkdir -p /content/drive/MyDrive/datasets/faceforensics
!mkdir -p /content/drive/MyDrive/datasets/celebdf

print("✅ Environment setup complete! T4 GPU optimizations enabled.")

ls: cannot access '/content/drive': No such file or directory
Mounted at /content/drive
✅ Environment setup complete! T4 GPU optimizations enabled.


In [None]:
# Download the official FaceForensics++ script
!wget -O /content/download.py https://github.com/ondyari/FaceForensics/raw/master/dataset/download.py

print("✅ Download script ready!")

--2025-05-05 17:42:39--  https://github.com/ondyari/FaceForensics/raw/master/dataset/download.py
Resolving github.com (github.com)... 140.82.113.4
Connecting to github.com (github.com)|140.82.113.4|:443... connected.
HTTP request sent, awaiting response... 404 Not Found
2025-05-05 17:42:39 ERROR 404: Not Found.

✅ Download script ready!


In [None]:
# Download 500 original videos
!python /content/download.py /content/drive/MyDrive/datasets/faceforensics --dataset original --compression c40 --type videos --server EU2 --num_videos 500

# Download 500 Deepfakes videos
!python /content/download.py /content/drive/MyDrive/datasets/faceforensics --dataset Deepfakes --compression c40 --type videos --server EU2 --num_videos 500

# Download 500 Face2Face videos
!python /content/download.py /content/drive/MyDrive/datasets/faceforensics --dataset Face2Face --compression c40 --type videos --server EU2 --num_videos 500

# Download 500 NeuralTextures videos
!python /content/download.py /content/drive/MyDrive/datasets/faceforensics --dataset NeuralTextures --compression c40 --type videos --server EU2 --num_videos 500

print("✅ FaceForensics++ subset downloaded!")

✅ FaceForensics++ subset downloaded!


In [None]:
# CelebDF download
print("Downloading CelebDF dataset...")
!gdown "1iLx76wsbi9ztw6oAjk65MTA2mopR6y-8" -O /content/drive/MyDrive/datasets/celebdf/celebdf.zip
!unzip -q /content/drive/MyDrive/datasets/celebdf/celebdf.zip -d /content/drive/MyDrive/datasets/celebdf/
!rm /content/drive/MyDrive/datasets/celebdf/celebdf.zip
print("✅ CelebDF download complete!")

Downloading CelebDF dataset...
Failed to retrieve file url:

	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses.
	Check FAQ in https://github.com/wkentaro/gdown?tab=readme-ov-file#faq.

You may still be able to access the file from the browser:

	https://drive.google.com/uc?id=1iLx76wsbi9ztw6oAjk65MTA2mopR6y-8

but Gdown can't. Please check connections and permissions.
unzip:  cannot find or open /content/drive/MyDrive/datasets/celebdf/celebdf.zip, /content/drive/MyDrive/datasets/celebdf/celebdf.zip.zip or /content/drive/MyDrive/datasets/celebdf/celebdf.zip.ZIP.
rm: cannot remove '/content/drive/MyDrive/datasets/celebdf/celebdf.zip': No such file or directory
✅ CelebDF download complete!


In [None]:
# Import face detection model
from facenet_pytorch import MTCNN

# Initialize the MTCNN detector
face_detector = MTCNN(
    image_size=224,
    margin=40,
    device=device,
    keep_all=True,
    post_process=True,
    select_largest=False
)

def extract_faces_from_frame(frame, min_face_size=100, confidence_threshold=0.95):
    """Extract faces from a single frame using MTCNN"""
    # Detect faces
    try:
        boxes, probs = face_detector.detect(frame, landmarks=False)

        # Check if any faces were detected
        if boxes is None:
            return []

        extracted_faces = []

        # Process each detected face
        for i, (box, prob) in enumerate(zip(boxes, probs)):
            # Filter by confidence and size
            if prob < confidence_threshold:
                continue

            # Get coordinates
            x1, y1, x2, y2 = box.astype(int)
            w, h = x2 - x1, y2 - y1

            # Filter small faces
            if w < min_face_size or h < min_face_size:
                continue

            # Extract the face with margin
            face = frame[max(0, y1):min(frame.shape[0], y2),
                         max(0, x1):min(frame.shape[1], x2)]

            # Resize to 224x224
            if face.size > 0:  # Ensure face was extracted properly
                face = cv2.resize(face, (224, 224))
                extracted_faces.append(face)

        return extracted_faces

    except Exception as e:
        print(f"Error in face extraction: {e}")
        return []

def extract_faces_from_video(video_path, max_frames=20, frame_interval=10):
    """Extract faces from video frames at regular intervals"""
    try:
        # Open video file
        cap = cv2.VideoCapture(video_path)
        if not cap.isOpened():
            print(f"Could not open video: {video_path}")
            return []

        # Get total frames
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Check if video is too short
        if total_frames < 10:  # Skip very short videos
            cap.release()
            return []

        # Calculate frame indices to process
        num_frames = min(max_frames, total_frames // frame_interval)
        frame_indices = [i * frame_interval for i in range(num_frames)]

        all_faces = []

        # Process each frame
        for frame_idx in frame_indices:
            # Set frame position
            cap.set(cv2.CAP_PROP_POS_FRAMES, frame_idx)
            ret, frame = cap.read()

            if not ret:
                continue

            # Convert to RGB (from BGR)
            frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)

            # Extract faces
            faces = extract_faces_from_frame(frame_rgb)
            all_faces.extend(faces)

            # Limit number of faces
            if len(all_faces) >= max_frames:
                all_faces = all_faces[:max_frames]
                break

        cap.release()
        return all_faces

    except Exception as e:
        print(f"Error processing video {video_path}: {e}")
        return []

print("✅ Face extraction utilities ready!")

Downloading CelebDF dataset...
Access denied with the following error:

 	Cannot retrieve the public link of the file. You may need to change
	the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

	 https://drive.google.com/uc?id=1iLx76wsbi9ztw6oAjk65MTA2mopR6y-8 

unzip:  cannot find or open /content/drive/MyDrive/datasets/celebdf/celebdf.zip, /content/drive/MyDrive/datasets/celebdf/celebdf.zip.zip or /content/drive/MyDrive/datasets/celebdf/celebdf.zip.ZIP.
rm: cannot remove '/content/drive/MyDrive/datasets/celebdf/celebdf.zip': No such file or directory
✅ CelebDF download complete!


In [None]:
class FaceForensicsDataset(Dataset):
    def __init__(self, root_dir, methods=['Deepfakes', 'Face2Face', 'NeuralTextures'],
                 compression='c40', transform=None, max_faces_per_video=10,
                 max_videos=None, use_extracted_faces=True):
        """
        FaceForensics++ dataset with face extraction
        Args:
            root_dir: Root directory of the FaceForensics++ dataset
            methods: List of manipulation methods to include
            compression: Compression level (c40 recommended for Colab)
            transform: Albumentations transforms
            max_faces_per_video: Maximum number of face images to extract per video
            max_videos: Maximum number of videos to use per class
            use_extracted_faces: If True, extract faces from frames
        """
        self.root_dir = Path(root_dir)
        self.methods = methods
        self.compression = compression
        self.transform = transform
        self.max_faces_per_video = max_faces_per_video
        self.max_videos = max_videos
        self.use_extracted_faces = use_extracted_faces

        # Prepare paths and labels
        self.samples = []  # Will contain (face_img_path, label, method_idx)

        # Add original (real) videos
        original_dir = self.root_dir / 'original_sequences/youtube' / compression / 'videos'
        if original_dir.exists():
            real_videos = list(original_dir.glob('*.mp4'))
            if self.max_videos:
                real_videos = real_videos[:self.max_videos]

            for video_path in tqdm(real_videos, desc="Processing original videos"):
                faces = extract_faces_from_video(str(video_path), max_frames=self.max_faces_per_video)
                for i, face in enumerate(faces):
                    # Save face image
                    face_filename = f"{video_path.stem}_face{i}.jpg"
                    face_path = self.root_dir / 'extracted_faces' / 'original' / face_filename
                    os.makedirs(face_path.parent, exist_ok=True)

                    # Save face image if not exists
                    if not face_path.exists():
                        cv2.imwrite(str(face_path), cv2.cvtColor(face, cv2.COLOR_RGB2BGR))

                    self.samples.append((str(face_path), 0, 0))  # 0 = real, 0 = method_idx for real

        # Add manipulated videos
        for method_idx, method in enumerate(self.methods, 1):  # Start from 1 (0 is reserved for real)
            fake_dir = self.root_dir / f'manipulated_sequences/{method}' / compression / 'videos'
            if fake_dir.exists():
                fake_videos = list(fake_dir.glob('*.mp4'))
                if self.max_videos:
                    fake_videos = fake_videos[:self.max_videos]

                for video_path in tqdm(fake_videos, desc=f"Processing {method} videos"):
                    faces = extract_faces_from_video(str(video_path), max_frames=self.max_faces_per_video)
                    for i, face in enumerate(faces):
                        # Save face image
                        face_filename = f"{video_path.stem}_face{i}.jpg"
                        face_path = self.root_dir / 'extracted_faces' / method / face_filename
                        os.makedirs(face_path.parent, exist_ok=True)

                        # Save face image if not exists
                        if not face_path.exists():
                            cv2.imwrite(str(face_path), cv2.cvtColor(face, cv2.COLOR_RGB2BGR))

                        self.samples.append((str(face_path), 1, method_idx))  # 1 = fake

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face_path, label, method_idx = self.samples[idx]

        # Load face image
        image = cv2.imread(face_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, label, method_idx

class CelebDFDataset(Dataset):
    def __init__(self, root_dir, transform=None, max_faces_per_video=10, max_videos=None):
        """
        CelebDF dataset with face extraction
        Args:
            root_dir: Root directory of the CelebDF dataset
            transform: Albumentations transforms
            max_faces_per_video: Maximum number of face images to extract per video
            max_videos: Maximum number of videos to use per class
        """
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.max_faces_per_video = max_faces_per_video
        self.max_videos = max_videos

        # Prepare paths and labels
        self.samples = []  # Will contain (face_img_path, label)

        # Celeb-real videos (real)
        real_dir = self.root_dir / 'Celeb-real'
        if real_dir.exists():
            real_videos = list(real_dir.glob('*.mp4'))
            if self.max_videos:
                real_videos = real_videos[:self.max_videos]

            for video_path in tqdm(real_videos, desc="Processing CelebDF real videos"):
                faces = extract_faces_from_video(str(video_path), max_frames=self.max_faces_per_video)
                for i, face in enumerate(faces):
                    # Save face image
                    face_filename = f"{video_path.stem}_face{i}.jpg"
                    face_path = self.root_dir / 'extracted_faces' / 'real' / face_filename
                    os.makedirs(face_path.parent, exist_ok=True)

                    # Save face image if not exists
                    if not face_path.exists():
                        cv2.imwrite(str(face_path), cv2.cvtColor(face, cv2.COLOR_RGB2BGR))

                    self.samples.append((str(face_path), 0))  # 0 = real

        # Celeb-synthesis videos (fake)
        fake_dir = self.root_dir / 'Celeb-synthesis'
        if fake_dir.exists():
            fake_videos = list(fake_dir.glob('*.mp4'))
            if self.max_videos:
                fake_videos = fake_videos[:self.max_videos]

            for video_path in tqdm(fake_videos, desc="Processing CelebDF fake videos"):
                faces = extract_faces_from_video(str(video_path), max_frames=self.max_faces_per_video)
                for i, face in enumerate(faces):
                    # Save face image
                    face_filename = f"{video_path.stem}_face{i}.jpg"
                    face_path = self.root_dir / 'extracted_faces' / 'fake' / face_filename
                    os.makedirs(face_path.parent, exist_ok=True)

                    # Save face image if not exists
                    if not face_path.exists():
                        cv2.imwrite(str(face_path), cv2.cvtColor(face, cv2.COLOR_RGB2BGR))

                    self.samples.append((str(face_path), 1))  # 1 = fake

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        face_path, label = self.samples[idx]

        # Load face image
        image = cv2.imread(face_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, label, -1  # -1 as method_idx placeholder (not used for CelebDF)

class CombinedDataset(Dataset):
    def __init__(self, ff_dataset=None, celebdf_dataset=None, balance=True):
        """
        Combines FaceForensics++ and CelebDF datasets
        Args:
            ff_dataset: FaceForensicsDataset
            celebdf_dataset: CelebDFDataset
            balance: If True, balance classes by undersampling
        """
        self.ff_dataset = ff_dataset
        self.celebdf_dataset = celebdf_dataset

        if ff_dataset is None and celebdf_dataset is None:
            raise ValueError("At least one dataset must be provided")

        # Combine samples
        self.samples = []

        # Add FaceForensics++ samples
        if ff_dataset is not None:
            for i in range(len(ff_dataset)):
                self.samples.append(('ff', i))

        # Add CelebDF samples
        if celebdf_dataset is not None:
            for i in range(len(celebdf_dataset)):
                self.samples.append(('celebdf', i))

        # Balance if needed
        if balance and ff_dataset is not None and celebdf_dataset is not None:
            self._balance_samples()

    def _balance_samples(self):
        # Count real and fake samples from each dataset
        ff_real, ff_fake = 0, 0
        for i in range(len(self.ff_dataset)):
            _, label, _ = self.ff_dataset[i]
            if label == 0:
                ff_real += 1
            else:
                ff_fake += 1

        celebdf_real, celebdf_fake = 0, 0
        for i in range(len(self.celebdf_dataset)):
            _, label, _ = self.celebdf_dataset[i]
            if label == 0:
                celebdf_real += 1
            else:
                celebdf_fake += 1

        # Create balanced samples
        real_samples = []
        fake_samples = []

        for sample in self.samples:
            dataset, idx = sample
            if dataset == 'ff':
                _, label, _ = self.ff_dataset[idx]
            else:
                _, label, _ = self.celebdf_dataset[idx]

            if label == 0:
                real_samples.append(sample)
            else:
                fake_samples.append(sample)

        # Undersample
        target_num_real = min(len(real_samples), len(fake_samples))
        target_num_fake = min(len(real_samples), len(fake_samples))

        real_samples = random.sample(real_samples, target_num_real)
        fake_samples = random.sample(fake_samples, target_num_fake)

        # Update samples
        self.samples = real_samples + fake_samples
        random.shuffle(self.samples)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        dataset, sample_idx = self.samples[idx]

        if dataset == 'ff':
            return self.ff_dataset[sample_idx]
        else:
            return self.celebdf_dataset[sample_idx]

print("✅ Dataset classes defined!")

In [None]:
def get_augmentation_pipeline(is_train=True):
    """Create augmentation pipeline optimized for deepfake detection"""

    if is_train:
        return A.Compose([
            # Face-specific augmentations
            A.RandomResizedCrop(224, 224, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
            A.HorizontalFlip(p=0.5),

            # Color augmentations to handle different video qualities
            A.OneOf([
                A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.7),
                A.ToGray(p=0.3),
            ], p=0.8),

            # Compression artifacts simulation
            A.OneOf([
                A.ImageCompression(quality_lower=60, quality_upper=100, p=0.5),
                A.GaussianBlur(blur_limit=(3, 7), p=0.5),
            ], p=0.3),

            # Noise augmentations
            A.OneOf([
                A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),
                A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5),
            ], p=0.3),

            # Normalize and convert to tensor
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])
    else:
        return A.Compose([
            A.Resize(224, 224),
            A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
            ToTensorV2()
        ])

print("✅ Augmentation pipeline defined!")

In [None]:
class DeepfakeDetector(nn.Module):
    def __init__(self, num_classes=2, num_manipulation_types=4,
                 pretrained=True, dropout=0.3, model_name='vit_base_patch16_224'):
        super().__init__()

        # Load pretrained ViT-Base
        self.backbone = timm.create_model(model_name,
                                        pretrained=pretrained,
                                        num_classes=0)  # Remove classification head

        # Get feature dimension
        feature_dim = self.backbone.num_features

        # Classification heads
        self.dropout = nn.Dropout(dropout)

        # Binary classification head (real/fake)
        self.binary_classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )

        # Manipulation type classification head
        self.manipulation_classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_manipulation_types)
        )

    def forward(self, x, return_features=False):
        # Extract features
        features = self.backbone(x)
        features = self.dropout(features)

        # Binary classification
        binary_logits = self.binary_classifier(features)

        # Manipulation type classification
        manipulation_logits = self.manipulation_classifier(features)

        if return_features:
            return binary_logits, manipulation_logits, features
        else:
            return binary_logits, manipulation_logits

print("✅ Model architecture defined!")

In [None]:
class EarlyStopping:
    def __init__(self, patience=7, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, val_loss):
        if self.best_loss is None:
            self.best_loss = val_loss
        elif val_loss > self.best_loss - self.min_delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = val_loss
            self.counter = 0

def train_epoch(model, train_loader, optimizer, criterion_binary,
                criterion_manipulation, scaler, device):
    """Train for one epoch with mixed precision"""
    model.train()
    running_loss = 0.0
    correct_binary = 0
    correct_manipulation = 0
    total = 0

    pbar = tqdm(train_loader, desc='Training', leave=False)
    for images, labels, manipulation_types in pbar:
        images = images.to(device)
        labels = labels.to(device)
        manipulation_types = manipulation_types.to(device)

        optimizer.zero_grad()

        # Mixed precision training
        with autocast():
            binary_logits, manipulation_logits = model(images)
            loss_binary = criterion_binary(binary_logits, labels)

            # Only compute manipulation loss for FaceForensics samples
            valid_manip = manipulation_types >= 0
            if valid_manip.sum() > 0:
                loss_manipulation = criterion_manipulation(
                    manipulation_logits[valid_manip],
                    manipulation_types[valid_manip]
                )
                loss = loss_binary + 0.5 * loss_manipulation
            else:
                loss = loss_binary

        # Backward pass with gradient scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        # Statistics
        running_loss += loss.item()
        _, predicted_binary = binary_logits.max(1)
        total += labels.size(0)
        correct_binary += predicted_binary.eq(labels).sum().item()

        # Update progress bar
        pbar.set_postfix({
            'loss': f'{running_loss/(pbar.n+1):.4f}',
            'acc': f'{100.*correct_binary/total:.2f}%',
        })

    epoch_loss = running_loss / len(train_loader)
    epoch_acc_binary = 100. * correct_binary / total

    return epoch_loss, epoch_acc_binary

def validate(model, val_loader, criterion_binary, criterion_manipulation, device):
    """Validation function"""
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_predictions = []
    all_probabilities = []

    with torch.no_grad():
        for images, labels, manipulation_types in tqdm(val_loader, desc='Validation', leave=False):
            images = images.to(device)
            labels = labels.to(device)
            manipulation_types = manipulation_types.to(device)

            binary_logits, manipulation_logits = model(images)
            loss_binary = criterion_binary(binary_logits, labels)

            # Only compute manipulation loss for FaceForensics samples
            valid_manip = manipulation_types >= 0
            if valid_manip.sum() > 0:
                loss_manipulation = criterion_manipulation(
                    manipulation_logits[valid_manip],
                    manipulation_types[valid_manip]
                )
                loss = loss_binary + 0.5 * loss_manipulation
            else:
                loss = loss_binary

            running_loss += loss.item()

            # Store predictions for metrics
            probabilities = F.softmax(binary_logits, dim=1)[:, 1]
            _, predicted = binary_logits.max(1)

            all_labels.extend(labels.cpu().numpy())
            all_predictions.extend(predicted.cpu().numpy())
            all_probabilities.extend(probabilities.cpu().numpy())

    # Calculate metrics
    val_loss = running_loss / len(val_loader)
    val_acc = 100. * np.mean(np.array(all_labels) == np.array(all_predictions))

    # ROC AUC
    fpr, tpr, _ = roc_curve(all_labels, all_probabilities)
    roc_auc = auc(fpr, tpr)

    return val_loss, val_acc, roc_auc, all_labels, all_predictions, all_probabilities

print("✅ Training utilities defined!")

In [None]:
def plot_confusion_matrix(labels, predictions, classes=['Real', 'Fake']):
    """Plot confusion matrix"""
    cm = confusion_matrix(labels, predictions)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=classes, yticklabels=classes)
    plt.xlabel('Predicted')
    plt.ylabel('Actual')
    plt.title('Confusion Matrix')
    plt.savefig('/content/drive/MyDrive/deepfake_detection/confusion_matrix.png')
    plt.show()

def plot_roc_curve(labels, probabilities):
    """Plot ROC curve"""
    fpr, tpr, _ = roc_curve(labels, probabilities)
    roc_auc = auc(fpr, tpr)

    plt.figure(figsize=(8, 6))
    plt.plot(fpr, tpr, color='darkorange', lw=2,
             label=f'ROC curve (AUC = {roc_auc:.2f})')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")
    plt.grid(True)
    plt.savefig('/content/drive/MyDrive/deepfake_detection/roc_curve.png')
    plt.show()

print("✅ Visualization functions defined!")

In [None]:
def main():
    # Configuration
    config = {
        'batch_size': 32,  # Adjust based on GPU memory
        'num_epochs': 30,
        'learning_rate': 1e-4,
        'weight_decay': 0.01,
        'num_workers': 2,  # Colab works best with 2 workers
        'patience': 7,
        'mixed_precision': True,
        'gradient_accumulation_steps': 2,  # For larger effective batch size
        'max_videos_per_class': 500,  # Set to 500 for all methods
        'max_faces_per_video': 10,
        'train_val_split': 0.8,  # 80% training, 20% validation
    }

    # Create augmentation pipelines
    train_transform = get_augmentation_pipeline(is_train=True)
    val_transform = get_augmentation_pipeline(is_train=False)

    # Create datasets
    print("Loading FaceForensics++ dataset...")
    ff_dataset = FaceForensicsDataset(
        root_dir='/content/drive/MyDrive/datasets/faceforensics',
        methods=['Deepfakes', 'Face2Face', 'NeuralTextures'],  # Include Face2Face
        compression='c40',
        transform=None,  # Will be applied later
        max_faces_per_video=config['max_faces_per_video'],
        max_videos=config['max_videos_per_class']
    )

    print("Loading CelebDF dataset...")
    celebdf_dataset = CelebDFDataset(
        root_dir='/content/drive/MyDrive/datasets/celebdf',
        transform=None,  # Will be applied later
        max_faces_per_video=config['max_faces_per_video'],
        max_videos=config['max_videos_per_class']
    )

    # Split datasets into train and validation
    ff_samples = ff_dataset.samples
    celebdf_samples = celebdf_dataset.samples

    random.shuffle(ff_samples)
    random.shuffle(celebdf_samples)

    train_ff_samples = ff_samples[:int(len(ff_samples) * config['train_val_split'])]
    val_ff_samples = ff_samples[int(len(ff_samples) * config['train_val_split']):]

    train_celebdf_samples = celebdf_samples[:int(len(celebdf_samples) * config['train_val_split'])]
    val_celebdf_samples = celebdf_samples[int(len(celebdf_samples) * config['train_val_split']):]

    # Create train and validation datasets with proper transforms
    train_ff_dataset = FaceForensicsDataset(root_dir='/content/drive/MyDrive/datasets/faceforensics')
    train_ff_dataset.samples = train_ff_samples
    train_ff_dataset.transform = train_transform

    val_ff_dataset = FaceForensicsDataset(root_dir='/content/drive/MyDrive/datasets/faceforensics')
    val_ff_dataset.samples = val_ff_samples
    val_ff_dataset.transform = val_transform

    train_celebdf_dataset = CelebDFDataset(root_dir='/content/drive/MyDrive/datasets/celebdf')
    train_celebdf_dataset.samples = train_celebdf_samples
    train_celebdf_dataset.transform = train_transform

    val_celebdf_dataset = CelebDFDataset(root_dir='/content/drive/MyDrive/datasets/celebdf')
    val_celebdf_dataset.samples = val_celebdf_samples
    val_celebdf_dataset.transform = val_transform

    # Combine datasets
    train_dataset = CombinedDataset(train_ff_dataset, train_celebdf_dataset, balance=True)
    val_dataset = CombinedDataset(val_ff_dataset, val_celebdf_dataset, balance=True)

    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=config['num_workers'],
        pin_memory=True,
        drop_last=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=config['num_workers'],
        pin_memory=True
    )

    # Initialize model
    model = DeepfakeDetector(
        num_classes=2,  # Binary: real/fake
        num_manipulation_types=4,  # Real, Deepfakes, Face2Face, NeuralTextures
        pretrained=True,
        dropout=0.3
    ).to(device)

    # Loss functions
    criterion_binary = nn.CrossEntropyLoss()
    criterion_manipulation = nn.CrossEntropyLoss()

    # Optimizer with weight decay
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config['learning_rate'],
        weight_decay=config['weight_decay']
    )

    # Learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=config['num_epochs']
    )

    # Mixed precision scaler
    scaler = GradScaler() if config['mixed_precision'] else None

    # Early stopping
    early_stopping = EarlyStopping(patience=config['patience'])

    # Training metrics
    train_losses, val_losses = [], []
    train_accs, val_accs = [], []
    val_aucs = []
    best_val_auc = 0.0

    # Training loop
    for epoch in range(config['num_epochs']):
        print(f"\nEpoch {epoch+1}/{config['num_epochs']}")

        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, optimizer, criterion_binary,
            criterion_manipulation, scaler, device
        )

        # Validate
        val_loss, val_acc, val_auc, val_labels, val_predictions, val_probabilities = validate(
            model, val_loader, criterion_binary, criterion_manipulation, device
        )

        # Step scheduler
        scheduler.step()

        # Save metrics
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_accs.append(train_acc)
        val_accs.append(val_acc)
        val_aucs.append(val_auc)

        # Print epoch results
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, Val AUC: {val_auc:.4f}")

        # Save best model
        if val_auc > best_val_auc:
            best_val_auc = val_auc
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_auc': val_auc,
                'val_acc': val_acc,
            }, '/content/drive/MyDrive/deepfake_detection/checkpoints/best_model.pth')
            print(f"Saved best model with AUC: {val_auc:.4f}")

        # Early stopping check
        early_stopping(val_loss)
        if early_stopping.early_stop:
            print("Early stopping triggered!")
            break

    # Plot training curves
    plot_training_curves(train_losses, val_losses, train_accs, val_accs, val_aucs)

    # Plot final evaluation metrics
    plot_confusion_matrix(val_labels, val_predictions)
    plot_roc_curve(val_labels, val_probabilities)

    # Print classification report
    print("\nClassification Report:")
    print(classification_report(val_labels, val_predictions,
                               target_names=['Real', 'Fake']))

    print(f"\nTraining completed! Best validation AUC: {best_val_auc:.4f}")

print("✅ Main training function defined!")

In [None]:
def inference(model_path, image_path, device):
    """Inference on a single image"""
    # Load model
    model = DeepfakeDetector(num_classes=2, num_manipulation_types=4)  # Updated for 4 types
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    # Load and preprocess image
    transform = get_augmentation_pipeline(is_train=False)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Extract face if needed
    faces = extract_faces_from_frame(image)
    if len(faces) == 0:
        print("No face detected in the image!")
        return None

    # Use the first face
    face = faces[0]

    # Apply transformations
    augmented = transform(image=face)
    image_tensor = augmented['image'].unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        binary_logits, manipulation_logits = model(image_tensor)
        binary_probs = F.softmax(binary_logits, dim=1)
        manipulation_probs = F.softmax(manipulation_logits, dim=1)

        is_fake = binary_probs[0][1] > 0.5
        confidence = binary_probs[0][1].item() if is_fake else binary_probs[0][0].item()

        manipulation_idx = manipulation_probs[0].argmax().item()
        manipulation_types = ['Real', 'Deepfakes', 'Face2Face', 'NeuralTextures']  # Updated list
        manipulation_type = manipulation_types[manipulation_idx]

    result = {
        'is_fake': is_fake,
        'confidence': confidence,
        'manipulation_type': manipulation_type if is_fake else 'N/A',
        'manipulation_confidence': manipulation_probs[0][manipulation_idx].item()
    }

    return result

def visualize_result(image_path, result):
    """Visualize inference result"""
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Extract face
    faces = extract_faces_from_frame(image)
    if len(faces) == 0:
        print("No face detected in the image!")
        return

    face = faces[0]

    plt.figure(figsize=(12, 6))

    # Display face
    plt.subplot(1, 2, 1)
    plt.imshow(face)
    plt.title("Input Face")
    plt.axis('off')

    # Display result
    plt.subplot(1, 2, 2)

    # Create a simple gauge chart
    fake_prob = result['confidence'] if result['is_fake'] else 1 - result['confidence']

    # Create gauge chart
    fig, ax = plt.subplot_mosaic([['gauge']], figsize=(6, 6))

    # Set up a basic gauge
    gauge = plt.Circle((0.5, 0.5), 0.4, color='lightgray', fill=True)
    ax['gauge'].add_artist(gauge)

    # Add a "needle" showing the probability
    angle = (1 - fake_prob) * np.pi
    x = 0.5 + 0.35 * np.cos(angle)
    y = 0.5 + 0.35 * np.sin(angle)
    ax['gauge'].plot([0.5, x], [0.5, y], color='red', linewidth=3)

    # Add gauge labels
    ax['gauge'].text(0.15, 0.5, "REAL", fontsize=14, ha='center', va='center', color='green')
    ax['gauge'].text(0.85, 0.5, "FAKE", fontsize=14, ha='center', va='center', color='red')

    # Add confidence text
    if result['is_fake']:
        text = f"FAKE ({fake_prob:.1%} confidence)\nType: {result['manipulation_type']}"
        color = 'red'
    else:
        text = f"REAL ({(1-fake_prob):.1%} confidence)"
        color = 'green'

    ax['gauge'].text(0.5, 0.25, text, fontsize=16, ha='center', va='center',
                   color=color, weight='bold')

    ax['gauge'].set_xlim(0, 1)
    ax['gauge'].set_ylim(0, 1)
    ax['gauge'].axis('off')

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/deepfake_detection/result.png')
    plt.show()

print("✅ Inference functions defined!")

In [None]:
def inference(model_path, image_path, device):
    """Inference on a single image"""
    # Load model
    model = DeepfakeDetector(num_classes=2, num_manipulation_types=3)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()

    # Load and preprocess image
    transform = get_augmentation_pipeline(is_train=False)
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Extract face if needed
    faces = extract_faces_from_frame(image)
    if len(faces) == 0:
        print("No face detected in the image!")
        return None

    # Use the first face
    face = faces[0]

    # Apply transformations
    augmented = transform(image=face)
    image_tensor = augmented['image'].unsqueeze(0).to(device)

    # Inference
    with torch.no_grad():
        binary_logits, manipulation_logits = model(image_tensor)
        binary_probs = F.softmax(binary_logits, dim=1)
        manipulation_probs = F.softmax(manipulation_logits, dim=1)

        is_fake = binary_probs[0][1] > 0.5
        confidence = binary_probs[0][1].item() if is_fake else binary_probs[0][0].item()

        manipulation_idx = manipulation_probs[0].argmax().item()
        manipulation_types = ['Real', 'Deepfakes', 'NeuralTextures']
        manipulation_type = manipulation_types[manipulation_idx]

    result = {
        'is_fake': is_fake,
        'confidence': confidence,
        'manipulation_type': manipulation_type if is_fake else 'N/A',
        'manipulation_confidence': manipulation_probs[0][manipulation_idx].item()
    }

    return result

def visualize_result(image_path, result):
    """Visualize inference result"""
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Extract face
    faces = extract_faces_from_frame(image)
    if len(faces) == 0:
        print("No face detected in the image!")
        return

    face = faces[0]

    plt.figure(figsize=(12, 6))

    # Display face
    plt.subplot(1, 2, 1)
    plt.imshow(face)
    plt.title("Input Face")
    plt.axis('off')

    # Display result
    plt.subplot(1, 2, 2)

    # Create a simple gauge chart
    fake_prob = result['confidence'] if result['is_fake'] else 1 - result['confidence']

    # Create gauge chart
    fig, ax = plt.subplot_mosaic([['gauge']], figsize=(6, 6))

    # Set up a basic gauge
    gauge = plt.Circle((0.5, 0.5), 0.4, color='lightgray', fill=True)
    ax['gauge'].add_artist(gauge)

    # Add a "needle" showing the probability
    angle = (1 - fake_prob) * np.pi
    x = 0.5 + 0.35 * np.cos(angle)
    y = 0.5 + 0.35 * np.sin(angle)
    ax['gauge'].plot([0.5, x], [0.5, y], color='red', linewidth=3)

    # Add gauge labels
    ax['gauge'].text(0.15, 0.5, "REAL", fontsize=14, ha='center', va='center', color='green')
    ax['gauge'].text(0.85, 0.5, "FAKE", fontsize=14, ha='center', va='center', color='red')

    # Add confidence text
    if result['is_fake']:
        text = f"FAKE ({fake_prob:.1%} confidence)\nType: {result['manipulation_type']}"
        color = 'red'
    else:
        text = f"REAL ({(1-fake_prob):.1%} confidence)"
        color = 'green'

    ax['gauge'].text(0.5, 0.25, text, fontsize=16, ha='center', va='center',
                   color=color, weight='bold')

    ax['gauge'].set_xlim(0, 1)
    ax['gauge'].set_ylim(0, 1)
    ax['gauge'].axis('off')

    plt.tight_layout()
    plt.savefig('/content/drive/MyDrive/deepfake_detection/result.png')
    plt.show()

print("✅ Inference functions defined!")

In [None]:
if __name__ == "__main__":
    # Clear GPU cache
    torch.cuda.empty_cache()
    gc.collect()

    # Start training
    main()

In [None]:
# Example inference on a test image
def test_on_image(image_path):
    """Test the model on a sample image"""
    model_path = '/content/drive/MyDrive/deepfake_detection/checkpoints/best_model.pth'

    # Check if model exists
    if not os.path.exists(model_path):
        print("Model not found! Please train the model first.")
        return

    # Run inference
    result = inference(model_path, image_path, device)

    if result is None:
        print("Could not analyze the image.")
        return

    # Print result
    print("=" * 50)
    print(f"Result: {'FAKE' if result['is_fake'] else 'REAL'}")
    print(f"Confidence: {result['confidence']:.2%}")

    if result['is_fake']:
        print(f"Manipulation type: {result['manipulation_type']}")
        print(f"Manipulation confidence: {result['manipulation_confidence']:.2%}")
    print("=" * 50)

    # Visualize
    visualize_result(image_path, result)

# To test on your own image, uncomment and provide a path:
# test_on_image('/content/drive/MyDrive/your_test_image.jpg')