<a href="https://colab.research.google.com/github/msrishav-28/Swin-Model/blob/main/Deepfake_Detection_Swin.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Save the download script to a file
!mkdir -p /content/dataset_setup
with open('/content/dataset_setup/download_faceforensics.py', 'w') as f:
    f.write('''#!/usr/bin/env python
""" Downloads FaceForensics++ and Deep Fake Detection public data release
Example usage:
    see -h or https://github.com/ondyari/FaceForensics
"""
# -*- coding: utf-8 -*-
import argparse
import os
import urllib
import urllib.request
import tempfile
import time
import sys
import json
import random
from tqdm import tqdm
from os.path import join


# URLs and filenames
FILELIST_URL = 'misc/filelist.json'
DEEPFEAKES_DETECTION_URL = 'misc/deepfake_detection_filenames.json'
DEEPFAKES_MODEL_NAMES = ['decoder_A.h5', 'decoder_B.h5', 'encoder.h5',]

# Parameters
DATASETS = {
    'original_youtube_videos': 'misc/downloaded_youtube_videos.zip',
    'original_youtube_videos_info': 'misc/downloaded_youtube_videos_info.zip',
    'original': 'original_sequences/youtube',
    'DeepFakeDetection_original': 'original_sequences/actors',
    'Deepfakes': 'manipulated_sequences/Deepfakes',
    'DeepFakeDetection': 'manipulated_sequences/DeepFakeDetection',
    'Face2Face': 'manipulated_sequences/Face2Face',
    'FaceShifter': 'manipulated_sequences/FaceShifter',
    'FaceSwap': 'manipulated_sequences/FaceSwap',
    'NeuralTextures': 'manipulated_sequences/NeuralTextures'
    }
ALL_DATASETS = ['original', 'DeepFakeDetection_original', 'Deepfakes',
                'DeepFakeDetection', 'Face2Face', 'FaceShifter', 'FaceSwap',
                'NeuralTextures']
COMPRESSION = ['raw', 'c23', 'c40']
TYPE = ['videos', 'masks', 'models']
SERVERS = ['EU', 'EU2', 'CA']


def parse_args():
    parser = argparse.ArgumentParser(
        description='Downloads FaceForensics v2 public data release.',
        formatter_class=argparse.ArgumentDefaultsHelpFormatter
    )
    parser.add_argument('output_path', type=str, help='Output directory.')
    parser.add_argument('-d', '--dataset', type=str, default='all',
                        help='Which dataset to download, either pristine or '
                             'manipulated data or the downloaded youtube '
                             'videos.',
                        choices=list(DATASETS.keys()) + ['all']
                        )
    parser.add_argument('-c', '--compression', type=str, default='raw',
                        help='Which compression degree. All videos '
                             'have been generated with h264 with a varying '
                             'codec. Raw (c0) videos are lossless compressed.',
                        choices=COMPRESSION
                        )
    parser.add_argument('-t', '--type', type=str, default='videos',
                        help='Which file type, i.e. videos, masks, for our '
                             'manipulation methods, models, for Deepfakes.',
                        choices=TYPE
                        )
    parser.add_argument('-n', '--num_videos', type=int, default=None,
                        help='Select a number of videos number to '
                             "download if you don't want to download the full"
                             ' dataset.')
    parser.add_argument('--server', type=str, default='EU',
                        help='Server to download the data from. If you '
                             'encounter a slow download speed, consider '
                             'changing the server.',
                        choices=SERVERS
                        )
    args = parser.parse_args()

    # URLs
    server = args.server
    if server == 'EU':
        server_url = 'http://canis.vc.in.tum.de:8100/'
    elif server == 'EU2':
        server_url = 'http://kaldir.vc.in.tum.de/faceforensics/'
    elif server == 'CA':
        server_url = 'http://falas.cmpt.sfu.ca:8100/'
    else:
        raise Exception('Wrong server name. Choices: {}'.format(str(SERVERS)))
    args.tos_url = server_url + 'webpage/FaceForensics_TOS.pdf'
    args.base_url = server_url + 'v3/'
    args.deepfakes_model_url = server_url + 'v3/manipulated_sequences/' + \
                               'Deepfakes/models/'

    return args


def download_files(filenames, base_url, output_path, report_progress=True):
    os.makedirs(output_path, exist_ok=True)
    if report_progress:
        filenames = tqdm(filenames)
    for filename in filenames:
        download_file(base_url + filename, join(output_path, filename))


def reporthook(count, block_size, total_size):
    global start_time
    if count == 0:
        start_time = time.time()
        return
    duration = time.time() - start_time
    progress_size = int(count * block_size)
    speed = int(progress_size / (1024 * duration))
    percent = int(count * block_size * 100 / total_size)
    sys.stdout.write("\\rProgress: %d%%, %d MB, %d KB/s, %d seconds passed" %
                     (percent, progress_size / (1024 * 1024), speed, duration))
    sys.stdout.flush()


def download_file(url, out_file, report_progress=False):
    out_dir = os.path.dirname(out_file)
    if not os.path.isfile(out_file):
        fh, out_file_tmp = tempfile.mkstemp(dir=out_dir)
        f = os.fdopen(fh, 'w')
        f.close()
        if report_progress:
            urllib.request.urlretrieve(url, out_file_tmp,
                                       reporthook=reporthook)
        else:
            urllib.request.urlretrieve(url, out_file_tmp)
        os.rename(out_file_tmp, out_file)
    else:
        tqdm.write('WARNING: skipping download of existing file ' + out_file)


def main(args):
    # TOS
    print('By pressing any key to continue you confirm that you have agreed '\\
          'to the FaceForensics terms of use as described at:')
    print(args.tos_url)
    print('***')
    print('Press any key to continue, or CTRL-C to exit.')
    _ = input('')

    # Extract arguments
    c_datasets = [args.dataset] if args.dataset != 'all' else ALL_DATASETS
    c_type = args.type
    c_compression = args.compression
    num_videos = args.num_videos
    output_path = args.output_path
    os.makedirs(output_path, exist_ok=True)

    # Check for special dataset cases
    for dataset in c_datasets:
        dataset_path = DATASETS[dataset]
        # Special cases
        if 'original_youtube_videos' in dataset:
            # Here we download the original youtube videos zip file
            print('Downloading original youtube videos.')
            if not 'info' in dataset_path:
                print('Please be patient, this may take a while (~40gb)')
                suffix = ''
            else:
                suffix = 'info'
            download_file(args.base_url + '/' + dataset_path,
                          out_file=join(output_path,
                                        'downloaded_videos{}.zip'.format(
                                            suffix)),
                          report_progress=True)
            return

        # Else: regular datasets
        print('Downloading {} of dataset "{}"'.format(
            c_type, dataset_path
        ))

        # Get filelists and video lenghts list from server
        if 'DeepFakeDetection' in dataset_path or 'actors' in dataset_path:
            filepaths = json.loads(urllib.request.urlopen(args.base_url + '/' +
                DEEPFEAKES_DETECTION_URL).read().decode("utf-8"))
            if 'actors' in dataset_path:
                filelist = filepaths['actors']
            else:
                filelist = filepaths['DeepFakesDetection']
        elif 'original' in dataset_path:
            # Load filelist from server
            file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
                FILELIST_URL).read().decode("utf-8"))
            filelist = []
            for pair in file_pairs:
                filelist += pair
        else:
            # Load filelist from server
            file_pairs = json.loads(urllib.request.urlopen(args.base_url + '/' +
                FILELIST_URL).read().decode("utf-8"))
            # Get filelist
            filelist = []
            for pair in file_pairs:
                filelist.append('_'.join(pair))
                if c_type != 'models':
                    filelist.append('_'.join(pair[::-1]))
        # Maybe limit number of videos for download
        if num_videos is not None and num_videos > 0:
            print('Downloading the first {} videos'.format(num_videos))
            filelist = filelist[:num_videos]

        # Server and local paths
        dataset_videos_url = args.base_url + '{}/{}/{}/'.format(
            dataset_path, c_compression, c_type)
        dataset_mask_url = args.base_url + '{}/{}/videos/'.format(
            dataset_path, 'masks', c_type)

        if c_type == 'videos':
            dataset_output_path = join(output_path, dataset_path, c_compression,
                                       c_type)
            print('Output path: {}'.format(dataset_output_path))
            filelist = [filename + '.mp4' for filename in filelist]
            download_files(filelist, dataset_videos_url, dataset_output_path)
        elif c_type == 'masks':
            dataset_output_path = join(output_path, dataset_path, c_type,
                                       'videos')
            print('Output path: {}'.format(dataset_output_path))
            if 'original' in dataset:
                if args.dataset != 'all':
                    print('Only videos available for original data. Aborting.')
                    return
                else:
                    print('Only videos available for original data. '
                          'Skipping original.\\n')
                    continue
            if 'FaceShifter' in dataset:
                print('Masks not available for FaceShifter. Aborting.')
                return
            filelist = [filename + '.mp4' for filename in filelist]
            download_files(filelist, dataset_mask_url, dataset_output_path)

        # Else: models for deepfakes
        else:
            if dataset != 'Deepfakes' and c_type == 'models':
                print('Models only available for Deepfakes. Aborting')
                return
            dataset_output_path = join(output_path, dataset_path, c_type)
            print('Output path: {}'.format(dataset_output_path))

            # Get Deepfakes models
            for folder in tqdm(filelist):
                folder_filelist = DEEPFAKES_MODEL_NAMES

                # Folder paths
                folder_base_url = args.deepfakes_model_url + folder + '/'
                folder_dataset_output_path = join(dataset_output_path,
                                                  folder)
                download_files(folder_filelist, folder_base_url,
                               folder_dataset_output_path,
                               report_progress=False)   # already done


if __name__ == "__main__":
    args = parse_args()
    main(args)''')

print("Download script created successfully!")

Download script created successfully!


In [None]:
!pip install timm albumentations pytorch-lightning face_recognition opencv-python-headless scikit-learn matplotlib seaborn

# Check GPU
!nvidia-smi

Collecting pytorch-lightning
  Downloading pytorch_lightning-2.5.1.post0-py3-none-any.whl.metadata (20 kB)
Collecting torchmetrics>=0.7.0 (from pytorch-lightning)
  Downloading torchmetrics-1.7.1-py3-none-any.whl.metadata (21 kB)
Collecting lightning-utilities>=0.10.0 (from pytorch-lightning)
  Downloading lightning_utilities-0.14.3-py3-none-any.whl.metadata (5.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch->timm)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch->timm)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (

In [None]:
# Download FaceForensics++ Dataset with c40 compression (500 videos per method)
# This is optimized for T4 GPU in Colab (smaller file size)
!python /content/dataset_setup/download_faceforensics.py /content/datasets/faceforensics -d original -c c40 -t videos -n 500 --server EU2
!python /content/dataset_setup/download_faceforensics.py /content/datasets/faceforensics -d Deepfakes -c c40 -t videos -n 500 --server EU2
!python /content/dataset_setup/download_faceforensics.py /content/datasets/faceforensics -d NeuralTextures -c c40 -t videos -n 500 --server EU2
!python /content/dataset_setup/download_faceforensics.py /content/datasets/faceforensics -d Face2Face -c c40 -t videos -n 500 --server EU2

python3: can't open file '/content/dataset_setup/download_faceforensics.py': [Errno 2] No such file or directory
python3: can't open file '/content/dataset_setup/download_faceforensics.py': [Errno 2] No such file or directory
python3: can't open file '/content/dataset_setup/download_faceforensics.py': [Errno 2] No such file or directory
python3: can't open file '/content/dataset_setup/download_faceforensics.py': [Errno 2] No such file or directory


In [None]:
import os
import cv2
from tqdm import tqdm
import concurrent.futures
import multiprocessing

def extract_frames(video_path, output_dir, sample_rate=30):
    """
    Extract frames from a video at the specified sample rate

    Args:
        video_path: Path to the video file
        output_dir: Directory to save extracted frames
        sample_rate: Extract 1 frame every 'sample_rate' frames
    """
    os.makedirs(output_dir, exist_ok=True)

    video_name = os.path.basename(video_path).split('.')[0]
    cap = cv2.VideoCapture(video_path)

    frame_count = 0
    saved_count = 0

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break

        if frame_count % sample_rate == 0:
            frame_path = os.path.join(output_dir, f"{video_name}_{saved_count:04d}.jpg")
            cv2.imwrite(frame_path, frame)
            saved_count += 1

        frame_count += 1

    cap.release()
    return saved_count

def process_videos(video_paths, output_dir, sample_rate=30):
    """Process multiple videos with multiprocessing"""

    # Define a worker function for each video
    def worker(video_path):
        video_name = os.path.basename(video_path).split('.')[0]
        video_output_dir = os.path.join(output_dir, video_name)
        return extract_frames(video_path, video_output_dir, sample_rate)

    # Get all video paths
    all_videos = []
    for video_path in video_paths:
        if os.path.isdir(video_path):
            for root, dirs, files in os.walk(video_path):
                for file in files:
                    if file.endswith('.mp4'):
                        all_videos.append(os.path.join(root, file))
        elif video_path.endswith('.mp4'):
            all_videos.append(video_path)

    print(f"Found {len(all_videos)} videos to process")

    # Process videos with parallel workers
    num_workers = min(multiprocessing.cpu_count(), 4)  # Limit to 4 workers to avoid memory issues
    total_frames = 0

    with concurrent.futures.ProcessPoolExecutor(max_workers=num_workers) as executor:
        # Submit all jobs
        future_to_video = {executor.submit(worker, video): video for video in all_videos}

        # Process as they complete
        for future in tqdm(concurrent.futures.as_completed(future_to_video), total=len(all_videos)):
            video = future_to_video[future]
            try:
                total_frames += future.result()
            except Exception as e:
                print(f"Error processing {video}: {e}")

    print(f"Extracted {total_frames} frames in total")

# Define paths for videos and frame extraction
ff_original_path = '/content/datasets/faceforensics/original_sequences/youtube/c40/videos'
ff_deepfakes_path = '/content/datasets/faceforensics/manipulated_sequences/Deepfakes/c40/videos'
ff_neural_path = '/content/datasets/faceforensics/manipulated_sequences/NeuralTextures/c40/videos'
ff_face2face_path = '/content/datasets/faceforensics/manipulated_sequences/Face2Face/c40/videos'

# Create output directories for extracted frames
ff_original_frames = '/content/datasets/frames/faceforensics/original'
ff_deepfakes_frames = '/content/datasets/frames/faceforensics/deepfakes'
ff_neural_frames = '/content/datasets/frames/faceforensics/neuraltextures'
ff_face2face_frames = '/content/datasets/frames/faceforensics/face2face'

# Set higher sample rate for T4 GPU (extract fewer frames to manage memory)
sample_rate = 60  # Extract 1 frame every 60 frames

# Process videos if they exist
if os.path.exists(ff_original_path):
    print("Extracting frames from FaceForensics++ original videos...")
    process_videos([ff_original_path], ff_original_frames, sample_rate)

if os.path.exists(ff_deepfakes_path):
    print("Extracting frames from FaceForensics++ deepfakes videos...")
    process_videos([ff_deepfakes_path], ff_deepfakes_frames, sample_rate)

if os.path.exists(ff_neural_path):
    print("Extracting frames from FaceForensics++ neural textures videos...")
    process_videos([ff_neural_path], ff_neural_frames, sample_rate)

print("Frame extraction complete!")

Frame extraction complete!


In [None]:
# Create CelebDF directories
!mkdir -p /content/datasets/celebdf/Celeb-real
!mkdir -p /content/datasets/celebdf/Celeb-synthesis

# Create directories for extracted frames
!mkdir -p /content/datasets/frames/celebdf/real
!mkdir -p /content/datasets/frames/celebdf/fake

In [None]:
!pip install torch



In [None]:
!pip install face_recognition


Collecting face_recognition
  Downloading face_recognition-1.3.0-py2.py3-none-any.whl.metadata (21 kB)
Collecting face-recognition-models>=0.3.0 (from face_recognition)
  Downloading face_recognition_models-0.3.0.tar.gz (100.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m100.1/100.1 MB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Downloading face_recognition-1.3.0-py2.py3-none-any.whl (15 kB)
Building wheels for collected packages: face-recognition-models
  Building wheel for face-recognition-models (setup.py) ... [?25l[?25hdone
  Created wheel for face-recognition-models: filename=face_recognition_models-0.3.0-py2.py3-none-any.whl size=100566166 sha256=f963212130733eee9097f41c7e715d49967a288e080c08d9d4ac89a3bb03c6f1
  Stored in directory: /root/.cache/pip/wheels/04/52/ec/9355da79c29f160b038a20c784db2803c2f9fa2c8a462c176a
Successfully built face-recognition-models
Installing collected packages: face-recogn

In [None]:
import os
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score, confusion_matrix
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
import face_recognition
from PIL import Image
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
import gc

# Check PyTorch version and GPU
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

PyTorch version: 2.6.0+cu124
CUDA available: True
GPU: Tesla T4
GPU Memory: 15.83 GB


In [None]:
def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    import random
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

In [None]:
class FaceExtractor:
    """Extract and align faces from images"""

    def __init__(self, output_size=224):
        self.output_size = output_size

    def extract_face(self, image_path):
        """Extract face from image with alignment"""
        try:
            # Check if input is a path or an image array
            if isinstance(image_path, str):
                # Load image
                image = face_recognition.load_image_file(image_path)
            else:
                # Assume it's already a numpy array
                image = image_path

            # Find face locations
            face_locations = face_recognition.face_locations(image)

            if len(face_locations) == 0:
                # If no face detected, return resized original image
                if isinstance(image_path, str):
                    image = cv2.imread(image_path)
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                return cv2.resize(image, (self.output_size, self.output_size))

            # Get the first face
            top, right, bottom, left = face_locations[0]

            # Extract face with some margin
            margin = int((bottom - top) * 0.2)
            top = max(0, top - margin)
            left = max(0, left - margin)
            bottom = min(image.shape[0], bottom + margin)
            right = min(image.shape[1], right + margin)

            face_image = image[top:bottom, left:right]

            # Resize to output size
            face_image = cv2.resize(face_image, (self.output_size, self.output_size))

            return face_image

        except Exception as e:
            print(f"Error processing image: {str(e)}")
            # Return original image if face extraction fails
            if isinstance(image_path, str):
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            return cv2.resize(image, (self.output_size, self.output_size))

In [None]:
class DeepfakeFrameDataset(Dataset):
    def __init__(self, frame_dirs, real_label=0, transform=None, max_samples_per_folder=None):
        """
        Dataset for frame-based deepfake detection

        Args:
            frame_dirs: List of [directory, label] pairs
            real_label: Label value for real images (0 or 1)
            transform: Albumentations transforms
            max_samples_per_folder: Maximum samples to use per folder (for balancing)
        """
        self.transform = transform
        self.image_paths = []
        self.labels = []
        self.dataset_sources = []

        # Collect all image paths and labels
        for dir_path, label in frame_dirs:
            if not os.path.exists(dir_path):
                print(f"Warning: {dir_path} does not exist, skipping.")
                continue

            # Get source dataset from path
            if 'faceforensics' in dir_path.lower():
                source = 'faceforensics'
            elif 'celebdf' in dir_path.lower():
                source = 'celebdf'
            else:
                source = 'unknown'

            # Walk through subdirectories for frames
            frames_in_folder = []
            for root, _, files in os.walk(dir_path):
                for file in files:
                    if file.endswith(('.jpg', '.png')):
                        frames_in_folder.append(os.path.join(root, file))

            # Sample frames if needed
            if max_samples_per_folder and len(frames_in_folder) > max_samples_per_folder:
                frames_in_folder = np.random.choice(frames_in_folder, max_samples_per_folder, replace=False).tolist()

            # Add to dataset
            self.image_paths.extend(frames_in_folder)
            self.labels.extend([label] * len(frames_in_folder))
            self.dataset_sources.extend([source] * len(frames_in_folder))

        print(f"Loaded {len(self.image_paths)} images total")
        print(f"Real images: {self.labels.count(real_label)}")
        print(f"Fake images: {self.labels.count(1 - real_label)}")

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        label = self.labels[idx]

        # Load image
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        # Apply transformations
        if self.transform:
            augmented = self.transform(image=image)
            image = augmented['image']

        return image, label

In [None]:
def get_train_transforms(image_size=224):
    return A.Compose([
        A.RandomResizedCrop(height=image_size, width=image_size, scale=(0.8, 1.0), ratio=(0.9, 1.1)),
        A.HorizontalFlip(p=0.5),
        A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1, p=0.5),
        A.GaussianBlur(blur_limit=(3, 7), p=0.3),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

def get_val_transforms(image_size=224):
    return A.Compose([
        A.Resize(height=image_size, width=image_size),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])

In [None]:
class SwinDeepfakeDetector(pl.LightningModule):
    def __init__(self,
                 model_name='swin_base_patch4_window7_224',
                 num_classes=1,
                 learning_rate=2e-5,  # Lower learning rate for T4
                 weight_decay=1e-5,
                 max_epochs=30):
        super().__init__()
        self.save_hyperparameters()

        # Load pretrained Swin Transformer
        self.backbone = timm.create_model(model_name, pretrained=True)

        # Modify the classifier head for binary classification
        num_features = self.backbone.head.in_features
        self.backbone.head = nn.Linear(num_features, num_classes)

        # For binary classification
        self.criterion = nn.BCEWithLogitsLoss()

        # Metrics storage
        self.val_outputs = []
        self.test_outputs = []

        # Mixed precision - important for T4 GPU
        self.use_amp = True

    def forward(self, x):
        return self.backbone(x)

    def configure_optimizers(self):
        # Separate parameters for different learning rates
        encoder_params = [p for n, p in self.named_parameters() if 'head' not in n]
        classifier_params = [p for n, p in self.named_parameters() if 'head' in n]

        optimizer = torch.optim.AdamW([
            {'params': encoder_params, 'lr': self.hparams.learning_rate / 10},  # Lower LR for pretrained parts
            {'params': classifier_params, 'lr': self.hparams.learning_rate}
        ], weight_decay=self.hparams.weight_decay)

        # Cosine annealing with warm restarts
        scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=10,
            T_mult=2,
            eta_min=1e-6
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'epoch',
                'frequency': 1
            }
        }

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images).squeeze()
        loss = self.criterion(outputs, labels.float())

        # Calculate metrics
        preds = torch.sigmoid(outputs) > 0.5
        acc = (preds == labels).float().mean()

        # Log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images).squeeze()
        loss = self.criterion(outputs, labels.float())

        # Store outputs for epoch-end metrics
        self.val_outputs.append({
            'loss': loss,
            'preds': torch.sigmoid(outputs),
            'labels': labels
        })

        return loss

    def on_validation_epoch_end(self):
        # Aggregate validation outputs
        avg_loss = torch.stack([x['loss'] for x in self.val_outputs]).mean()
        preds = torch.cat([x['preds'] for x in self.val_outputs])
        labels = torch.cat([x['labels'] for x in self.val_outputs])

        # Calculate metrics
        preds_binary = (preds > 0.5).cpu().numpy()
        labels_np = labels.cpu().numpy()
        preds_np = preds.cpu().numpy()

        acc = accuracy_score(labels_np, preds_binary)
        precision = precision_score(labels_np, preds_binary, zero_division=0)
        recall = recall_score(labels_np, preds_binary, zero_division=0)
        f1 = f1_score(labels_np, preds_binary, zero_division=0)

        try:
            auc = roc_auc_score(labels_np, preds_np)
        except:
            auc = 0.0

        # Log metrics
        self.log('val_loss', avg_loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)
        self.log('val_precision', precision)
        self.log('val_recall', recall)
        self.log('val_f1', f1)
        self.log('val_auc', auc)

        # Clear outputs
        self.val_outputs.clear()

    def test_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images).squeeze()
        loss = self.criterion(outputs, labels.float())

        # Store outputs for epoch-end metrics
        self.test_outputs.append({
            'loss': loss,
            'preds': torch.sigmoid(outputs),
            'labels': labels
        })

        return loss

    def on_test_epoch_end(self):
        # Aggregate test outputs
        avg_loss = torch.stack([x['loss'] for x in self.test_outputs]).mean()
        preds = torch.cat([x['preds'] for x in self.test_outputs])
        labels = torch.cat([x['labels'] for x in self.test_outputs])

        # Calculate metrics
        preds

In [None]:
def prepare_datasets(max_samples_per_folder=2000):
    """
    Prepare datasets for training, validation, and testing
    Balanced to work with T4 GPU memory constraints

    Args:
        max_samples_per_folder: Maximum samples to use per folder
    """
    # Define dataset directories with labels (0=real, 1=fake)
    frame_dirs = [
        ['/content/datasets/frames/faceforensics/original', 0],  # FaceForensics++ original (real)
        ['/content/datasets/frames/faceforensics/deepfakes', 1],  # FaceForensics++ deepfakes (fake)
        ['/content/datasets/frames/faceforensics/neuraltextures', 1],  # FaceForensics++ neural (fake)
        ['/content/datasets/frames/celebdf/real', 0],  # CelebDF real
        ['/content/datasets/frames/celebdf/fake', 1],  # CelebDF fake
    ]

    # Create combined dataset - already limiting samples per folder
    dataset = DeepfakeFrameDataset(
        frame_dirs,
        real_label=0,
        transform=None,  # No transforms here, will apply later
        max_samples_per_folder=max_samples_per_folder
    )

    # Split dataset maintaining class balance
    train_indices, temp_indices = train_test_split(
        list(range(len(dataset))),
        test_size=0.2,
        stratify=dataset.labels,
        random_state=42
    )

    val_indices, test_indices = train_test_split(
        temp_indices,
        test_size=0.5,
        stratify=[dataset.labels[i] for i in temp_indices],
        random_state=42
    )

    # Create image path and label lists for each split
    train_paths = [dataset.image_paths[i] for i in train_indices]
    train_labels = [dataset.labels[i] for i in train_indices]
    train_sources = [dataset.dataset_sources[i] for i in train_indices]

    val_paths = [dataset.image_paths[i] for i in val_indices]
    val_labels = [dataset.labels[i] for i in val_indices]
    val_sources = [dataset.dataset_sources[i] for i in val_indices]

    test_paths = [dataset.image_paths[i] for i in test_indices]
    test_labels = [dataset.labels[i] for i in test_indices]
    test_sources = [dataset.dataset_sources[i] for i in test_indices]

    print(f"Train set: {len(train_paths)} images")
    print(f"Validation set: {len(val_paths)} images")
    print(f"Test set: {len(test_paths)} images")

    # Create datasets with appropriate transforms
    train_dataset = DeepfakeFrameDataset(
        [[train_paths[i], train_labels[i]] for i in range(len(train_paths))],
        transform=get_train_transforms(224)
    )

    val_dataset = DeepfakeFrameDataset(
        [[val_paths[i], val_labels[i]] for i in range(len(val_paths))],
        transform=get_val_transforms(224)
    )

    test_dataset = DeepfakeFrameDataset(
        [[test_paths[i], test_labels[i]] for i in range(len(test_paths))],
        transform=get_val_transforms(224)
    )

    return train_dataset, val_dataset, test_dataset, test_sources

def train_model(batch_size=16, max_epochs=30, accumulate_grad_batches=2):
    """
    Train the Swin Transformer model

    Args:
        batch_size: Batch size for training
        max_epochs: Maximum number of training epochs
        accumulate_grad_batches: Number of batches to accumulate gradients (helps with small batch sizes)
    """
    # Prepare datasets
    train_dataset, val_dataset, test_dataset, test_sources = prepare_datasets(max_samples_per_folder=2000)

    # Create dataloaders with appropriate batch sizes
    # T4 GPU has ~16GB memory, so we need to be careful with batch size
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2,
        pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )

    # Initialize model
    model = SwinDeepfakeDetector(
        model_name='swin_base_patch4_window7_224',
        num_classes=1,
        learning_rate=2e-5,  # Lower for T4 GPU
        weight_decay=1e-5,
        max_epochs=max_epochs
    )

    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val_f1',
        dirpath='checkpoints',
        filename='deepfake-detector-{epoch:02d}-{val_f1:.2f}',
        save_top_k=3,
        mode='max'
    )

    early_stopping = EarlyStopping(
        monitor='val_f1',
        patience=5,
        mode='max'
    )

    # Initialize trainer with T4 GPU optimizations
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if torch.cuda.is_available() else 'cpu',
        devices=1,
        callbacks=[checkpoint_callback, early_stopping],
        precision=16,  # Mixed precision for T4 GPU - CRITICAL for memory efficiency
        gradient_clip_val=1.0,
        accumulate_grad_batches=accumulate_grad_batches,  # Accumulate gradients to simulate larger batch sizes
        deterministic=True,
        log_every_n_steps=50,  # Reduce logging frequency to save time
        enable_progress_bar=True,
        enable_model_summary=True,
    )

    # Train model
    print("Starting training...")
    trainer.fit(model, train_loader, val_loader)

    # Test model
    print("Evaluating on test set...")
    test_results = trainer.test(model, test_loader)

    return model, test_dataset, test_sources

In [None]:
def evaluate_model_by_dataset(model, test_dataset, test_sources, batch_size=16):
    """
    Evaluate model performance on each dataset separately

    Args:
        model: Trained model
        test_dataset: Test dataset
        test_sources: Source dataset for each test sample
        batch_size: Batch size for evaluation
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    # Separate indices by dataset source
    ff_indices = [i for i, source in enumerate(test_sources) if source == 'faceforensics']
    celebdf_indices = [i for i, source in enumerate(test_sources) if source == 'celebdf']

    results = {}

    # Evaluate on FaceForensics++ subset
    if ff_indices:
        ff_subset = torch.utils.data.Subset(test_dataset, ff_indices)
        ff_loader = DataLoader(ff_subset, batch_size=batch_size, shuffle=False, num_workers=2)

        ff_preds = []
        ff_labels = []

        with torch.no_grad():
            for images, labels in tqdm(ff_loader, desc="Evaluating FaceForensics++"):
                images = images.to(device)
                outputs = model(images).squeeze()
                preds = torch.sigmoid(outputs).cpu().numpy()

                ff_preds.extend(preds)
                ff_labels.extend(labels.numpy())

        # Calculate metrics
        ff_preds = np.array(ff_preds)
        ff_labels = np.array(ff_labels)
        ff_preds_binary = (ff_preds > 0.5).astype(int)

        metrics = {
            'accuracy': accuracy_score(ff_labels, ff_preds_binary),
            'precision': precision_score(ff_labels, ff_preds_binary, zero_division=0),
            'recall': recall_score(ff_labels, ff_preds_binary, zero_division=0),
            'f1': f1_score(ff_labels, ff_preds_binary, zero_division=0),
            'auc': roc_auc_score(ff_labels, ff_preds) if len(np.unique(ff_labels)) > 1 else 0.0
        }

        results['faceforensics'] = {
            'metrics': metrics,
            'preds': ff_preds,
            'labels': ff_labels
        }

    # Evaluate on CelebDF subset
    if celebdf_indices:
        celebdf_subset = torch.utils.data.Subset(test_dataset, celebdf_indices)
        celebdf_loader = DataLoader(celebdf_subset, batch_size=batch_size, shuffle=False, num_workers=2)

        celebdf_preds = []
        celebdf_labels = []

        with torch.no_grad():
            for images, labels in tqdm(celebdf_loader, desc="Evaluating CelebDF"):
                images = images.to(device)
                outputs = model(images).squeeze()
                preds = torch.sigmoid(outputs).cpu().numpy()

                celebdf_preds.extend(preds)
                celebdf_labels.extend(labels.numpy())

        # Calculate metrics
        celebdf_preds = np.array(celebdf_preds)
        celebdf_labels = np.array(celebdf_labels)
        celebdf_preds_binary = (celebdf_preds > 0.5).astype(int)

        metrics = {
            'accuracy': accuracy_score(celebdf_labels, celebdf_preds_binary),
            'precision': precision_score(celebdf_labels, celebdf_preds_binary, zero_division=0),
            'recall': recall_score(celebdf_labels, celebdf_preds_binary, zero_division=0),
            'f1': f1_score(celebdf_labels, celebdf_preds_binary, zero_division=0),
            'auc': roc_auc_score(celebdf_labels, celebdf_preds) if len(np.unique(celebdf_labels)) > 1 else 0.0
        }

        results['celebdf'] = {
            'metrics': metrics,
            'preds': celebdf_preds,
            'labels': celebdf_labels
        }

    # Print results
    print("\nCross-Dataset Evaluation Results:")
    for dataset, result in results.items():
        print(f"\n{dataset.upper()}:")
        for metric, value in result['metrics'].items():
            print(f"  {metric}: {value:.4f}")

    return results

def plot_confusion_matrices(results):
    """Plot confusion matrices for each dataset"""
    num_datasets = len(results)
    fig, axes = plt.subplots(1, num_datasets, figsize=(6*num_datasets, 5))

    if num_datasets == 1:
        axes = [axes]

    for ax, (dataset, result) in zip(axes, results.items()):
        labels = result['labels']
        preds = (result['preds'] > 0.5).astype(int)
        cm = confusion_matrix(labels, preds)

        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=['Real', 'Fake'],
                    yticklabels=['Real', 'Fake'], ax=ax)
        ax.set_title(f'{dataset.upper()} Confusion Matrix')
        ax.set_ylabel('True Label')
        ax.set_xlabel('Predicted Label')

    plt.tight_layout()
    plt.savefig('confusion_matrices.png')
    plt.show()

def plot_roc_curves(results):
    """Plot ROC curves for each dataset"""
    plt.figure(figsize=(10, 6))

    colors = ['darkorange', 'green', 'blue']
    for i, (dataset, result) in enumerate(results.items()):
        labels = result['labels']
        preds = result['preds']

        fpr, tpr, _ = roc_curve(labels, preds)
        roc_auc = auc(fpr, tpr)

        plt.plot(fpr, tpr, color=colors[i], lw=2,
                 label=f'{dataset.upper()} (AUC = {roc_auc:.2f})')

    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curves')
    plt.legend(loc="lower right")
    plt.savefig('roc_curves.png')
    plt.show()

In [None]:
# Create directories
!mkdir -p checkpoints

# Clear memory
gc.collect()
torch.cuda.empty_cache()

# Check available memory on GPU
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # GB
    print(f"Total GPU memory: {gpu_memory:.2f} GB")
    print(f"Available GPU memory: {torch.cuda.memory_reserved(0) / 1e9:.2f} GB reserved")

    # Determine batch size based on available memory for T4 GPU
    batch_size = 16  # Start with this, will be adjusted if needed
    accumulate_grad_batches = 2
else:
    batch_size = 4
    accumulate_grad_batches = 4

print(f"Using batch size: {batch_size} with gradient accumulation: {accumulate_grad_batches}")

# Train model
model, test_dataset, test_sources = train_model(
    batch_size=batch_size,
    max_epochs=30,
    accumulate_grad_batches=accumulate_grad_batches
)

Total GPU memory: 15.83 GB
Available GPU memory: 0.00 GB reserved
Using batch size: 16 with gradient accumulation: 2
Loaded 0 images total
Real images: 0
Fake images: 0


ValueError: With n_samples=0, test_size=0.2 and train_size=None, the resulting train set will be empty. Adjust any of the aforementioned parameters.

In [None]:
# Evaluate model by dataset
results = evaluate_model_by_dataset(model, test_dataset, test_sources, batch_size=batch_size)

# Visualize results
plot_confusion_matrices(results)
plot_roc_curves(results)

# Save the model
torch.save(model.state_dict(), 'deepfake_detector_swin_base.pth')
print("Model saved successfully!")

# Save to Google Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!cp deepfake_detector_swin_base.pth /content/drive/MyDrive/
!cp confusion_matrices.png /content/drive/MyDrive/
!cp roc_curves.png /content/drive/MyDrive/
print("Results copied to Google Drive!")

In [None]:
import random

def test_with_random_sample():
    """Test the model with a random sample from the test dataset"""
    if 'test_dataset' not in globals():
        print("Test dataset not available. Train the model first.")
        return

    # Select a random image from test set
    idx = random.randint(0, len(test_dataset)-1)
    image, label = test_dataset[idx]

    # Convert tensor to numpy for visualization
    image_np = image.permute(1, 2, 0).numpy()
    image_np = image_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
    image_np = np.clip(image_np, 0, 1)

    # Make prediction
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    with torch.no_grad():
        output = model(image.unsqueeze(0).to(device)).squeeze()
        prob = torch.sigmoid(output).item()

    is_fake = prob > 0.5
    confidence = prob if is_fake else 1 - prob
    true_label = "FAKE" if label == 1 else "REAL"
    pred_label = "FAKE" if is_fake else "REAL"

    # Display result
    plt.figure(figsize=(10, 4))
    plt.subplot(1, 2, 1)
    plt.imshow(image_np)
    plt.title(f"True Label: {true_label}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.bar(['Real', 'Fake'], [1-prob, prob], color=['green', 'red'])
    plt.ylim(0, 1)
    plt.title(f"Prediction: {pred_label}\nConfidence: {confidence:.2%}")

    plt.tight_layout()
    plt.show()

    print(f"True label: {true_label}")
    print(f"Prediction: {pred_label}")
    print(f"Confidence: {confidence:.2%}")
    print(f"Correct prediction: {true_label == pred_label}")

# Run test
test_with_random_sample()

In [None]:
from google.colab import files
import io

def test_with_uploaded_image():
    """Test the model with an uploaded image"""
    print("Please upload an image...")
    uploaded = files.upload()

    for filename in uploaded.keys():
        # Read image
        image = Image.open(io.BytesIO(uploaded[filename]))

        # Convert to numpy array
        image_np = np.array(image)
        if image_np.shape[-1] > 3:  # Handle RGBA
            image_np = image_np[:, :, :3]

        # Extract face
        face_extractor = FaceExtractor(output_size=224)
        face = face_extractor.extract_face(image_np)

        # Preprocess image
        transform = get_val_transforms(224)
        image_tensor = transform(image=face)['image'].unsqueeze(0)

        # Make prediction
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model.to(device)
        model.eval()

        with torch.no_grad():
            output = model(image_tensor.to(device)).squeeze()
            prob = torch.sigmoid(output).item()

        is_fake = prob > 0.5
        confidence = prob if is_fake else 1 - prob

        # Display result
        plt.figure(figsize=(10, 4))
        plt.subplot(1, 2, 1)
        plt.imshow(face)
        plt.title("Input Image")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.bar(['Real', 'Fake'], [1-prob, prob], color=['green', 'red'])
        plt.ylim(0, 1)
        plt.title(f"Prediction: {'FAKE' if is_fake else 'REAL'}\nConfidence: {confidence:.2%}")

        plt.tight_layout()
        plt.show()

        print(f"Prediction: {'FAKE' if is_fake else 'REAL'}")
        print(f"Confidence: {confidence:.2%}")
        print(f"Fake probability: {prob:.2%}")

# Run test with uploaded image
# test_with_uploaded_image()

In [None]:
def analyze_video(video_path, output_path=None, frame_skip=10):
    """Analyze a video for deepfake detection"""

    if not os.path.exists(video_path):
        print(f"Video not found at {video_path}")
        return

    # Extract face extractor and transforms
    face_extractor = FaceExtractor(output_size=224)
    transform = get_val_transforms(224)

    # Load model
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.to(device)
    model.eval()

    # Open video
    cap = cv2.VideoCapture(video_path)
    fps = int(cap.get(cv2.CAP_PROP_FPS))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))

    # Prepare output video if path is provided
    if output_path:
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
        out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))

    frame_count = 0
    predictions = []

    with torch.no_grad():
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break

            if frame_count % frame_skip == 0:
                # Extract face and make prediction
                try:
                    face = face_extractor.extract_face(frame)
                    face_tensor = transform(image=face)['image'].unsqueeze(0).to(device)

                    output = model(face_tensor).squeeze()
                    prob = torch.sigmoid(output).item()

                    is_fake = prob > 0.5
                    confidence = prob if is_fake else 1 - prob

                    # Add text to frame
                    text = f"{'FAKE' if is_fake else 'REAL'}: {confidence:.2%}"
                    cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX,
                               1, (0, 0, 255) if is_fake else (0, 255, 0), 2)

                    predictions.append(prob)
                except Exception as e:
                    print(f"Error processing frame {frame_count}: {e}")

            if output_path:
                out.write(frame)

            frame_count += 1

    cap.release()
    if output_path:
        out.release()

    # Calculate overall video prediction
    if predictions:
        avg_fake_prob = np.mean(predictions)
        is_video_fake = avg_fake_prob > 0.5

        print(f"Video Analysis Complete")
        print(f"Total frames analyzed: {len(predictions)}/{frame_count}")
        print(f"Average fake probability: {avg_fake_prob:.2%}")
        print(f"Video verdict: {'FAKE' if is_video_fake else 'REAL'}")

        # Plot prediction histogram
        plt.figure(figsize=(10, 4))
        plt.hist(predictions, bins=20, alpha=0.7, color='blue')
        plt.axvline(x=0.5, color='red', linestyle='--')
        plt.axvline(x=avg_fake_prob, color='green', linestyle='-')
        plt.title(f"Frame Predictions Histogram\nAverage: {avg_fake_prob:.2%}")
        plt.xlabel("Fake Probability")
        plt.ylabel("Number of Frames")
        plt.savefig('video_analysis.png')
        plt.show()

        return {
            'is_fake': is_video_fake,
            'average_probability': avg_fake_prob,
            'frame_predictions': predictions
        }
    else:
        print("No faces detected in video")
        return None

# Example usage - uncomment to run
# video_results = analyze_video('/path/to/video.mp4', '/path/to/output.mp4', frame_skip=30)

In [None]:
# Clear memory
gc.collect()
torch.cuda.empty_cache()

# Check memory usage
if torch.cuda.is_available():
    gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9  # GB
    reserved = torch.cuda.memory_reserved(0) / 1e9
    allocated = torch.cuda.memory_allocated(0) / 1e9
    print(f"Total GPU memory: {gpu_memory:.2f} GB")
    print(f"Reserved GPU memory: {reserved:.2f} GB")
    print(f"Allocated GPU memory: {allocated:.2f} GB")
    print(f"Free GPU memory: {gpu_memory - reserved:.2f} GB")

# Print model size if available
if 'model' in globals():
    model_size = sum(p.numel() for p in model.parameters()) / 1e6
    print(f"Model size: {model_size:.2f} million parameters")

In [None]:
def generate_model_report():
    """Generate a comprehensive model report"""

    if 'results' not in globals() or 'model' not in globals():
        print("Results not available. Train the model first.")
        return

    # Create report
    report = """
    # Deepfake Detection Model Report

    ## Model Architecture
    - Backbone: Swin Transformer Base
    - Pretrained: ImageNet
    - Input Size: 224x224

    ## Training Details
    - Batch Size: {batch_size}
    - Gradient Accumulation: {accumulate_grad_batches}
    - Learning Rate: 2e-5 (head), 2e-6 (backbone)
    - Mixed Precision: FP16
    - Early Stopping: Yes (F1 score)

    ## Performance Metrics
    """.format(
        batch_size=batch_size,
        accumulate_grad_batches=accumulate_grad_batches
    )

    # Add metrics for each dataset
    for dataset, result in results.items():
        report += f"\n### {dataset.upper()} Dataset\n"
        for metric, value in result['metrics'].items():
            report += f"- {metric.capitalize()}: {value:.4f}\n"

    # Save report
    with open('model_report.md', 'w') as f:
        f.write(report)

    # Copy to Drive
    !cp model_report.md /content/drive/MyDrive/

    print("Report generated and saved to Google Drive!")
    return report

# Generate report
# generate_model_report()