# 🔍 AVFF Deepfake Detection Demo

This notebook demonstrates the AVFF (Audio-Visual Feature Fusion) pipeline for video deepfake detection using pre-trained models.

## Pipeline Overview
1. 🔧 Setup & Dependencies
2. 📥 Load & Preprocess Sample Video
3. 🎛️ Build Audio & Visual Encoders
4. 🔁 Cross-Modal Feature Fusion
5. 📦 Reconstruction via Decoder
6. 🧠 Deepfake Classification
7. 🧪 Run Inference
8. 📈 Visualize & Interpret Results

## 1. 🔧 Setup & Dependencies

In [None]:
!pip install torch torchvision torchaudio transformers timm einops librosa opencv-python matplotlib scikit-learn

In [None]:
import torch
import torch.nn as nn
import torchaudio
import torchvision.transforms as transforms
import cv2
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import librosa
from transformers import AutoModel, AutoFeatureExtractor
import timm
from einops import rearrange

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. 📥 Load & Preprocess Sample Video

In [None]:
def extract_frames(video_path, frame_count=16):
    """Extract frames from video and preprocess them."""
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Sample frames uniformly
    frame_indices = np.linspace(0, total_frames-1, frame_count, dtype=int)
    frames = []
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = transform(frame)
            frames.append(frame)
    
    cap.release()
    return torch.stack(frames)

def extract_audio(video_path, sample_rate=16000, duration=1.0):
    """Extract audio from video and compute mel spectrogram."""
    # Extract audio using ffmpeg
    !ffmpeg -y -i {video_path} -vn -ar {sample_rate} -ac 1 -f wav temp_audio.wav
    
    # Load audio and compute mel spectrogram
    waveform, sr = torchaudio.load('temp_audio.wav')
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Extract segment of specified length
    target_length = int(duration * sample_rate)
    if waveform.shape[1] > target_length:
        start = torch.randint(0, waveform.shape[1] - target_length, (1,))
        waveform = waveform[:, start:start + target_length]
    else:
        pad_length = target_length - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, (0, pad_length))
    
    # Compute mel spectrogram
    mel_spec = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=1024,
        hop_length=512,
        n_mels=64
    )(waveform)
    
    return mel_spec

# Test with a sample video
video_path = 'sample.mp4'  # Replace with your video path
frames = extract_frames(video_path)
mel_spec = extract_audio(video_path)

print(f'Frames shape: {frames.shape}')
print(f'Mel spectrogram shape: {mel_spec.shape}')

# Visualize
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.imshow(frames[0].permute(1, 2, 0).numpy())
plt.title('Sample Frame')
plt.subplot(1, 2, 2)
plt.imshow(mel_spec.squeeze().numpy(), aspect='auto', origin='lower')
plt.title('Mel Spectrogram')
plt.show()

## 3. 🎛️ Build Audio & Visual Encoders

We'll use pre-trained models for both audio and visual encoding:

In [None]:
class AudioEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Use Wav2Vec2 for audio encoding
        self.encoder = AutoModel.from_pretrained('facebook/wav2vec2-base')
        self.projection = nn.Linear(768, 512)  # Project to common dimension
        
    def forward(self, x):
        # x: [batch_size, 1, time]
        features = self.encoder(x).last_hidden_state
        # Pool over time dimension
        features = torch.mean(features, dim=1)
        return self.projection(features)

class VisualEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        # Use ViT for visual encoding
        self.encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.projection = nn.Linear(768, 512)  # Project to common dimension
        
    def forward(self, x):
        # x: [batch_size, frames, channels, height, width]
        batch_size, frames = x.shape[:2]
        x = x.view(-1, *x.shape[2:])  # Combine batch and frame dimensions
        features = self.encoder(x)
        features = features.view(batch_size, frames, -1)
        # Pool over frames
        features = torch.mean(features, dim=1)
        return self.projection(features)

# Initialize encoders
audio_encoder = AudioEncoder().to(device)
visual_encoder = VisualEncoder().to(device)

# Test encoders
with torch.no_grad():
    audio_features = audio_encoder(mel_spec.unsqueeze(0).to(device))
    visual_features = visual_encoder(frames.unsqueeze(0).to(device))
    
print(f'Audio features shape: {audio_features.shape}')
print(f'Visual features shape: {visual_features.shape}')

## 4. 🔁 Cross-Modal Feature Fusion

In [None]:
class CrossModalFusion(nn.Module):
    def __init__(self, dim=512):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=8)
        self.fusion_mlp = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, visual_features, audio_features):
        # Cross-attention
        attn_out, _ = self.attention(
            visual_features.unsqueeze(0),
            audio_features.unsqueeze(0),
            audio_features.unsqueeze(0)
        )
        
        # Fusion
        fused = self.fusion_mlp(torch.cat([visual_features, attn_out.squeeze(0)], dim=1))
        return fused

# Initialize fusion module
fusion = CrossModalFusion().to(device)

# Test fusion
with torch.no_grad():
    fused_features = fusion(visual_features, audio_features)
print(f'Fused features shape: {fused_features.shape}')

## 5. 📦 Reconstruction via Decoder

In [None]:
class Decoder(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256):
        super().__init__()
        self.decoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, 64 * 64)  # Reconstruct mel spectrogram
        )
    
    def forward(self, x):
        return self.decoder(x)

# Initialize decoder
decoder = Decoder().to(device)

# Test reconstruction
with torch.no_grad():
    reconstructed = decoder(fused_features)
    reconstructed = reconstructed.view(-1, 64, 64)

# Visualize reconstruction
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.imshow(mel_spec.squeeze().numpy(), aspect='auto', origin='lower')
plt.title('Original Mel Spectrogram')
plt.subplot(1, 2, 2)
plt.imshow(reconstructed.squeeze().cpu().numpy(), aspect='auto', origin='lower')
plt.title('Reconstructed Mel Spectrogram')
plt.show()

## 6. 🧠 Deepfake Classification

In [None]:
class Classifier(nn.Module):
    def __init__(self, input_dim=512):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.classifier(x)

# Initialize classifier
classifier = Classifier().to(device)

# Test classification
with torch.no_grad():
    prediction = classifier(fused_features)
print(f'Prediction: {prediction.item():.4f}')

## 7. 🧪 Run Inference on a Video

In [None]:
def process_video(video_path):
    # Extract frames and audio
    frames = extract_frames(video_path)
    mel_spec = extract_audio(video_path)
    
    # Move to device
    frames = frames.unsqueeze(0).to(device)
    mel_spec = mel_spec.unsqueeze(0).to(device)
    
    # Get features
    with torch.no_grad():
        audio_features = audio_encoder(mel_spec)
        visual_features = visual_encoder(frames)
        fused_features = fusion(visual_features, audio_features)
        prediction = classifier(fused_features)
    
    return {
        'prediction': prediction.item(),
        'audio_features': audio_features.cpu().numpy(),
        'visual_features': visual_features.cpu().numpy(),
        'fused_features': fused_features.cpu().numpy()
    }

# Test on a video
video_path = 'sample.mp4'  # Replace with your video path
results = process_video(video_path)
print(f'Deepfake probability: {results["prediction"]:.4f}')

## 8. 📈 Visualize & Interpret Results

In [None]:
def visualize_features(results):
    # Plot feature distributions
    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 3, 1)
    plt.hist(results['audio_features'].flatten(), bins=50)
    plt.title('Audio Features Distribution')
    
    plt.subplot(1, 3, 2)
    plt.hist(results['visual_features'].flatten(), bins=50)
    plt.title('Visual Features Distribution')
    
    plt.subplot(1, 3, 3)
    plt.hist(results['fused_features'].flatten(), bins=50)
    plt.title('Fused Features Distribution')
    
    plt.tight_layout()
    plt.show()

# Visualize features
visualize_features(results)