<a href="https://colab.research.google.com/github/NaveenSanjaya/Deepfake-Face-Detection-In-The-Wild/blob/main/deepfake-gan-detectoripynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from torchvision.io import read_image
import os
import pandas as pd
import numpy as np
from PIL import Image

# Custom Dataset class for handling multiple deepfake datasets
class MultiSourceDeepfakeDataset(Dataset):
    def __init__(self, root_dirs, transform=None):
        """
        Args:
            root_dirs (dict): Dictionary containing paths to different datasets
            transform: Image transformations to be applied
        """
        self.transform = transform
        self.image_paths = []
        self.labels = []

        # Process each dataset
        for dataset_name, dataset_path in root_dirs.items():
            real_path = os.path.join(dataset_path, 'real')
            fake_path = os.path.join(dataset_path, 'fake')

            # Add real images
            for img_name in os.listdir(real_path):
                self.image_paths.append(os.path.join(real_path, img_name))
                self.labels.append(1)  # 1 for real

            # Add fake images
            for img_name in os.listdir(fake_path):
                self.image_paths.append(os.path.join(fake_path, img_name))
                self.labels.append(0)  # 0 for fake

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label

# Generator Network
class Generator(nn.Module):
    def __init__(self, latent_dim=100, channels=3):
        """
        Generator network that creates synthetic images
        Args:
            latent_dim: Size of the input noise vector
            channels: Number of output image channels (3 for RGB)
        """
        super(Generator, self).__init__()

        # Initial dense layer to reshape noise
        self.fc = nn.Linear(latent_dim, 512 * 4 * 4)

        # Transposed convolution layers
        self.deconv1 = nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1)
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1)
        self.deconv4 = nn.ConvTranspose2d(64, channels, kernel_size=4, stride=2, padding=1)

        # Batch normalization layers
        self.bn1 = nn.BatchNorm2d(256)
        self.bn2 = nn.BatchNorm2d(128)
        self.bn3 = nn.BatchNorm2d(64)

    def forward(self, x):
        # Reshape input noise
        x = self.fc(x)
        x = x.view(-1, 512, 4, 4)

        # Apply transposed convolutions with batch normalization and ReLU
        x = F.relu(self.bn1(self.deconv1(x)))
        x = F.relu(self.bn2(self.deconv2(x)))
        x = F.relu(self.bn3(self.deconv3(x)))

        # Final layer with tanh activation
        x = torch.tanh(self.deconv4(x))
        return x

# Discriminator Network
class Discriminator(nn.Module):
    def __init__(self, channels=3):
        """
        Discriminator network that classifies images as real or fake
        Args:
            channels: Number of input image channels (3 for RGB)
        """
        super(Discriminator, self).__init__()

        # Convolutional layers
        self.conv1 = nn.Conv2d(channels, 64, kernel_size=4, stride=2, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
        self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)

        # Batch normalization layers
        self.bn1 = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(512)

        # Final classification layer
        self.fc = nn.Linear(512 * 4 * 4, 1)

        # Dropout for regularization
        self.dropout = nn.Dropout(0.3)

    def forward(self, x):
        # Apply convolutions with leaky ReLU and batch normalization
        x = F.leaky_relu(self.conv1(x), 0.2)
        x = F.leaky_relu(self.bn1(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.bn2(self.conv3(x)), 0.2)
        x = F.leaky_relu(self.bn3(self.conv4(x)), 0.2)

        # Flatten and apply dropout
        x = x.view(-1, 512 * 4 * 4)
        x = self.dropout(x)

        # Final classification
        x = torch.sigmoid(self.fc(x))
        return x

# Training function
def train_gan(generator, discriminator, dataloader, num_epochs, device):
    """
    Training loop for the GAN
    Args:
        generator: Generator network
        discriminator: Discriminator network
        dataloader: DataLoader containing the training data
        num_epochs: Number of training epochs
        device: Device to train on (CPU/GPU)
    """
    # Loss function and optimizers
    criterion = nn.BCELoss()
    g_optimizer = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    for epoch in range(num_epochs):
        for i, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)

            # Create labels
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Train Discriminator
            real_images = real_images.to(device)
            d_optimizer.zero_grad()

            # Loss on real images
            d_output_real = discriminator(real_images)
            d_loss_real = criterion(d_output_real, real_labels)

            # Loss on fake images
            noise = torch.randn(batch_size, 100).to(device)
            fake_images = generator(noise)
            d_output_fake = discriminator(fake_images.detach())
            d_loss_fake = criterion(d_output_fake, fake_labels)

            # Total discriminator loss
            d_loss = d_loss_real + d_loss_fake
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            g_optimizer.zero_grad()
            g_output = discriminator(fake_images)
            g_loss = criterion(g_output, real_labels)
            g_loss.backward()
            g_optimizer.step()

            if i % 100 == 0:
                print(f'Epoch [{epoch}/{num_epochs}], Step [{i}/{len(dataloader)}], '
                      f'd_loss: {d_loss.item():.4f}, g_loss: {g_loss.item():.4f}')

# Main execution
def main():
    # Set device
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Define dataset paths
    dataset_paths = {
        'celeb_df_v1': 'path/to/celeb_df_v1',
        'celeb_df_v2': 'path/to/celeb_df_v2',
        'faceforensics': 'path/to/faceforensics',
        'dfdc': 'path/to/dfdc'
    }

    # Define transformations
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Create dataset and dataloader
    dataset = MultiSourceDeepfakeDataset(dataset_paths, transform=transform)
    dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)

    # Initialize networks
    generator = Generator().to(device)
    discriminator = Discriminator().to(device)

    # Train the model
    train_gan(generator, discriminator, dataloader, num_epochs=100, device=device)

    # Save the trained models
    torch.save(generator.state_dict(), 'generator.pth')
    torch.save(discriminator.state_dict(), 'discriminator.pth')

if __name__ == '__main__':
    main()
