# AI-Generated Voice Detection (Multilingual) - Kaggle

This Kaggle notebook trains a robust AI-generated voice detector following **"Measuring the Robustness of Audio Deepfake Detectors" (Li et al., 2025)**.

**How to run on Kaggle:**
1. Create a new Notebook (or copy this one).
2. **Settings (right panel)**: Set **Accelerator** to **GPU** (P100 or T4). Enable **Internet** (needed to download dataset and models).
3. **Option A – Download dataset**: Run the download + unzip cells; the dataset will be saved under `/kaggle/working/dataset`.
4. **Option B – Use a Kaggle dataset**: Upload your dataset to Kaggle (Datasets), add it to this notebook (Add Data), then set `DATASET_ROOT` in the config cell to the path shown (e.g. `/kaggle/input/your-dataset-slug/dataset`). Skip the download and unzip cells.
5. Run all cells. The trained model is saved under `/kaggle/working/ai_voice_detector_w2vbert` and will appear in the notebook Output for download.

In [17]:
# Install required packages (Kaggle)
!pip install -q transformers datasets librosa pydub tqdm scikit-learn

In [18]:
# Check GPU availability
import torch

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
    DEVICE = torch.device("cuda")
else:
    print("WARNING: No GPU detected! Training will be slow.")
    DEVICE = torch.device("cpu")

print(f"\nUsing device: {DEVICE}")

PyTorch version: 2.8.0+cu126
CUDA available: True
GPU: Tesla T4
GPU Memory: 15.64 GB

Using device: cuda


In [19]:
# Download dataset to Kaggle working directory (writable)
DATASET_URL = "https://huggingface.co/datasets/kimnamjoon0007/AI_Detection/resolve/main/dataset.zip"

!wget -q --show-progress -O /kaggle/working/dataset.zip "{DATASET_URL}"
print("Download complete!")

Download complete!


In [20]:
# Unzip dataset into Kaggle working directory
import zipfile
import os
import shutil
import glob

WORKING_DATASET = "/kaggle/working/dataset"
ZIP_PATH = "/kaggle/working/dataset.zip"

if os.path.exists(WORKING_DATASET):
    shutil.rmtree(WORKING_DATASET)
os.makedirs(WORKING_DATASET, exist_ok=True)

print("Extracting dataset.zip into", WORKING_DATASET, "...")
with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
    zip_ref.extractall(WORKING_DATASET)

print("\nExtracted contents:")
for item in os.listdir(WORKING_DATASET):
    item_path = os.path.join(WORKING_DATASET, item)
    if os.path.isdir(item_path):
        file_count = len([f for f in os.listdir(item_path) if os.path.isfile(os.path.join(item_path, f))])
        print(f"  - {item}/ ({file_count} files)")
    else:
        print(f"  - {item}")

nested_dataset = os.path.join(WORKING_DATASET, 'dataset')
if os.path.exists(nested_dataset):
    print("\nDetected nested 'dataset' folder, reorganizing...")
    for item in os.listdir(nested_dataset):
        src = os.path.join(nested_dataset, item)
        dst = os.path.join(WORKING_DATASET, item)
        if os.path.exists(dst):
            shutil.rmtree(dst)
        shutil.move(src, dst)
    os.rmdir(nested_dataset)

print("\n" + "="*50)
print("FINAL DATASET STRUCTURE:")
print("="*50)
for root, dirs, files in os.walk(WORKING_DATASET):
    level = root.replace(WORKING_DATASET, '').count(os.sep)
    indent = '  ' * level
    folder_name = os.path.basename(root) if root != WORKING_DATASET else 'dataset'
    print(f"{indent}{folder_name}/")
    if level < 2:
        subindent = '  ' * (level + 1)
        audio_files = [f for f in files if f.endswith(('.wav', '.mp3', '.flac', '.ogg', '.m4a'))]
        for file in audio_files[:3]:
            print(f"{subindent}{file}")
        if len(audio_files) > 3:
            print(f"{subindent}... and {len(audio_files)-3} more audio files")

total_audio = 0
for ext in ['*.wav', '*.mp3', '*.flac', '*.ogg', '*.m4a']:
    total_audio += len(glob.glob(f"{WORKING_DATASET}/**/{ext}", recursive=True))
print(f"\nTotal audio files found: {total_audio}")

Extracting dataset.zip into /kaggle/working/dataset ...

Extracted contents:
  - telugu/ (0 files)
  - english/ (0 files)
  - tamil/ (0 files)
  - test/ (7 files)
  - fake/ (30 files)
  - real/ (30 files)
  - hindi/ (0 files)
  - malayalam/ (0 files)

FINAL DATASET STRUCTURE:
dataset/
  telugu/
    fake/
    real/
  english/
    fake/
    real/
  tamil/
    fake/
    real/
  test/
    Hasan.wav
    Test(Urdu Real Audio).wav
    Test(Urdu Clone).wav
    ... and 4 more audio files
  fake/
    Davis.wav
    Henri.wav
    Shakir.wav
    ... and 27 more audio files
  real/
    clip_28.wav
    clip_13.wav
    clip_2.wav
    ... and 27 more audio files
  hindi/
    fake/
    real/
  malayalam/
    fake/
    real/

Total audio files found: 6567


In [21]:
# Configuration
import os

# Root folder for the audio dataset (Kaggle: use /kaggle/working/dataset or /kaggle/input/...).
# Expected structure:
# dataset_root/
#   Tamil/
#     HUMAN/*.mp3
#     AI_GENERATED/*.mp3
#   English/
#     HUMAN/*.mp3
#     AI_GENERATED/*.mp3
#   Hindi/...
#   Malayalam/...
#   Telugu/...

DATASET_ROOT = "/kaggle/working/dataset"  # or /kaggle/input/your-dataset-slug/... if you added a Kaggle dataset

# Target languages (fixed by the problem statement)
TARGET_LANGUAGES = ["Tamil", "English", "Hindi", "Malayalam", "Telugu"]

# Labels
LABEL_TO_ID = {"HUMAN": 0, "AI_GENERATED": 1}
ID_TO_LABEL = {v: k for k, v in LABEL_TO_ID.items()}

# Audio settings
TARGET_SAMPLING_RATE = 16000  # 16 kHz, as used in the paper experiments
MAX_DURATION_SECONDS = 10.0   # clips longer than this will be randomly cropped during training

# Model and training hyperparameters
MODEL_NAME = "facebook/wav2vec2-large-xlsr-53"  # Wave2Vec2BERT backbone
BATCH_SIZE = 8
NUM_EPOCHS = 5
LEARNING_RATE = 1e-5
WARMUP_RATIO = 0.1
WEIGHT_DECAY = 0.01

OUTPUT_DIR = "/kaggle/working/ai_voice_detector_w2vbert"  # saved model will appear in Notebook Output
os.makedirs(OUTPUT_DIR, exist_ok=True)

print("Config loaded.")

Config loaded.


In [22]:
# Data discovery and index construction

import glob
from dataclasses import dataclass
from typing import List

@dataclass
class AudioSample:
    path: str
    language: str
    label: str  # "HUMAN" or "AI_GENERATED"


def discover_dataset(dataset_root: str) -> List[AudioSample]:
    """
    Scan the dataset directory and build an index of (path, language, label).

    Assumes structure like:
      dataset/
        english/
          human/ or real/
          ai/ or fake/ or ai_generated/
        tamil/
        ...

    Folder names can be lowercase; we normalize them here.
    """
    samples: List[AudioSample] = []

    # Possible on-disk names for each canonical label
    label_dir_map = {
        "HUMAN": ["HUMAN", "human", "real", "Real"],
        "AI_GENERATED": ["AI_GENERATED", "ai_generated", "ai", "fake", "Fake"],
    }

    for language in TARGET_LANGUAGES:
        # Most HF zips use lowercase language folder names
        lang_dir = os.path.join(dataset_root, language.lower())
        if not os.path.isdir(lang_dir):
            print(f"[WARN] Language folder not found: {lang_dir}")
            continue

        for canonical_label, dir_names in label_dir_map.items():
            label_found = False
            for dn in dir_names:
                label_dir = os.path.join(lang_dir, dn)
                if os.path.isdir(label_dir):
                    pattern = os.path.join(label_dir, "**", "*.mp3")
                    for path in glob.glob(pattern, recursive=True):
                        samples.append(
                            AudioSample(path=path, language=language, label=canonical_label)
                        )
                    label_found = True
            if not label_found:
                print(f"[WARN] No folder for label {canonical_label} under {lang_dir}")

    print(f"Discovered {len(samples)} audio files across {len(TARGET_LANGUAGES)} languages.")
    return samples


all_samples = discover_dataset(DATASET_ROOT)

# Optional: peek at a few samples
for s in all_samples[:5]:
    print(s)

Discovered 6500 audio files across 5 languages.
AudioSample(path='/kaggle/working/dataset/tamil/real/3645.mp3', language='Tamil', label='HUMAN')
AudioSample(path='/kaggle/working/dataset/tamil/real/26525.mp3', language='Tamil', label='HUMAN')
AudioSample(path='/kaggle/working/dataset/tamil/real/4603.mp3', language='Tamil', label='HUMAN')
AudioSample(path='/kaggle/working/dataset/tamil/real/15771.mp3', language='Tamil', label='HUMAN')
AudioSample(path='/kaggle/working/dataset/tamil/real/23214.mp3', language='Tamil', label='HUMAN')


In [23]:
# Patched load_audio that handles corrupted files
def load_audio(path: str, target_sr: int = TARGET_SAMPLING_RATE) -> torch.Tensor:
    """Load an audio file and resample to target_sr, mono."""
    try:
        audio_segment = AudioSegment.from_file(path)
    except Exception as e:
        print(f"[WARN] Corrupted file skipped: {path}")
        # Return 0.5s of silence as a placeholder
        return torch.zeros(1, int(target_sr * 0.5))
    
    samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
    channels = audio_segment.channels
    if channels > 1:
        samples = samples.reshape(-1, channels).mean(axis=1)
    samples /= 32767.0
    sr = audio_segment.frame_rate
    if sr != target_sr:
        samples = librosa.resample(samples, orig_sr=sr, target_sr=target_sr)
    return torch.from_numpy(samples).unsqueeze(0)

print("Patched load_audio function - corrupted files will be replaced with silence.")

Patched load_audio function - corrupted files will be replaced with silence.


In [24]:
# Train/validation/test split

from sklearn.model_selection import train_test_split

if len(all_samples) == 0:
    raise RuntimeError("No audio files found. Please mount/upload your dataset and set DATASET_ROOT correctly.")

# Convert to simple lists for splitting
paths = [s.path for s in all_samples]
labels = [LABEL_TO_ID[s.label] for s in all_samples]
languages = [s.language for s in all_samples]

# Stratify by label to keep AI/HUMAN balance
train_paths, temp_paths, train_labels, temp_labels, train_langs, temp_langs = train_test_split(
    paths, labels, languages, test_size=0.3, random_state=42, stratify=labels
)

val_paths, test_paths, val_labels, test_labels, val_langs, test_langs = train_test_split(
    temp_paths, temp_labels, temp_langs, test_size=0.5, random_state=42, stratify=temp_labels
)

print(f"Train: {len(train_paths)} | Val: {len(val_paths)} | Test: {len(test_paths)}")

Train: 4550 | Val: 975 | Test: 975


In [25]:
# Robustness-oriented audio augmentations (following the paper)

import random
import io
import librosa
import numpy as np
import torch
from pydub import AudioSegment

def load_audio(path: str, target_sr: int = TARGET_SAMPLING_RATE) -> torch.Tensor:
    """Load an audio file and resample to target_sr, mono, using pydub + librosa.

    This avoids backend issues with mp3 decoding in soundfile/librosa and works
    reliably for the dataset's MP3 files.
    """
    audio_segment = AudioSegment.from_file(path)
    samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32)
    channels = audio_segment.channels
    if channels > 1:
        samples = samples.reshape(-1, channels).mean(axis=1)
    samples /= 32767.0
    sr = audio_segment.frame_rate
    if sr != target_sr:
        samples = librosa.resample(samples, orig_sr=sr, target_sr=target_sr)
    return torch.from_numpy(samples).unsqueeze(0)


def random_crop(waveform: torch.Tensor, max_duration: float, sr: int) -> torch.Tensor:
    max_len = int(max_duration * sr)
    if waveform.shape[1] <= max_len:
        return waveform
    start = random.randint(0, waveform.shape[1] - max_len)
    return waveform[:, start:start + max_len]


# 1. Noise perturbation (Gaussian / background-like noise)

def add_gaussian_noise(waveform: torch.Tensor, snr_db: float = 20.0) -> torch.Tensor:
    """Add Gaussian noise at a target SNR, as in the paper's noise perturbations section."""
    signal_power = waveform.pow(2).mean()
    snr_linear = 10 ** (snr_db / 10)
    noise_power = signal_power / snr_linear
    noise = torch.randn_like(waveform) * torch.sqrt(noise_power)
    return waveform + noise


# 2. Pitch shifting (spectral modification)

def pitch_shift(waveform: torch.Tensor, sr: int, n_steps_range=(-2.0, 2.0)) -> torch.Tensor:
    """Apply pitch shifting using librosa, following the paper's challenging pitch-shift corruption."""
    y = waveform.squeeze(0).cpu().numpy()
    n_steps = random.uniform(*n_steps_range)
    y_shifted = librosa.effects.pitch_shift(y, sr=sr, n_steps=n_steps)
    return torch.from_numpy(y_shifted).unsqueeze(0)


# 3. Time stretching (temporal modification)

def time_stretch(waveform: torch.Tensor, sr: int, rate_range=(0.8, 1.2)) -> torch.Tensor:
    """Apply time stretching, another corruption highlighted as particularly harmful in the paper."""
    y = waveform.squeeze(0).cpu().numpy()
    rate = random.uniform(*rate_range)
    y_stretched = librosa.effects.time_stretch(y, rate=rate)
    return torch.from_numpy(y_stretched).unsqueeze(0)


# 4. MP3 compression (proxy for codec-based corruption like MP3/Opus/Encodec)

def mp3_compress_decompress(waveform: torch.Tensor, sr: int, bitrate: str = "32k") -> torch.Tensor:
    """Simulate lossy codec corruption by round-tripping through MP3 at a low bitrate.

    The paper shows that traditional and neural codecs (MP3, Opus, Encodec, DAC, FACodec, AudioDec)
    severely hurt detectors despite high perceptual quality. This augmentation mimics that effect.
    """
    # Convert tensor to 16-bit PCM bytes
    y = waveform.squeeze(0).cpu().numpy()
    audio_segment = AudioSegment(
        (y * 32767).astype(np.int16).tobytes(),
        frame_rate=sr,
        sample_width=2,
        channels=1,
    )

    # Export to MP3 in-memory
    buf = io.BytesIO()
    audio_segment.export(buf, format="mp3", bitrate=bitrate)
    buf.seek(0)

    # Read back
    compressed = AudioSegment.from_file(buf, format="mp3")
    samples = np.array(compressed.get_array_of_samples()).astype(np.float32) / 32767.0
    return torch.from_numpy(samples).unsqueeze(0)


def apply_random_corruption(waveform: torch.Tensor, sr: int) -> torch.Tensor:
    """Randomly apply one of the corruptions (or none) during training.

    This follows the paper's robustness enhancement strategy: mix clean and corrupted
    examples, with challenging corruptions like pitch shift, time stretch, and codec
    compression to improve generalization.
    """
    # Keep some samples clean
    if random.random() < 0.3:
        return waveform

    ops = [
        lambda x: add_gaussian_noise(x, snr_db=random.choice([10.0, 20.0, 30.0])),
        lambda x: pitch_shift(x, sr),
        lambda x: time_stretch(x, sr),
        lambda x: mp3_compress_decompress(x, sr, bitrate=random.choice(["32k", "64k", "96k"])),
    ]
    op = random.choice(ops)
    try:
        return op(waveform)
    except Exception as e:
        # Fallback to original if an augmentation fails (e.g., very short clips)
        print(f"[AUGMENT WARN] {e}")
        return waveform

In [26]:
# PyTorch Dataset and DataLoader

from torch.utils.data import Dataset, DataLoader


class AIDeepfakeDataset(Dataset):
    """Binary AI/HUMAN detector dataset with optional corruption-based augmentation."""

    def __init__(self, paths, labels, languages, augment: bool = False):
        self.paths = list(paths)
        self.labels = list(labels)
        self.languages = list(languages)
        self.augment = augment

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        path = self.paths[idx]
        label = self.labels[idx]
        language = self.languages[idx]

        waveform = load_audio(path, target_sr=TARGET_SAMPLING_RATE)
        waveform = random_crop(waveform, MAX_DURATION_SECONDS, TARGET_SAMPLING_RATE)

        if self.augment:
            waveform = apply_random_corruption(waveform, TARGET_SAMPLING_RATE)

        return {
            "input_values": waveform.squeeze(0),  # (T,)
            "label": torch.tensor(label, dtype=torch.long),
            "language": language,
        }


train_dataset = AIDeepfakeDataset(train_paths, train_labels, train_langs, augment=True)
val_dataset = AIDeepfakeDataset(val_paths, val_labels, val_langs, augment=False)
test_dataset = AIDeepfakeDataset(test_paths, test_labels, test_langs, augment=False)

print(f"Train samples: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

Train samples: 4550 | Val: 975 | Test: 975


In [27]:
from transformers import AutoFeatureExtractor, Wav2Vec2Model

feature_extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)

backbone = Wav2Vec2Model.from_pretrained(MODEL_NAME)
backbone.to(DEVICE)

print("Backbone loaded.")

2026-01-29 16:40:50.650545: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1769704850.897562      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1769704850.986153      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1769704851.622741      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769704851.622801      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1769704851.622804      55 computation_placer.cc:177] computation placer alr

preprocessor_config.json:   0%|          | 0.00/212 [00:00<?, ?B/s]

config.json: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/1.27G [00:00<?, ?B/s]

Backbone loaded.


In [28]:
# Model head: utterance-level classifier on top of Wav2Vec2

import torch.nn as nn

class W2VBertDeepfakeDetector(nn.Module):
    """Wave2Vec2-based deepfake detector using a foundation speech model."""

    def __init__(self, backbone, num_labels: int = 2):
        super().__init__()
        self.backbone = backbone
        hidden_size = backbone.config.hidden_size
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(hidden_size, num_labels)

    def forward(self, input_values, attention_mask=None, labels=None):
        # input_values is already a padded batch of raw waveforms: (batch, T)
        # from collate_fn. We feed it directly to Wav2Vec2Model.

        input_vals = input_values.to(DEVICE)  # (B, T)

        # Wav2Vec2Model expects input_values (float PCM in [-1, 1])
        outputs = self.backbone(input_values=input_vals, attention_mask=attention_mask)

        hidden_states = outputs.last_hidden_state  # (batch, seq_len, hidden)
        pooled = hidden_states.mean(dim=1)
        pooled = self.dropout(pooled)
        logits = self.classifier(pooled)

        loss = None
        if labels is not None:
            labels = labels.to(DEVICE)
            loss_fn = nn.CrossEntropyLoss()
            loss = loss_fn(logits, labels)

        return {"loss": loss, "logits": logits}


model = W2VBertDeepfakeDetector(backbone, num_labels=2).to(DEVICE)
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
    print(f"Using {torch.cuda.device_count()} GPUs (DataParallel).")
print("Full model ready.")

Using 2 GPUs (DataParallel).
Full model ready.


In [29]:
# Data collator

from typing import Dict, Any


def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    input_values = [item["input_values"] for item in batch]
    labels = torch.stack([item["label"] for item in batch])
    languages = [item["language"] for item in batch]

    # We will pass raw waveforms into the model, which will internally use the
    # same feature extractor as the paper (Mel-like features for Wav2Vec2BERT).
    padded = nn.utils.rnn.pad_sequence(input_values, batch_first=True)

    return {
        "input_values": padded.to(DEVICE),
        "labels": labels.to(DEVICE),
        "languages": languages,
    }


train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn)

print("DataLoaders ready.")

DataLoaders ready.


In [30]:
# Optimizer, scheduler, and metrics (Accuracy, AUROC, EER as in the paper)

from transformers import get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score, roc_auc_score

num_training_steps = NUM_EPOCHS * len(train_loader)
num_warmup_steps = int(WARMUP_RATIO * num_training_steps)

optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=num_training_steps,
)


def compute_eer(y_true, y_scores):
    """Compute Equal Error Rate (EER), matching the paper's main metric.

    We sweep thresholds over [0, 1] on the predicted probability of AI_GENERATED.
    """
    y_true = np.array(y_true)
    y_scores = np.array(y_scores)

    # thresholds between 0 and 1
    thresholds = np.linspace(0, 1, 200)
    fprs = []
    fnrs = []

    for th in thresholds:
        y_pred = (y_scores >= th).astype(int)
        tp = np.sum((y_true == 1) & (y_pred == 1))
        tn = np.sum((y_true == 0) & (y_pred == 0))
        fp = np.sum((y_true == 0) & (y_pred == 1))
        fn = np.sum((y_true == 1) & (y_pred == 0))

        fpr = fp / (fp + tn + 1e-8)
        fnr = fn / (fn + tp + 1e-8)
        fprs.append(fpr)
        fnrs.append(fnr)

    fprs = np.array(fprs)
    fnrs = np.array(fnrs)
    diffs = np.abs(fprs - fnrs)
    idx = np.argmin(diffs)
    eer = (fprs[idx] + fnrs[idx]) / 2.0
    return float(eer)


print("Optimizer, scheduler, and metrics ready.")

Optimizer, scheduler, and metrics ready.


In [32]:
# Training loop (with validation using Accuracy, AUROC, EER)

from tqdm.auto import tqdm


def evaluate(loader, desc="Val"):
    model.eval()
    all_labels = []
    all_probs_ai = []  # probability of AI_GENERATED (class 1)

    with torch.no_grad():
        for batch in tqdm(loader, desc=desc):
            outputs = model(input_values=batch["input_values"], labels=None)
            logits = outputs["logits"]
            probs = torch.softmax(logits, dim=-1)[:, 1]  # P(AI_GENERATED)

            all_labels.extend(batch["labels"].cpu().numpy().tolist())
            all_probs_ai.extend(probs.cpu().numpy().tolist())

    # Use 0.5 threshold for accuracy, as in typical binary classification
    preds = (np.array(all_probs_ai) >= 0.5).astype(int)
    acc = accuracy_score(all_labels, preds)

    try:
        auroc = roc_auc_score(all_labels, all_probs_ai)
    except ValueError:
        auroc = float("nan")  # e.g., only one class present

    eer = compute_eer(all_labels, all_probs_ai)
    return {"accuracy": acc, "auroc": auroc, "eer": eer}


best_val_eer = 1.0

for epoch in range(1, NUM_EPOCHS + 1):
    model.train()
    train_loss = 0.0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}")
    for batch in pbar:
        optimizer.zero_grad()

        outputs = model(input_values=batch["input_values"], labels=batch["labels"])
        loss = outputs["loss"]

        # Fix for DataParallel: reduce loss to scalar if needed
        if loss.dim() > 0:
            loss = loss.mean()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()

        train_loss += loss.item()
        pbar.set_postfix({"loss": loss.item()})

    train_loss /= max(1, len(train_loader))

    # Validation
    metrics = evaluate(val_loader, desc="Val")
    print(f"Epoch {epoch}: train_loss={train_loss:.4f}, val_acc={metrics['accuracy']:.4f}, "
          f"val_auroc={metrics['auroc']:.4f}, val_eer={metrics['eer']:.4f}")

    # Save best model by EER (lower is better), as in the paper
    if metrics["eer"] < best_val_eer:
        best_val_eer = metrics["eer"]
        state_to_save = model.module.state_dict() if hasattr(model, "module") else model.state_dict()
        torch.save(state_to_save, os.path.join(OUTPUT_DIR, "best_model.pt"))
        print("Saved new best model (by EER).")

print("Training complete.")

Epoch 1/5:   0%|          | 0/569 [00:00<?, ?it/s]



CouldntDecodeError: Decoding failed. ffmpeg returned error code: 1

Output from ffmpeg/avlib:

ffmpeg version 4.4.2-0ubuntu0.22.04.1 Copyright (c) 2000-2021 the FFmpeg developers
  built with gcc 11 (Ubuntu 11.2.0-19ubuntu1)
  configuration: --prefix=/usr --extra-version=0ubuntu0.22.04.1 --toolchain=hardened --libdir=/usr/lib/x86_64-linux-gnu --incdir=/usr/include/x86_64-linux-gnu --arch=amd64 --enable-gpl --disable-stripping --enable-gnutls --enable-ladspa --enable-libaom --enable-libass --enable-libbluray --enable-libbs2b --enable-libcaca --enable-libcdio --enable-libcodec2 --enable-libdav1d --enable-libflite --enable-libfontconfig --enable-libfreetype --enable-libfribidi --enable-libgme --enable-libgsm --enable-libjack --enable-libmp3lame --enable-libmysofa --enable-libopenjpeg --enable-libopenmpt --enable-libopus --enable-libpulse --enable-librabbitmq --enable-librubberband --enable-libshine --enable-libsnappy --enable-libsoxr --enable-libspeex --enable-libsrt --enable-libssh --enable-libtheora --enable-libtwolame --enable-libvidstab --enable-libvorbis --enable-libvpx --enable-libwebp --enable-libx265 --enable-libxml2 --enable-libxvid --enable-libzimg --enable-libzmq --enable-libzvbi --enable-lv2 --enable-omx --enable-openal --enable-opencl --enable-opengl --enable-sdl2 --enable-pocketsphinx --enable-librsvg --enable-libmfx --enable-libdc1394 --enable-libdrm --enable-libiec61883 --enable-chromaprint --enable-frei0r --enable-libx264 --enable-shared
  libavutil      56. 70.100 / 56. 70.100
  libavcodec     58.134.100 / 58.134.100
  libavformat    58. 76.100 / 58. 76.100
  libavdevice    58. 13.100 / 58. 13.100
  libavfilter     7.110.100 /  7.110.100
  libswscale      5.  9.100 /  5.  9.100
  libswresample   3.  9.100 /  3.  9.100
  libpostproc    55.  9.100 / 55.  9.100
[mp3 @ 0x584dde1b44c0] Format mp3 detected only with low score of 1, misdetection possible!
[mp3 @ 0x584dde1b44c0] Failed to read frame size: Could not seek to 1026.
/kaggle/working/dataset/tamil/fake/tamil_0330.mp3: Invalid argument


In [None]:
# Final evaluation on the held-out test set

# Load best model (by validation EER)
state = torch.load(os.path.join(OUTPUT_DIR, "best_model.pt"), map_location=DEVICE)
if hasattr(model, "module"):
    model.module.load_state_dict(state)
else:
    model.load_state_dict(state)
model.to(DEVICE)

metrics_test = evaluate(test_loader, desc="Test")
print("Test metrics:")
print(f"  Accuracy: {metrics_test['accuracy']:.4f}")
print(f"  AUROC:    {metrics_test['auroc']:.4f}")
print(f"  EER:      {metrics_test['eer']:.4f}")

In [None]:
# Save model + feature extractor for deployment

from transformers import Wav2Vec2Model

# Save backbone configuration and feature extractor alongside the fine-tuned head weights
backbone.save_pretrained(OUTPUT_DIR)
feature_extractor.save_pretrained(OUTPUT_DIR)

print(f"Saved backbone and feature extractor to {OUTPUT_DIR}")
print("For deployment, also keep 'best_model.pt' (classifier head + backbone weights).")

In [None]:
# Inference helper aligned with the problem statement (for API integration later)

import base64
import io
import numpy as np
import librosa
from pydub import AudioSegment


def load_model_for_inference(model_dir: str):
    """Reload the fine-tuned model + feature extractor for inference.

    This is what your API server should do on startup.
    """
    feat_extractor = AutoFeatureExtractor.from_pretrained(model_dir)
    backb = Wav2Vec2Model.from_pretrained(model_dir)
    det = W2VBertDeepfakeDetector(backb, num_labels=2)
    det.load_state_dict(torch.load(os.path.join(model_dir, "best_model.pt"), map_location=DEVICE))
    det.to(DEVICE)
    det.eval()
    return feat_extractor, det


def classify_base64_mp3(audio_base64: str, language: str, model_dir: str = OUTPUT_DIR):
    """Classify a single Base64-encoded MP3 as AI_GENERATED or HUMAN.

    This mirrors the problem statement:
    - Input: one Base64 MP3 and a language (Tamil/English/Hindi/Malayalam/Telugu)
    - Output: classification, confidence score, and a short explanation.
    """
    if language not in TARGET_LANGUAGES:
        raise ValueError(f"Unsupported language: {language}. Must be one of {TARGET_LANGUAGES}.")

    feature_extractor_inf, model_inf = load_model_for_inference(model_dir)

    # Decode base64 to raw bytes
    audio_bytes = base64.b64decode(audio_base64)
    buf = io.BytesIO(audio_bytes)

    # Decode MP3 bytes with pydub
    audio_segment = AudioSegment.from_file(buf, format="mp3")
    samples = np.array(audio_segment.get_array_of_samples()).astype(np.float32) / 32767.0
    sr = audio_segment.frame_rate

    # Resample to TARGET_SAMPLING_RATE if needed (librosa)
    if sr != TARGET_SAMPLING_RATE:
        samples = librosa.resample(samples, orig_sr=sr, target_sr=TARGET_SAMPLING_RATE)

    waveform = torch.from_numpy(samples).unsqueeze(0)
    waveform = random_crop(waveform, MAX_DURATION_SECONDS, TARGET_SAMPLING_RATE)

    # Run through model
    with torch.no_grad():
        features = feature_extractor_inf(
            waveform.squeeze(0),
            sampling_rate=TARGET_SAMPLING_RATE,
            return_tensors="pt",
            padding=True,
        )
        # For Wav2Vec2* models, the feature extractor outputs "input_values"
        input_vals = features["input_values"].to(DEVICE)
        attention_mask = features.get("attention_mask")
        if attention_mask is not None:
            attention_mask = attention_mask.to(DEVICE)

        outputs = model_inf.backbone(input_values=input_vals, attention_mask=attention_mask)
        hidden_states = outputs.last_hidden_state
        pooled = hidden_states.mean(dim=1)
        logits = model_inf.classifier(pooled)
        probs = torch.softmax(logits, dim=-1).cpu().numpy()[0]

    prob_human = float(probs[0])
    prob_ai = float(probs[1])

    if prob_ai >= prob_human:
        classification = "AI_GENERATED"
        confidence = prob_ai
        explanation = (
            "Model detected spectral and temporal artifacts consistent with AI-synthesized speech, "
            "and these patterns remained distinguishable even under compression- and pitch/time-based "
            "corruptions similar to those studied in the robustness paper."
        )
    else:
        classification = "HUMAN"
        confidence = prob_human
        explanation = (
            "Signal characteristics (prosody, spectral detail, and temporal variation) align with "
            "human speech patterns and do not exhibit the codec- and modification-sensitive artifacts "
            "associated with AI deepfakes in the robustness study."
        )

    return {
        "status": "success",
        "language": language,
        "classification": classification,
        "confidenceScore": float(confidence),
        "explanation": explanation,
    }


print("Inference helpers defined. You can now wire this into your REST API.")