# Lightweight Deepfake Detection Training

Fast training setup for audio and video deepfake detection with lightweight models.

## Features:
- Lightweight CNN for audio deepfake detection
- Lightweight CNN-LSTM for video deepfake detection
- Fast training with reasonable accuracy
- Dataset downloading and preprocessing

In [None]:
# Install required packages
!pip install torch torchvision torchaudio librosa opencv-python tqdm requests

In [None]:
# Import required libraries
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
import numpy as np
from sklearn.metrics import roc_auc_score, accuracy_score
import librosa
import cv2
from PIL import Image
import torchaudio
import torchvision.transforms as transforms
from tqdm import tqdm
import requests
import zipfile
import tarfile

In [None]:
# VGGish + Logistic Regression Model for Fast Audio Deepfake Detection

# Install required packages
# !pip install torch torchaudio torchvggish scikit-learn soundfile

import os
import numpy as np
import soundfile as sf
import torch
from torchvggish import vggish, vggish_input
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score
import joblib

# 1) Load frozen VGGish (embeddings are 128-D)
vggish_model = vggish()
vggish_model.eval()

def embed_wav(path):
    """Extract VGGish embeddings from audio file."""
    wav, sr = sf.read(path)     # mono or stereo
    if wav.ndim > 1: 
        wav = wav.mean(axis=1)  # Convert to mono
    # VGGish expects 16k; vggish_input handles resample+log-mel framing
    examples_batch = vggish_input.waveform_to_examples(wav, sample_rate=sr)
    with torch.no_grad():
        emb = vggish_model.forward(torch.tensor(examples_batch).float())
    return emb.numpy().mean(axis=0)  # average over segments -> (128,)

# 2) Load data from CSV: "path,label" where label in {0,1}
def load_csv(csv_path):
    """Load embeddings and labels from CSV file."""
    X, y = [], []
    if os.path.exists(csv_path):
        with open(csv_path) as f:
            for line in f:
                parts = line.strip().split(",")
                if len(parts) >= 2:
                    p, lbl = parts[0], parts[1]
                    if os.path.exists(p):
                        X.append(embed_wav(p))
                        y.append(int(lbl))
    return np.stack(X) if X else np.array([]), np.array(y)

# 3) Train logistic regression on VGGish embeddings
def train_vggish_model(train_csv="audio_train.csv", val_csv="audio_val.csv"):
    """Train VGGish + Logistic Regression model."""
    print("Loading training data...")
    Xtr, ytr = load_csv(train_csv)
    if len(Xtr) == 0:
        print(f"No training data found in {train_csv}")
        return None
    
    print("Loading validation data...")
    Xva, yva = load_csv(val_csv)
    if len(Xva) == 0:
        print(f"No validation data found in {val_csv}")
        return None
    
    print(f"Training on {len(Xtr)} samples, validating on {len(Xva)} samples")
    
    # Fit logistic regression
    clf = LogisticRegression(max_iter=1000, n_jobs=-1, class_weight="balanced")
    clf.fit(Xtr, ytr)
    
    # Evaluate
    probs = clf.predict_proba(Xva)[:,1]
    auc = roc_auc_score(yva, probs)
    print(f"Validation AUC: {auc:.4f}")
    
    # Save model and normalization stats
    joblib.dump(clf, "vggish_linear.joblib")
    np.save("feature_mean_std.npy", np.stack([Xtr.mean(0), Xtr.std(0)]))
    print("Model saved as vggish_linear.joblib")
    
    return clf

# Simple Video Model (placeholder - implement based on your video dataset)
class SimpleVideoModel:
    """Simple video deepfake detector using basic features."""
    
    def __init__(self):
        self.trained = False
    
    def fit(self, X, y):
        # Placeholder - implement video feature extraction and training
        print("Video model training not implemented yet")
        self.trained = True
    
    def predict_proba(self, X):
        # Placeholder
        return np.random.rand(len(X), 2)

class FastDeepFakeDetector:
    """Combined audio and video detector."""
    
    def __init__(self):
        self.audio_model = None
        self.video_model = SimpleVideoModel()
        self.feature_stats = None
    
    def load_audio_model(self, model_path="vggish_linear.joblib", stats_path="feature_mean_std.npy"):
        """Load trained audio model."""
        if os.path.exists(model_path):
            self.audio_model = joblib.load(model_path)
        if os.path.exists(stats_path):
            self.feature_stats = np.load(stats_path)
    
    def predict_audio(self, audio_path):
        """Predict on single audio file."""
        if self.audio_model is None:
            return 0.5
        
        emb = embed_wav(audio_path)
        if self.feature_stats is not None:
            emb = (emb - self.feature_stats[0]) / (self.feature_stats[1] + 1e-9)
        
        prob_fake = self.audio_model.predict_proba([emb])[0, 1]
        return prob_fake
    
    def predict(self, audio_path=None, video_path=None):
        """Combined prediction."""
        scores = []
        
        if audio_path:
            audio_score = self.predict_audio(audio_path)
            scores.append(audio_score)
        
        if video_path:
            # Placeholder for video prediction
            video_score = 0.5
            scores.append(video_score)
        
        if scores:
            final_score = np.mean(scores)
        else:
            final_score = 0.5
        
        return {
            "deepfake_score": final_score,
            "prediction": "fake" if final_score > 0.5 else "real",
            "confidence": abs(final_score - 0.5) * 2
        }

In [None]:
# Dataset Download Functions

def download_file(url, dest_path):
    """Download file with progress bar."""
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(dest_path, 'wb') as file, tqdm(
        desc=os.path.basename(dest_path),
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            size = file.write(data)
            bar.update(size)

def extract_archive(archive_path, extract_to):
    """Extract zip or tar.gz archives."""
    if archive_path.endswith('.zip'):
        with zipfile.ZipFile(archive_path, 'r') as zip_ref:
            zip_ref.extractall(extract_to)
    elif archive_path.endswith(('.tar.gz', '.tgz')):
        with tarfile.open(archive_path, 'r:gz') as tar_ref:
            tar_ref.extractall(extract_to)

def download_audio_dataset():
    """Download LA dataset for audio deepfake detection."""
    print("Downloading LA audio dataset...")
    url = "https://datashare.ed.ac.uk/bitstream/handle/10283/3336/LA.zip?sequence=3&isAllowed=y"
    zip_path = "LA.zip"
    extract_to = "data/audio"
    
    os.makedirs(extract_to, exist_ok=True)
    
    if not os.path.exists(zip_path):
        download_file(url, zip_path)
    
    if not os.path.exists(os.path.join(extract_to, "LA")):
        print("Extracting LA dataset...")
        extract_archive(zip_path, extract_to)
    
    print("LA audio dataset ready.")

def download_video_dataset():
    """Download FF++ dataset subset for video deepfake detection."""
    print("Downloading FF++ video dataset subset...")
    # Using a smaller subset for faster training
    url = "https://github.com/ondyari/FaceForensics/releases/download/v1.0/faceforensics_data.zip"
    zip_path = "faceforensics_data.zip"
    extract_to = "data/video"
    
    os.makedirs(extract_to, exist_ok=True)
    
    if not os.path.exists(zip_path):
        download_file(url, zip_path)
    
    if not os.path.exists(os.path.join(extract_to, "faceforensics_data")):
        print("Extracting FF++ dataset...")
        extract_archive(zip_path, extract_to)
    
    print("FF++ video dataset ready.")

# Download datasets
os.makedirs("data", exist_ok=True)
download_audio_dataset()
download_video_dataset()

In [None]:
# Data Preprocessing and Dataset Classes

class AudioDataset(Dataset):
    def __init__(self, audio_paths, labels, clip_seconds=3):
        self.audio_paths = audio_paths
        self.labels = labels
        self.clip_seconds = clip_seconds
    
    def __len__(self):
        return len(self.audio_paths)
    
    def __getitem__(self, idx):
        audio_path = self.audio_paths[idx]
        label = self.labels[idx]
        
        # Load audio at 16kHz
        y, sr = librosa.load(audio_path, sr=16000, mono=True)
        
        # Trim/pad to fixed length
        target_length = 16000 * self.clip_seconds
        if len(y) > target_length:
            y = y[:target_length]
        else:
            y = np.pad(y, (0, target_length - len(y)))
        
        return torch.from_numpy(y).float(), torch.tensor(label, dtype=torch.float)

class VideoDataset(Dataset):
    def __init__(self, video_paths, labels, num_frames=8, frame_size=(224, 224)):
        self.video_paths = video_paths
        self.labels = labels
        self.num_frames = num_frames
        self.frame_size = frame_size
        self.transform = transforms.Compose([
            transforms.Resize(frame_size),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    
    def __len__(self):
        return len(self.video_paths)
    
    def __getitem__(self, idx):
        video_path = self.video_paths[idx]
        label = self.labels[idx]
        
        cap = cv2.VideoCapture(video_path)
        frames = []
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Sample frames evenly
        frame_indices = np.linspace(0, total_frames-1, self.num_frames, dtype=int)
        
        for i in frame_indices:
            cap.set(cv2.CAP_PROP_POS_FRAMES, i)
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = Image.fromarray(frame)
                frame = self.transform(frame)
                frames.append(frame)
        
        cap.release()
        
        # Pad if not enough frames
        while len(frames) < self.num_frames:
            frames.append(torch.zeros_like(frames[0]))
        
        # Stack: [frames, C, H, W]
        video_tensor = torch.stack(frames)
        
        return video_tensor, torch.tensor(label, dtype=torch.float)

# Generate sample data paths (you'll need to implement proper data loading)
def get_sample_data():
    # Placeholder - implement based on your dataset structure
    audio_paths = []  # List of audio file paths
    audio_labels = []  # Corresponding labels (0=real, 1=fake)
    video_paths = []  # List of video file paths  
    video_labels = []  # Corresponding labels
    
    return audio_paths, audio_labels, video_paths, video_labels

# Create data loaders
audio_paths, audio_labels, video_paths, video_labels = get_sample_data()

if audio_paths:
    audio_dataset = AudioDataset(audio_paths, audio_labels)
    audio_loader = DataLoader(audio_dataset, batch_size=8, shuffle=True)

if video_paths:
    video_dataset = VideoDataset(video_paths, video_labels)
    video_loader = DataLoader(video_dataset, batch_size=4, shuffle=True)

In [None]:
# Generate Dataset CSVs

# Run the CSV generation script
!python generate_dataset_csvs.py

# Train VGGish + Logistic Regression Model

# Install torchvggish and soundfile if not already installed
# !pip install torchvggish soundfile

audio_model = train_vggish_model("audio_train.csv", "audio_val.csv")

# Initialize multimodal detector
detector = FastDeepFakeDetector()
detector.load_audio_model()

print("Training complete! VGGish + Logistic Regression model ready.")

In [None]:
# Inference Example

def predict_deepfake(audio_path=None, video_path=None, model_path='light_deepfake_detector.pth'):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    model = LightDeepFakeDetector()
    model.load_state_dict(torch.load(model_path))
    model.to(device)
    model.eval()
    
    audio_tensor = None
    video_tensor = None
    
    if audio_path:
        # Load and preprocess audio
        y, sr = librosa.load(audio_path, sr=16000, mono=True)
        target_length = 16000 * 3  # 3 seconds
        if len(y) > target_length:
            y = y[:target_length]
        else:
            y = np.pad(y, (0, target_length - len(y)))
        audio_tensor = torch.from_numpy(y).float().unsqueeze(0).to(device)
    
    if video_path:
        # Load and preprocess video
        cap = cv2.VideoCapture(video_path)
        frames = []
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        
        for i in range(8):  # 8 frames
            ret, frame = cap.read()
            if ret:
                frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                frame = Image.fromarray(frame)
                frame = transform(frame)
                frames.append(frame)
        
        cap.release()
        
        while len(frames) < 8:
            frames.append(torch.zeros_like(frames[0]))
        
        video_tensor = torch.stack(frames).unsqueeze(0).to(device)  # [1, 8, C, H, W]
    
    with torch.no_grad():
        outputs = model(audio_input=audio_tensor, video_input=video_tensor)
        score = torch.sigmoid(outputs.squeeze()).item()
        prediction = "fake" if score > 0.5 else "real"
    
    return {
        "deepfake_score": score,
        "prediction": prediction,
        "confidence": abs(score - 0.5) * 2
    }

# Example usage:
# result = predict_deepfake(audio_path="path/to/audio.wav", video_path="path/to/video.mp4")
# print(result)