# Deepfake Audio Detection

This notebook demonstrates how to run inference on a fine-tuned Wav2Vec2 model for deepfake audio detection available on Hugging Face:

[`garystafford/wav2vec2-deepfake-voice-detector`](https://huggingface.co/garystafford/wav2vec2-deepfake-voice-detector)

**Make sure you use acceleration with GPU T4.**

## Step 1: Environment Setup

Confirm required dependencies are already installed by Colab. This includes PyTorch with CUDA support, Librosa, Transformers, and audio processing libraries.

In [None]:
%pip list | grep 'librosa\|transformers\|torch\|ipywidgets'

## Step 2: Load Model onto GPU

Load the HuggingFace model onto the Notebook's GPU.

In [None]:
import torch
from transformers import AutoFeatureExtractor, AutoModelForAudioClassification

model_name = "garystafford/wav2vec2-deepfake-voice-detector"

print("Loading model from Hugging Face...")
model = AutoModelForAudioClassification.from_pretrained(model_name)
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device).eval()
print(f"Model loaded on: {device}\n")

print(f"Model:\n{model}")

## Step 3: Inference

Load the fine-tuned model from Hugging Face and run inference on your own audio files.

**Requirements:**
- Audio file (mp3, wav, or flac format)
- 16kHz sample rate recommended (will auto-convert if different)

**Quick start:**
1. Run the cells
2. Select an audio file
3. Get prediction: real or fake audio

**Note**: The first inference call will be slower than proceeding calls.

In [None]:
import os
import tempfile
import librosa
import torch


def run_inference_on_audio_bytes(audio_bytes: bytes, threshold=0.4):
    # Write bytes to a temporary file so librosa.load can use a path
    with tempfile.NamedTemporaryFile(suffix=".flac", delete=False) as tmp:
        tmp.write(audio_bytes)
        tmp_path = tmp.name

    try:
        # Load and resample to 16kHz
        audio, sr = librosa.load(tmp_path, sr=16000, mono=True)

        # Extract features
        inputs = feature_extractor(
            audio, sampling_rate=16000, return_tensors="pt", padding=True
        )
        inputs = {k: v.to(device) for k, v in inputs.items()}

        # Run inference
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits

        # Convert to probabilities
        probs = torch.nn.functional.softmax(logits, dim=-1)
        prob_real = probs[0][0].item()
        prob_fake = probs[0][1].item()

        # Make prediction
        prediction = "fake" if prob_fake >= threshold else "real"

        return {
            "prediction": prediction,
            "confidence": max(prob_real, prob_fake),
            "probabilities": {"real": prob_real, "fake": prob_fake},
        }
    finally:
        # Clean up temp file
        try:
            os.remove(tmp_path)
        except OSError:
            pass

In [None]:
import json
import ipywidgets as widgets
import IPython.display as ipd

upload_widget = widgets.FileUpload(accept=".wav,.mp3,.flac,.ogg,.m4a", multiple=False)
status_label = widgets.Label(value="Upload an audio file to run deepfake detection.")
output_area = widgets.Output()


def on_upload_change(change):
    with output_area:
        output_area.clear_output()

        if not upload_widget.value:
            return

        v = upload_widget.value

        if isinstance(v, dict):
            # ipywidgets 7.x: dict of filename -> {"metadata": {...}, "content": bytes}
            file_info = next(iter(v.values()))
            file_name = file_info["metadata"]["name"]
            content = file_info["content"]  # already bytes in many 7.x builds
        else:
            # ipywidgets 8.x: list of dicts with top-level keys; content is usually memoryview
            file_info = v[0]
            file_name = file_info["name"]
            content = file_info["content"]
            if hasattr(content, "tobytes"):
                content = content.tobytes()  # convert memoryview to bytes

        audio_bytes = content  # ensure this is bytes
        # ... rest of your pipeline ...

        print(f"File: {file_name}")
        print(f"Size: {len(audio_bytes)} bytes")

        # Audio widget: pass bytes so no rate is required
        try:
            ipd.display(ipd.Audio(data=audio_bytes))  # bytes, not ndarray
        except Exception as e:
            print(f"Could not render audio widget: {e}")

        # Inference using pipeline (via temp file)
        try:
            status_label.value = f"Running inference on {file_name} ..."
            result = run_inference_on_audio_bytes(audio_bytes)
            print("\nInference result:")
            print(json.dumps(result, indent=2))
            status_label.value = f"Inference completed for {file_name}."
        except Exception as e:
            print("\nError during inference:")
            print(e)
            status_label.value = f"Error during inference for {file_name}."


upload_widget.observe(on_upload_change, names="value")

ui = widgets.VBox([upload_widget, status_label, output_area])
ui