# AVFF (Audio-Visual Feature Fusion) Documentation

This documentation provides a comprehensive guide to the AVFF deepfake detection system, including explanations, code examples, and visualizations.

## Table of Contents
1. [Preprocessing](#preprocessing)
   - [Image Preprocessing](#image-preprocessing)
   - [Audio Preprocessing](#audio-preprocessing)
     - [Mel Spectrograms](#mel-spectrograms)
2. [Model Architecture](#model-architecture)
   - [Audio Encoder](#audio-encoder)
   - [Visual Encoder](#visual-encoder)
   - [Cross-Modal Fusion](#cross-modal-fusion)
3. [Training Pipeline](#training-pipeline)
4. [Inference and Evaluation](#inference)

## Preprocessing

### Image Preprocessing

The image preprocessing pipeline includes several key steps:

1. **Frame Extraction**: Extract frames from video at uniform intervals
2. **Resizing**: Resize frames to a standard size (224x224)
3. **Normalization**: Normalize pixel values using ImageNet statistics

#### Normalization Explained

The normalization step uses ImageNet statistics:
```python
transforms.Normalize(
    mean=[0.485, 0.456, 0.406],  # RGB means
    std=[0.229, 0.224, 0.225]    # RGB standard deviations
)
```

Why these specific values?
- These are the mean and standard deviation of the ImageNet dataset
- Used in pre-training of many vision models (ResNet, VGG, ViT)
- Helps with:
  - Numerical stability
  - Faster convergence
  - Better generalization

#### Code Example
```python
def extract_frames(video_path, frame_count=16):
    cap = cv2.VideoCapture(video_path)
    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    
    # Sample frames uniformly
    frame_indices = np.linspace(0, total_frames-1, frame_count, dtype=int)
    frames = []
    
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        )
    ])
    
    for idx in frame_indices:
        cap.set(cv2.CAP_PROP_POS_FRAMES, idx)
        ret, frame = cap.read()
        if ret:
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            frame = transform(frame)
            frames.append(frame)
    
    cap.release()
    return torch.stack(frames)
```

### Audio Preprocessing

The audio preprocessing pipeline includes:

1. **Audio Extraction**: Extract audio from video
2. **Resampling**: Convert to standard sampling rate (16kHz)
3. **Mel Spectrogram**: Convert to mel spectrogram representation

#### Mel Spectrograms Explained

A mel spectrogram is a visual representation of the short-term power spectrum of a sound, using a mel scale. Here's why and how we use them:

1. **What is a Mel Scale?**
   - A non-linear scale of pitches
   - Based on human perception of pitch
   - More sensitive to changes in lower frequencies
   - Less sensitive to changes in higher frequencies

2. **Why Use Mel Spectrograms?**
   - Better represents how humans perceive sound
   - Captures important speech characteristics
   - Reduces dimensionality while preserving important features
   - Standard input format for many audio models

3. **How They're Created**
   ```python
   mel_spec = torchaudio.transforms.MelSpectrogram(
       sample_rate=16000,    # Audio sampling rate
       n_fft=1024,          # Size of FFT window
       hop_length=512,       # Number of samples between successive frames
       n_mels=64            # Number of mel bands
   )(waveform)
   ```

   The process involves:
   1. Short-time Fourier transform (STFT)
   2. Convert to mel scale using triangular filters
   3. Convert to log scale

4. **Visualization Example**
   ```python
   def plot_mel_spectrogram(mel_spec):
       plt.figure(figsize=(10, 4))
       plt.imshow(mel_spec.squeeze().numpy(), 
                  aspect='auto', 
                  origin='lower')
       plt.colorbar(format='%+2.0f dB')
       plt.title('Mel Spectrogram')
       plt.xlabel('Time')
       plt.ylabel('Mel Frequency')
       plt.show()
   ```

#### Resampling Explained

Why resample to 16kHz?
- Different audio files have different sampling rates
- 16kHz is standard for speech processing
- Sufficient for capturing speech frequencies (300Hz-3kHz)
- Memory efficient while maintaining quality

#### Code Example
```python
def extract_audio(video_path, sample_rate=16000, duration=1.0):
    # Extract audio using ffmpeg
    !ffmpeg -y -i {video_path} -vn -ar {sample_rate} -ac 1 -f wav temp_audio.wav
    
    # Load audio and compute mel spectrogram
    waveform, sr = torchaudio.load('temp_audio.wav')
    if sr != sample_rate:
        resampler = torchaudio.transforms.Resample(sr, sample_rate)
        waveform = resampler(waveform)
    
    # Convert to mono if stereo
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0, keepdim=True)
    
    # Extract segment of specified length
    target_length = int(duration * sample_rate)
    if waveform.shape[1] > target_length:
        start = torch.randint(0, waveform.shape[1] - target_length, (1,))
        waveform = waveform[:, start:start + target_length]
    else:
        pad_length = target_length - waveform.shape[1]
        waveform = torch.nn.functional.pad(waveform, (0, pad_length))
    
    # Compute mel spectrogram
    mel_spec = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=1024,
        hop_length=512,
        n_mels=64
    )(waveform)
    
    return mel_spec
```

## Model Architecture

The AVFF model consists of several key components:

1. **Audio Encoder**: Processes audio features
2. **Visual Encoder**: Processes visual features
3. **Cross-Modal Fusion**: Combines audio and visual features
4. **Decoder**: Reconstructs input features
5. **Classifier**: Makes final deepfake prediction

### Audio Encoder

Uses Wav2Vec2 pre-trained model for audio feature extraction:
```python
class AudioEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = AutoModel.from_pretrained('facebook/wav2vec2-base')
        self.projection = nn.Linear(768, 512)
        
    def forward(self, x):
        features = self.encoder(x).last_hidden_state
        features = torch.mean(features, dim=1)
        return self.projection(features)
```

### Visual Encoder

Uses ViT (Vision Transformer) for visual feature extraction:
```python
class VisualEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = timm.create_model('vit_base_patch16_224', pretrained=True)
        self.projection = nn.Linear(768, 512)
        
    def forward(self, x):
        batch_size, frames = x.shape[:2]
        x = x.view(-1, *x.shape[2:])
        features = self.encoder(x)
        features = features.view(batch_size, frames, -1)
        features = torch.mean(features, dim=1)
        return self.projection(features)
```

### Cross-Modal Fusion

Combines audio and visual features using attention:
```python
class CrossModalFusion(nn.Module):
    def __init__(self, dim=512):
        super().__init__()
        self.attention = nn.MultiheadAttention(dim, num_heads=8)
        self.fusion_mlp = nn.Sequential(
            nn.Linear(dim * 2, dim),
            nn.ReLU(),
            nn.Linear(dim, dim)
        )
    
    def forward(self, visual_features, audio_features):
        attn_out, _ = self.attention(
            visual_features.unsqueeze(0),
            audio_features.unsqueeze(0),
            audio_features.unsqueeze(0)
        )
        fused = self.fusion_mlp(torch.cat([visual_features, attn_out.squeeze(0)], dim=1))
        return fused
```

## Training Pipeline

The training process consists of two stages:

1. **Self-Supervised Learning**:
   - Train encoder-decoder architecture
   - Reconstruct input features
   - Learn meaningful representations

2. **Supervised Learning**:
   - Fine-tune for deepfake detection
   - Use labeled data
   - Optimize for classification

## Inference and Evaluation

The inference pipeline:
1. Preprocess input video
2. Extract features using encoders
3. Fuse features
4. Make prediction

```python
def process_video(video_path):
    # Extract frames and audio
    frames = extract_frames(video_path)
    mel_spec = extract_audio(video_path)
    
    # Move to device
    frames = frames.unsqueeze(0).to(device)
    mel_spec = mel_spec.unsqueeze(0).to(device)
    
    # Get features
    with torch.no_grad():
        audio_features = audio_encoder(mel_spec)
        visual_features = visual_encoder(frames)
        fused_features = fusion(visual_features, audio_features)
        prediction = classifier(fused_features)
    
    return {
        'prediction': prediction.item(),
        'audio_features': audio_features.cpu().numpy(),
        'visual_features': visual_features.cpu().numpy(),
        'fused_features': fused_features.cpu().numpy()
    }
```

## References

1. AVFF Paper (CVPR 2024)
2. Wav2Vec2: [facebook/wav2vec2-base](https://huggingface.co/facebook/wav2vec2-base)
3. ViT: [timm/vit_base_patch16_224](https://github.com/huggingface/pytorch-image-models)