# üèãÔ∏è Training on Google Colab
## Compression-Aware Video Deepfake Detection

This notebook trains all 3 model variants directly using face crops on Google Drive.

**Pre-requisite:** Face crops already extracted to `/content/drive/MyDrive/ffpp_faces/`

üìå **Runtime ‚Üí Change runtime type ‚Üí GPU (T4)**

## 1Ô∏è‚É£ Setup

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Clone repo and install dependencies
!git clone https://github.com/its-simran-ch/compression_aware_deepfake.git /content/project
%cd /content/project
!pip install -q -r requirements.txt

In [None]:
# Verify GPU + packages
import torch
print(f'PyTorch: {torch.__version__}')
print(f'CUDA:    {torch.cuda.is_available()}')
if torch.cuda.is_available():
    print(f'GPU:     {torch.cuda.get_device_name(0)}')

from facenet_pytorch import MTCNN
print('MTCNN:   OK')
import pywt
print('PyWavelets: OK')

## 2Ô∏è‚É£ Verify Dataset

In [None]:
import pandas as pd
import os

DATA_ROOT = '/content/drive/MyDrive/ffpp_faces'
METADATA_CSV = f'{DATA_ROOT}/metadata.csv'
OUTPUT_DIR = '/content/drive/MyDrive/deepfake_results'  # Save results to Drive!

df = pd.read_csv(METADATA_CSV)
print(f'‚úÖ Total face crops: {len(df)}')
print(f'\nBy split:')
print(df['split'].value_counts())
print(f'\nBy label:')
print(df['label'].value_counts())
print(f'\nBy compression:')
print(df['compression'].value_counts())

# Verify a sample image
sample = df.iloc[0]
sample_path = os.path.join(DATA_ROOT, sample['frame_path'])
print(f'\nSample image exists: {os.path.exists(sample_path)}')

## 3Ô∏è‚É£ Quick Sanity Check (2-3 min)

Run a tiny test to make sure training works before the full run.

In [None]:
# Quick test: 1 epoch, limited samples
!python src/training/train_ffpp.py \
    --metadata_csv {METADATA_CSV} \
    --data_root {DATA_ROOT} \
    --mode hybrid \
    --compressions c23 \
    --epochs 1 \
    --batch_size 8 \
    --max_train_samples 100 \
    --max_val_samples 50 \
    --output_dir /content/test_run \
    --experiment_name quick_test

print('\n‚úÖ Sanity check passed! Training pipeline works.')

---
## 4Ô∏è‚É£ Train Hybrid Model (Main Experiment)

‚è±Ô∏è **~1.5-2 hours** on T4 GPU

Results are saved to Google Drive so they persist even if Colab disconnects.

In [None]:
!python src/training/train_ffpp.py \
    --metadata_csv {METADATA_CSV} \
    --data_root {DATA_ROOT} \
    --mode hybrid \
    --compressions c23 c40 \
    --epochs 15 \
    --batch_size 16 \
    --lr 1e-4 \
    --output_dir {OUTPUT_DIR} \
    --experiment_name hybrid_c23_c40

## 5Ô∏è‚É£ Train Baseline Models (Ablation)

In [None]:
# Spatial-only baseline
!python src/training/train_ffpp.py \
    --metadata_csv {METADATA_CSV} \
    --data_root {DATA_ROOT} \
    --mode spatial \
    --compressions c23 c40 \
    --epochs 15 \
    --batch_size 16 \
    --lr 1e-4 \
    --output_dir {OUTPUT_DIR} \
    --experiment_name spatial_c23_c40

In [None]:
# Frequency-only baseline
!python src/training/train_ffpp.py \
    --metadata_csv {METADATA_CSV} \
    --data_root {DATA_ROOT} \
    --mode frequency \
    --compressions c23 c40 \
    --epochs 15 \
    --batch_size 16 \
    --lr 1e-4 \
    --output_dir {OUTPUT_DIR} \
    --experiment_name frequency_c23_c40

## 6Ô∏è‚É£ Evaluate Per Compression Level

This evaluates each model on c23 and c40 separately ‚Äî crucial for the paper's compression robustness analysis.

In [None]:
import sys
sys.path.insert(0, '/content/project')

import torch
import numpy as np
import pandas as pd
import csv
from torch.utils.data import DataLoader
from src.datasets.ffpp_dataset import FFPPFrameDataset
from src.models.fusion_classifier import HybridDeepfakeClassifier
from src.utils.metrics import compute_metrics

device = 'cuda' if torch.cuda.is_available() else 'cpu'

eval_results = []

for mode in ['hybrid', 'spatial', 'frequency']:
    ckpt_path = f'{OUTPUT_DIR}/checkpoints/best_{mode}_c23_c40.pth'
    if not os.path.exists(ckpt_path):
        print(f'‚ö†Ô∏è Skipping {mode} ‚Äî checkpoint not found')
        continue
    
    # Load model
    model = HybridDeepfakeClassifier(mode=mode, pretrained_spatial=False).to(device)
    ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
    model.load_state_dict(ckpt['model_state_dict'])
    model.eval()
    print(f'\n{"="*50}')
    print(f'Evaluating: {mode} (best epoch {ckpt["epoch"]}, train AUC={ckpt["val_auc"]:.4f})')
    
    for comp in ['c23', 'c40']:
        include_dwt = mode in ('frequency', 'hybrid')
        test_ds = FFPPFrameDataset(
            metadata_csv=METADATA_CSV,
            root_dir=DATA_ROOT,
            split='test',
            compressions=[comp],
            include_dwt=include_dwt,
        )
        test_loader = DataLoader(test_ds, batch_size=32, shuffle=False, num_workers=2)
        
        all_labels, all_probs = [], []
        with torch.no_grad():
            for batch in test_loader:
                labels = batch['label'].float()
                rgb = batch['rgb'].to(device) if mode != 'frequency' else None
                dwt = batch['dwt'].to(device) if mode != 'spatial' else None
                logits = model(rgb_input=rgb, dwt_input=dwt)
                probs = torch.sigmoid(logits).cpu().numpy()
                all_probs.extend(probs)
                all_labels.extend(labels.numpy())
        
        preds = (np.array(all_probs) >= 0.5).astype(int)
        metrics = compute_metrics(np.array(all_labels), preds, np.array(all_probs))
        
        eval_results.append({
            'mode': mode, 'compression': comp,
            'accuracy': metrics['accuracy'], 'f1': metrics['f1'],
            'auc': metrics['auc'],
        })
        print(f'  {comp}: Acc={metrics["accuracy"]:.4f}  F1={metrics["f1"]:.4f}  AUC={metrics["auc"]:.4f}')

# Save results
df_eval = pd.DataFrame(eval_results)
df_eval.to_csv(f'{OUTPUT_DIR}/csv/compression_eval_summary.csv', index=False)
print(f'\n\nüìä Results Summary:')
display(df_eval)

## 7Ô∏è‚É£ Generate Paper Plots

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

os.makedirs(f'{OUTPUT_DIR}/plots', exist_ok=True)

# ‚îÄ‚îÄ Compression Robustness Line Plot ‚îÄ‚îÄ
if len(eval_results) > 0:
    fig, ax = plt.subplots(figsize=(8, 5))
    colors = {'hybrid': '#2ec4b6', 'spatial': '#4361ee', 'frequency': '#f77f00'}
    
    for mode in ['hybrid', 'spatial', 'frequency']:
        mode_data = [r for r in eval_results if r['mode'] == mode]
        if mode_data:
            comps = [r['compression'] for r in mode_data]
            aucs = [r['auc'] for r in mode_data]
            ax.plot(comps, aucs, 'o-', label=mode.capitalize(),
                    color=colors[mode], linewidth=2.5, markersize=10)
    
    ax.set_xlabel('Compression Level', fontsize=12)
    ax.set_ylabel('AUC', fontsize=12)
    ax.set_title('Detection Performance Across Compression Levels', fontsize=14)
    ax.legend(fontsize=11)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0.5, 1.02)
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/plots/compression_robustness.png', dpi=200, bbox_inches='tight')
    plt.show()
    print('Saved: compression_robustness.png')

In [None]:
# ‚îÄ‚îÄ Ablation Bar Chart ‚îÄ‚îÄ
if len(eval_results) > 0:
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    colors = {'hybrid': '#2ec4b6', 'spatial': '#4361ee', 'frequency': '#f77f00'}
    
    # Average AUC across compressions
    avg_metrics = {}
    for mode in ['hybrid', 'spatial', 'frequency']:
        mode_data = [r for r in eval_results if r['mode'] == mode]
        if mode_data:
            avg_metrics[mode] = {
                'auc': np.mean([r['auc'] for r in mode_data]),
                'f1': np.mean([r['f1'] for r in mode_data]),
            }
    
    modes = list(avg_metrics.keys())
    aucs = [avg_metrics[m]['auc'] for m in modes]
    f1s = [avg_metrics[m]['f1'] for m in modes]
    bar_colors = [colors[m] for m in modes]
    
    bars1 = ax1.bar(modes, aucs, color=bar_colors, alpha=0.85, width=0.5)
    for bar, val in zip(bars1, aucs):
        ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                 f'{val:.4f}', ha='center', fontweight='bold')
    ax1.set_ylabel('AUC'); ax1.set_title('Average AUC by Model')
    ax1.set_ylim(0.5, 1.05); ax1.grid(True, alpha=0.2, axis='y')
    
    bars2 = ax2.bar(modes, f1s, color=bar_colors, alpha=0.85, width=0.5)
    for bar, val in zip(bars2, f1s):
        ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.005,
                 f'{val:.4f}', ha='center', fontweight='bold')
    ax2.set_ylabel('F1 Score'); ax2.set_title('Average F1 by Model')
    ax2.set_ylim(0.5, 1.05); ax2.grid(True, alpha=0.2, axis='y')
    
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/plots/ablation_comparison.png', dpi=200, bbox_inches='tight')
    plt.show()
    print('Saved: ablation_comparison.png')

In [None]:
# ‚îÄ‚îÄ Training Curves ‚îÄ‚îÄ
fig, axes = plt.subplots(1, 2, figsize=(14, 5))
colors = {'hybrid': '#2ec4b6', 'spatial': '#4361ee', 'frequency': '#f77f00'}

for mode in ['hybrid', 'spatial', 'frequency']:
    log_path = f'{OUTPUT_DIR}/csv/train_log_{mode}_c23_c40.csv'
    if not os.path.exists(log_path):
        continue
    df_log = pd.read_csv(log_path)
    color = colors[mode]
    
    axes[0].plot(df_log['epoch'], df_log['val_loss'], label=mode.capitalize(),
                 color=color, linewidth=2)
    axes[1].plot(df_log['epoch'], df_log['val_auc'], label=mode.capitalize(),
                 color=color, linewidth=2, marker='o', markersize=4)

axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Validation Loss')
axes[0].set_title('Validation Loss'); axes[0].legend(); axes[0].grid(True, alpha=0.3)
axes[1].set_xlabel('Epoch'); axes[1].set_ylabel('Validation AUC')
axes[1].set_title('Validation AUC'); axes[1].legend(); axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/plots/training_curves.png', dpi=200, bbox_inches='tight')
plt.show()
print('Saved: training_curves.png')

## 8Ô∏è‚É£ Results for Paper

In [None]:
# Final results table
if len(eval_results) > 0:
    df_final = pd.DataFrame(eval_results)
    
    # Pivot: rows=mode, cols=compression, values=AUC
    pivot_auc = df_final.pivot_table(values='auc', index='mode', columns='compression')
    print('üìä AUC by Model √ó Compression:')
    print('='*40)
    display(pivot_auc.round(4))
    
    pivot_f1 = df_final.pivot_table(values='f1', index='mode', columns='compression')
    print('\nüìä F1 by Model √ó Compression:')
    print('='*40)
    display(pivot_f1.round(4))
    
    # LaTeX for paper
    print('\nüìù LaTeX table (paste into paper):')
    print(pivot_auc.round(4).to_latex())

In [None]:
# List all saved files
print('üì¶ All results saved to Google Drive:')
for root, dirs, files in os.walk(OUTPUT_DIR):
    for f in files:
        full = os.path.join(root, f)
        size = os.path.getsize(full) / (1024*1024)
        rel = os.path.relpath(full, OUTPUT_DIR)
        print(f'  {rel}  ({size:.1f} MB)')