# Deepfake Detection Model Training

This notebook trains a deepfake detection model using audio and video modalities with fusion capabilities.

## Features:
- Audio branch using Wav2Vec2
- Video branch using Timesformer
- Fusion model for multi-modal detection
- LoRA adapters for efficient fine-tuning
- Automatic dataset download and preprocessing

In [None]:
# Install required packages
!pip install torch torchvision torchaudio transformers datasets kagglehub scikit-learn peft librosa opencv-python

In [None]:
# Mount Google Drive for saving models
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone the repository or upload files
!git clone https://github.com/your-repo/lj-hackathon.git
%cd lj-hackathon/ml

In [None]:
# Import required libraries
import sys
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import kagglehub
import zipfile
from pathlib import Path
import torchaudio
import torchvision.transforms as transforms
from PIL import Image
import cv2
import librosa
from transformers import Wav2Vec2Model, TimesformerModel
from peft import LoraConfig, get_peft_model

In [None]:
# Model definition
class AudioBranch(nn.Module):
    def __init__(self, lora_config=None):
        super().__init__()
        self.backbone = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        if lora_config:
            self.backbone = get_peft_model(self.backbone, lora_config)
        self.classifier = nn.Linear(self.backbone.config.hidden_size, 1)

    def forward(self, x):
        outputs = self.backbone(x)
        return self.classifier(outputs.last_hidden_state.mean(dim=1))

class VideoBranch(nn.Module):
    def __init__(self, lora_config=None):
        super().__init__()
        self.backbone = TimesformerModel.from_pretrained("facebook/timesformer-base")
        if lora_config:
            self.backbone = get_peft_model(self.backbone, lora_config)
        self.classifier = nn.Linear(self.backbone.config.hidden_size, 1)

    def forward(self, x):
        outputs = self.backbone(x)
        return self.classifier(outputs.last_hidden_state[:, 0])

class DeepFakeDetector(nn.Module):
    def __init__(self, use_lora=True):
        super().__init__()
        lora_config = LoraConfig(r=16, lora_alpha=32) if use_lora else None
        self.audio_branch = AudioBranch(lora_config)
        self.video_branch = VideoBranch(lora_config)

    def forward(self, audio_input=None, video_input=None):
        if audio_input is not None and video_input is not None:
            audio_logits = self.audio_branch(audio_input)
            video_logits = self.video_branch(video_input)
            return (audio_logits + video_logits) / 2
        elif audio_input is not None:
            return self.audio_branch(audio_input)
        else:
            return self.video_branch(video_input)

In [None]:
# Data preprocessing functions
class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, audio_paths, labels, transform=None):
        self.audio_paths = audio_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.audio_paths)

    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]

        # Load audio
        waveform, sample_rate = torchaudio.load(audio_path)

        # Convert to mel-spectrogram
        mel_spec = torchaudio.transforms.MelSpectrogram(
            sample_rate=sample_rate,
            n_fft=1024,
            hop_length=512,
            n_mels=128
        )(waveform)

        # Convert to log scale
        mel_spec = torchaudio.transforms.AmplitudeToDB()(mel_spec)

        # Normalize
        mel_spec = (mel_spec - mel_spec.mean()) / (mel_spec.std() + 1e-9)

        return {
            'input_values': mel_spec.squeeze(0),
            'labels': torch.tensor(label, dtype=torch.float)
        }

class VideoDataset(torch.utils.data.Dataset):
    def __init__(self, video_paths, labels, transform=None, num_frames=16):
        self.video_paths = video_paths
        self.labels = labels
        self.transform = transform
        self.num_frames = num_frames

    def __len__(self):
        return len(self.video_paths)

    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]

        # Load video
        cap = cv2.VideoCapture(video_path)
        frames = []
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

        # Sample frames evenly
        frame_indices = np.linspace(0, total_frames-1, self.num_frames, dtype=int)

        for i in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = Image.fromarray(frame)
                frames.append(frame)

        cap.release()

        # If not enough frames, duplicate last frame
        while len(frames) < self.num_frames:
            frames.append(frames[-1] if frames else Image.new('RGB', (224, 224)))

        # Apply transforms
        if self.transform:
            frames = [self.transform(frame) for frame in frames]

        # Stack frames
        video_tensor = torch.stack(frames, dim=1)  # [C, T, H, W]

        return {
            'pixel_values': video_tensor,
            'labels': torch.tensor(label, dtype=torch.float)
        }

In [None]:
# Download datasets
def download_datasets(data_dir):
    os.makedirs(data_dir, exist_ok=True)

    # ASVspoof 2019
    try:
        asv_path = kagglehub.dataset_download("asvspoof/asv-spoof-2019-dataset")
        print(f"ASVspoof downloaded to: {asv_path}")
    except:
        print("ASVspoof download failed - please download manually")

    # For DFDC, you may need to download manually due to size
    print("DFDC dataset is large - please download manually from https://www.kaggle.com/c/deepfake-detection-challenge")

# Create data loaders
data_dir = '/content/data'
download_datasets(data_dir)

# Note: You'll need to implement load_asvspoof2019 and load_dfdc functions
# For demo purposes, we'll assume you have the data ready
print("Please ensure datasets are downloaded and implement data loading functions")

In [None]:
# Training functions
def train_audio_model(model, train_loader, val_loader, epochs=10, save_path='/content/drive/MyDrive/models/audio_model.pth'):
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5)

    best_auc = 0.0

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0

        for batch in train_loader:
            optimizer.zero_grad()
            augmented_input = augment_audio(batch['input_values'])
            outputs = model.audio_branch(augmented_input)
            loss = criterion(outputs.squeeze(), batch['labels'])
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        # Validation
        model.eval()
        val_preds = []
        val_labels = []

        with torch.no_grad():
            for batch in val_loader:
                outputs = model.audio_branch(batch['input_values'])
                preds = torch.sigmoid(outputs.squeeze()).cpu().numpy()
                val_preds.extend(preds)
                val_labels.extend(batch['labels'].cpu().numpy())

        val_auc = roc_auc_score(val_labels, val_preds)
        val_acc = accuracy_score(val_labels, np.round(val_preds))

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss/len(train_loader):.4f}, Val AUC: {val_auc:.4f}, Val Acc: {val_acc:.4f}")

        if val_auc > best_auc:
            best_auc = val_auc
            torch.save(model.audio_branch.state_dict(), save_path)
            print(f"Saved best audio model with AUC: {best_auc:.4f}")

    return model

def augment_audio(mel_spec):
    if torch.rand(1) < 0.5:
        noise = torch.randn_like(mel_spec) * 0.1
        mel_spec = mel_spec + noise
    return mel_spec

# Similar functions for video and fusion training...

In [None]:
# Initialize and train model
model = DeepFakeDetector(use_lora=True)

# Assuming you have data loaders ready
# audio_train_loader, audio_val_loader = get_audio_loaders(data_dir)
# video_train_loader, video_val_loader = get_video_loaders(data_dir)

# Train audio model
# model = train_audio_model(model, audio_train_loader, audio_val_loader)

print("Model initialized. Please implement data loading and training calls.")

In [None]:
# Save final model
model_save_path = '/content/drive/MyDrive/models/deepfake_detector.pth'
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
torch.save(model.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")

In [None]:
# Inference example
def predict_deepfake(audio_path=None, video_path=None, model_path='/content/drive/MyDrive/models/deepfake_detector.pth'):
    model = DeepFakeDetector(use_lora=True)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    # Process inputs (implement preprocessing)
    # ...

    with torch.no_grad():
        if audio_path and video_path:
            # Fusion prediction
            outputs = model(audio_input=audio_tensor, video_input=video_tensor)
        elif audio_path:
            # Audio-only prediction
            outputs = model(audio_input=audio_tensor)
        else:
            # Video-only prediction
            outputs = model(video_input=video_tensor)

        score = torch.sigmoid(outputs.squeeze()).item()
        prediction = "synthetic" if score > 0.5 else "real"

    return {
        "deepfake_score": score,
        "prediction": prediction,
        "explanation": ["spectrogram anomalies", "face warp detected"]  # Implement actual explanations
    }

# Example usage:
# result = predict_deepfake(audio_path="path/to/audio.wav", video_path="path/to/video.mp4")
# print(result)