In [10]:
import os
import pandas as pd
import numpy as np
import librosa
import soundfile as sf
import torch
import torchaudio.transforms as T
import torch.nn.functional as F
import logging
import hashlib
from typing import List, Tuple, Optional
from pathlib import Path
import pyloudnorm as pyln
import random

# Configure logging
logging.basicConfig(
    level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
)

In [11]:

# --- Configuration Section ---
class DataConfig:
    """Configuration for audio data processing parameters."""

    # General audio processing settings
    SEED = 42  # Random seed for reproducibility

    SR = 16000  # Sample rate (Hz)
    N_FFT = 2048  # FFT window size
    HOP_LENGTH = 512  # Hop length for spectrogram
    N_MELS = 128  # Number of Mel bands
    FMIN = 0.0  # Minimum frequency (Hz)
    FMAX = 8000.0  # Maximum frequency (Hz)
    TARGET_SPEC_SIZE = (224, 224) 

    # Augmentation settings
    NUM_TIME_MASKS = 2  # Number of time masks for SpecAugment
    NUM_FREQ_MASKS = 2  # Number of frequency masks for SpecAugment
    TIME_MASK_MAX_WIDTH = 100  # Maximum width of time mask
    FREQ_MASK_MAX_WIDTH = 40  # Maximum width of frequency mask
    MASK_REPLACEMENT_VALUE = -80.0  # Value for masked regions in spectrogram
    NORM_EPSILON = 1e-6  # Small value to prevent division by zero
    LOUDNESS_LUFS = -23.0  # Target loudness (LUFS)

    # Dataset and processing options
    USE_GLOBAL_NORMALIZATION = False  # Use global mean/std for normalization
    DATASET_ROOT = Path("F:\\Deepfake-Audio-Detector\\datasets\\raw_final_dataset").as_posix()
    CACHE_DIR = Path("F:\\Deepfake-Audio-Detector\\datasets\\final_dataset").as_posix()


In [12]:
random.seed(DataConfig.SEED)
np.random.seed(DataConfig.SEED)
torch.manual_seed(DataConfig.SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(DataConfig.SEED)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    logging.info("CUDA available. Applied CUDA seeds.")

2025-06-07 04:14:26,288 - INFO - CUDA available. Applied CUDA seeds.


In [13]:
class DataLoaderConfig:
    """Configuration for DataLoader creation"""

    def __init__(
        self,
        audio_length_seconds: float,
        batch_size: int,
        num_workers: int = 12,
        apply_augmentation_to_train: bool = True,
        apply_waveform_augmentation: bool = True,
        limit_files: Optional[int] = None,
        overlap_ratio: float = 0.0,
    ):
        self.audio_length_seconds = audio_length_seconds
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.apply_augmentation_to_train = apply_augmentation_to_train
        self.apply_waveform_augmentation = apply_waveform_augmentation
        self.limit_files = limit_files
        self.overlap_ratio = overlap_ratio
        self.max_frame_spec = int(
            np.ceil((audio_length_seconds * DataConfig.SR) / DataConfig.HOP_LENGTH)
        )
        logging.info(
            f"DataLoaderConfig initialized with audio_length_seconds={audio_length_seconds}, "
            f"max_frame_spec={self.max_frame_spec} frames, "
            f"(SR={DataConfig.SR}, HOP_LENGTH={DataConfig.HOP_LENGTH})"
        )


In [14]:
# --- Helper Functions ---
def _load_and_segment_audio(
    file_path: Optional[str],
    sr: int = DataConfig.SR,
    segment_length: float = 5.0,
    overlap_ratio: float = 0.0,
) -> List[np.ndarray]:
    """Load and segment audio into fixed-length parts with loudness normalization."""
    try:
        y, _ = librosa.load(str(file_path), sr=sr, mono=True)

        meter = pyln.Meter(sr)
        loudness = meter.integrated_loudness(y)
        y = pyln.normalize.loudness(y, loudness, DataConfig.LOUDNESS_LUFS)

        if np.abs(y).max() < 1e-5:
            logging.warning(f"Silent audio detected: {file_path}")
            return []

        segment_samples = int(segment_length * sr)
        segments = []

        if len(y) < segment_samples:
            padded = np.pad(y, (0, segment_samples - len(y)), "constant")
            segments.append(padded)
        else:
            step_size = max(1, int(segment_samples * (1 - overlap_ratio)))
            for i in range(0, len(y) - segment_samples + 1, step_size):
                segments.append(y[i : i + segment_samples])

        return segments
    except Exception as e:
        logging.error(f"Error processing {file_path}: {e}")
        return []


def _scale_mel_spectrogram(
    mel_spec: np.ndarray, target_size: Tuple[int, int] = DataConfig.TARGET_SPEC_SIZE
) -> np.ndarray:
    """Scale Mel-spectrogram to target size (e.g., 224x224)."""
    mel_spec_tensor = (
        torch.from_numpy(mel_spec).float().unsqueeze(0).unsqueeze(0)
    )  # Shape: (1, 1, n_mels, frames)
    mel_spec_scaled = F.interpolate(
        mel_spec_tensor, size=target_size, mode="bilinear", align_corners=False
    )
    return mel_spec_scaled.squeeze(0).squeeze(0).numpy()


def _audio_to_mel_spectrogram(
    y: np.ndarray,
    sr: int = DataConfig.SR,
    n_fft: int = DataConfig.N_FFT,
    hop_length: int = DataConfig.HOP_LENGTH,
    n_mels: int = DataConfig.N_MELS,
    fmin: float = DataConfig.FMIN,
    fmax: float = DataConfig.FMAX,
) -> np.ndarray:
    """Convert waveform to Mel-spectrogram and scale to target size."""
    if y is None or len(y) == 0:
        return np.full(
            DataConfig.TARGET_SPEC_SIZE,
            DataConfig.MASK_REPLACEMENT_VALUE,
            dtype=np.float32,
        )

    mel_spec = librosa.feature.melspectrogram(
        y=y,
        sr=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        fmin=fmin,
        fmax=fmax,
    )
    mel_spec_db = librosa.power_to_db(mel_spec, ref=np.max)

    # Scale to target size
    mel_spec_db = _scale_mel_spectrogram(mel_spec_db)
    assert mel_spec_db.shape == DataConfig.TARGET_SPEC_SIZE, (
        f"Unexpected shape {mel_spec_db.shape}, expected {DataConfig.TARGET_SPEC_SIZE}"
    )
    logging.debug(f"Scaled spectrogram to {DataConfig.TARGET_SPEC_SIZE}")

    return mel_spec_db


def _compute_global_stats(
    filepaths: List[Optional[str]], segment_length: float, max_frames_spec: int
) -> Tuple[float, float]:
    """Compute global mean and std of spectrograms for normalization."""
    means, stds = [], []
    for file_path in filepaths:
        segments = _load_and_segment_audio(file_path, segment_length=segment_length)
        if not segments:
            continue
        sampled_segments = random.sample(segments, min(3, len(segments)))
        for seg in sampled_segments:
            mel_spec = _audio_to_mel_spectrogram(seg)
            means.append(mel_spec.mean())
            stds.append(mel_spec.std())

    return float(np.mean(means)), float(np.mean(stds) + DataConfig.NORM_EPSILON)


In [15]:
class SpecAugment(torch.nn.Module):
    """Implements SpecAugment for spectrogram augmentation."""

    def __init__(self):
        super().__init__()
        self.freq_mask = T.FrequencyMasking(
            freq_mask_param=DataConfig.FREQ_MASK_MAX_WIDTH
        )
        self.time_mask = T.TimeMasking(time_mask_param=DataConfig.TIME_MASK_MAX_WIDTH)
        self.num_freq_masks = DataConfig.NUM_FREQ_MASKS
        self.num_time_masks = DataConfig.NUM_TIME_MASKS

    def forward(self, spec: torch.Tensor) -> torch.Tensor:
        """Apply frequency and time masking to spectrogram."""
        if spec.ndim == 4:
            spec = spec.squeeze(1)
        elif spec.ndim == 3 and spec.shape[0] == 1:
            spec = spec.squeeze(0).clone()
        elif spec.ndim not in [2, 3]:
            logging.warning(
                f"Unexpected spectrogram shape: {spec.shape}. Skipping augmentation."
            )
            return spec

        for _ in range(self.num_freq_masks):
            spec = (
                self.freq_mask(spec)
                if spec.ndim == 2
                else T.FrequencyMasking(self.freq_mask.freq_mask_param)(spec)
            )

        for _ in range(self.num_time_masks):
            spec = (
                self.time_mask(spec)
                if spec.ndim == 2
                else T.TimeMasking(self.time_mask.time_mask_param)(spec)
            )

        if spec.ndim == 2:
            spec = spec.unsqueeze(0)
        return spec


class WaveformAugment:
    """Implements waveform-level augmentations."""

    def __init__(self):
        self.sr = DataConfig.SR
        self.pitch_shift = T.PitchShift(sample_rate=self.sr, n_steps=2)

    def apply(self, y: np.ndarray) -> np.ndarray:
        """Apply random waveform augmentations."""
        y_tensor = torch.from_numpy(y).float()

        if random.random() < 0.3:
            noise = torch.randn_like(y_tensor) * 0.005
            y_tensor = y_tensor + noise

        if random.random() < 0.3:
            y_tensor = self.pitch_shift(y_tensor.unsqueeze(0)).squeeze(0)

        if random.random() < 0.3:
            rate = random.uniform(0.8, 1.2)
            y_numpy = librosa.effects.time_stretch(y_tensor.detach().numpy(), rate=rate)
            original_len = len(y)
            if len(y_numpy) < original_len:
                y_numpy = np.pad(y_numpy, (0, original_len - len(y_numpy)), "constant")
            elif len(y_numpy) > original_len:
                y_numpy = y_numpy[:original_len]
            y_tensor = torch.from_numpy(y_numpy).float()

        return y_tensor.detach().cpu().numpy()


In [16]:
class ModelConfig:
    """Configuration for model-specific dataset creation."""

    def __init__(
        self,
        name: str,
        audio_length_seconds: float,
        overlap_ratio: float,
        apply_augmentation: bool = False,
        apply_waveform_augmentation: bool = False,
    ):
        self.name = name
        self.audio_length_seconds = audio_length_seconds
        self.overlap_ratio = overlap_ratio
        self.apply_augmentation = apply_augmentation
        self.apply_waveform_augmentation = apply_waveform_augmentation
        self.max_frames_spec = int(
            np.ceil((audio_length_seconds * DataConfig.SR) / DataConfig.HOP_LENGTH)
        )
        logging.info(
            f"ModelConfig for {name}: audio_length={audio_length_seconds}s, "
            f"max_frames={self.max_frames_spec}, overlap_ratio={overlap_ratio}"
        )


In [17]:
class DatasetCreator:
    """Manages the creation of cached datasets for models."""

    def __init__(self, model_configs: List[ModelConfig]):
        self.model_configs = model_configs
        self.label_mapping = {"real": 0, "fake": 1}
        self.spec_augmenter = SpecAugment()
        self.waveform_augmenter = WaveformAugment()
        from datetime import datetime

        self.creation_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")

    def generate_metadata(self, set_type: str) -> pd.DataFrame:
        """Generate metadata from WAV files in the directory structure."""
        base_dir = Path(DataConfig.DATASET_ROOT) / set_type
        metadata_records = []

        for label_dir in ["fake", "real"]:
            label_path = base_dir / label_dir
            if not label_path.exists():
                logging.warning(f"Directory not found: {label_path}")
                continue

            for wav_file in label_path.glob("*.wav"):
                relative_path = wav_file.relative_to(DataConfig.DATASET_ROOT).as_posix()
                try:
                    info = sf.info(wav_file.as_posix())
                    audio_duration = info.duration
                    if (
                        info.frames == 0
                        or np.abs(
                            librosa.load(
                                wav_file.as_posix(), sr=DataConfig.SR, duration=1.0
                            )[0]
                        ).max()
                        < 1e-5
                    ):
                        logging.warning(f"Empty or silent audio detected: {wav_file}")
                        continue
                except Exception as e:
                    logging.error(f"Error processing {wav_file}: {e}")
                    continue

                file_hash = hashlib.md5(wav_file.as_posix().encode()).hexdigest()
                metadata_records.append(
                    {
                        "path": relative_path,
                        "label": label_dir,
                        "fake_level": 0
                        if label_dir == "real"
                        else 1,  # Giả định fake_level = 0 cho real, 1 cho fake
                        "creation_time": self.creation_time,
                        "audio_duration": audio_duration,
                        "hash_value": file_hash,
                    }
                )

        if not metadata_records:
            logging.warning(f"No valid WAV files found in {base_dir}")
            return pd.DataFrame()

        metadata_df = pd.DataFrame(metadata_records)
        metadata_path = base_dir / f"combined_metadata_{set_type}.csv"
        metadata_df.to_csv(metadata_path.as_posix(), index=False)
        logging.info(
            f"Generated and saved metadata to {metadata_path} with {len(metadata_df)} samples"
        )
        return metadata_df

    def load_metadata(self, set_type: str) -> pd.DataFrame:
        """Load metadata for a given set type (train/val/test)."""
        metadata_path = (
            Path(DataConfig.DATASET_ROOT)
            / set_type
            / f"combined_metadata_{set_type}.csv"
        )
        metadata_path = metadata_path.as_posix()
        if not os.path.exists(metadata_path):
            logging.info(
                f"Metadata file not found at {metadata_path}. Generating metadata..."
            )
            return self.generate_metadata(set_type)
        df = pd.read_csv(metadata_path).dropna(subset=["path"]).reset_index(drop=True)
        logging.info(f"Loaded {len(df)} samples from {metadata_path}")
        return df

    def validate_and_get_full_path(
        self, set_type: str, audio_path_relative: str
    ) -> Optional[str]:
        """Validate audio file path and return full path."""
        full_path = Path(DataConfig.DATASET_ROOT) / audio_path_relative
        full_path = full_path.as_posix()

        if not os.path.exists(full_path):
            logging.warning(f"Audio file not found: {full_path}")
            return None

        if full_path.lower().endswith((".wav", ".flac")):
            try:
                info = sf.info(full_path)
                if info.frames == 0:
                    logging.warning(f"Empty audio file: {full_path}")
                    return None
                y_check, _ = librosa.load(
                    full_path, sr=DataConfig.SR, mono=True, duration=1.0
                )
                if np.abs(y_check).max() < 1e-5:
                    logging.warning(f"Silent audio detected: {full_path}")
                    return None
            except Exception as e:
                logging.error(f"Error validating {full_path}: {e}")
                return None
        else:
            logging.warning(f"Unsupported audio format: {full_path}")
            return None

        return full_path

    def create_cached_datasets(self):
        """Create cached datasets for each model configuration."""
        for model_config in self.model_configs:
            cache_root = Path(DataConfig.CACHE_DIR) / f"{model_config.name}_dataset"
            os.makedirs(cache_root.as_posix(), exist_ok=True)

            for set_type in ["train", "val", "test"]:
                logging.info(f"Processing {set_type} set for {model_config.name}")
                metadata_df = self.load_metadata(set_type)
                if metadata_df.empty:
                    continue

                global_mean, global_std = 0.0, 1.0
                if DataConfig.USE_GLOBAL_NORMALIZATION:
                    filepaths = [
                        self.validate_and_get_full_path(set_type, row["path"])
                        for _, row in metadata_df.iterrows()
                        if self.validate_and_get_full_path(set_type, row["path"])
                        is not None
                    ]
                    if filepaths:
                        global_mean, global_std = _compute_global_stats(
                            filepaths,
                            model_config.audio_length_seconds,
                            model_config.max_frames_spec,
                        )
                        logging.info(
                            f"Global stats ({set_type}): Mean={global_mean:.4f}, Std={global_std:.4f}"
                        )
                    else:
                        logging.warning(
                            f"No valid audio files to compute global stats for {set_type}."
                        )

                set_cache_dir = cache_root / set_type
                os.makedirs(set_cache_dir.as_posix(), exist_ok=True)
                metadata_records = []

                for _, row in metadata_df.iterrows():
                    audio_path_relative_in_csv = row["path"]
                    label_str = row["label"]

                    try:
                        fake_level = int(row.get("fake_level", 0))
                    except (ValueError, TypeError):
                        logging.warning(
                            f"Invalid fake_level '{row.get('fake_level', 'N/A')}' for sample {audio_path_relative_in_csv}. Using default 0."
                        )
                        fake_level = 0

                    full_path = self.validate_and_get_full_path(
                        set_type, audio_path_relative_in_csv
                    )
                    if not full_path:
                        continue

                    info = sf.info(full_path)
                    audio_duration = info.duration

                    segments = _load_and_segment_audio(
                        full_path,
                        segment_length=model_config.audio_length_seconds,
                        overlap_ratio=model_config.overlap_ratio,
                    )
                    if not segments:
                        continue

                    for seg_idx, seg in enumerate(segments):
                        processed_seg = seg
                        if (
                            model_config.apply_waveform_augmentation
                            and set_type == "train"
                        ):
                            processed_seg = self.waveform_augmenter.apply(processed_seg)

                        mel_spec = _audio_to_mel_spectrogram(processed_seg)
                        mel_spec_tensor = torch.from_numpy(mel_spec).float()

                        if DataConfig.USE_GLOBAL_NORMALIZATION:
                            mel_spec_tensor = (
                                mel_spec_tensor - global_mean
                            ) / global_std
                        else:
                            mean_val = mel_spec_tensor.mean()
                            std_val = mel_spec_tensor.std() + DataConfig.NORM_EPSILON
                            mel_spec_tensor = (mel_spec_tensor - mean_val) / std_val

                        if model_config.apply_augmentation and set_type == "train":
                            mel_spec_tensor = self.spec_augmenter(
                                mel_spec_tensor.unsqueeze(0)
                            ).squeeze(0)

                        label_dir = "real" if label_str == "real" else "fake"
                        sample_cache_dir = set_cache_dir / label_dir
                        os.makedirs(sample_cache_dir.as_posix(), exist_ok=True)

                        file_hash = hashlib.md5(
                            f"{full_path}_{seg_idx}".encode()
                        ).hexdigest()
                        npy_path = sample_cache_dir / f"{file_hash}.npy"
                        if not os.path.exists(npy_path):
                            np.save(npy_path.as_posix(), mel_spec_tensor.numpy())

                        metadata_records.append(
                            {
                                "npy_path": npy_path.as_posix(),
                                "original_path": audio_path_relative_in_csv,
                                "label": self.label_mapping[label_str],
                                "fake_level": fake_level,
                                "segment_index": seg_idx,
                                "creation_time": self.creation_time,
                                "audio_duration": audio_duration,
                                "hash_value": file_hash,
                            }
                        )

                if metadata_records:
                    metadata_df_processed = pd.DataFrame(metadata_records)
                    metadata_output_path = set_cache_dir / "metadata.csv"
                    metadata_df_processed.to_csv(
                        metadata_output_path.as_posix(), index=False
                    )
                    logging.info(
                        f"Saved {len(metadata_df_processed)} samples to {metadata_output_path}"
                    )
                else:
                    logging.warning(
                        f"No processed samples for {set_type} for {model_config.name}"
                    )


In [18]:
logging.info("Starting dataset caching process...")

cache_dir = Path(DataConfig.CACHE_DIR)
if cache_dir.exists():
    import shutil

    shutil.rmtree(cache_dir.as_posix())
    logging.info(f"Removed existing {cache_dir} directory.")
os.makedirs(cache_dir.as_posix(), exist_ok=True)

vit_config = ModelConfig(
    name="vit_3s",
    audio_length_seconds=3.0,
    overlap_ratio=0.5,
    apply_augmentation=True,
    apply_waveform_augmentation=True,
)

cnn_config = ModelConfig(
    name="cnn_3s",
    audio_length_seconds=3.0,
    overlap_ratio=0.5,
    apply_augmentation=True,
    apply_waveform_augmentation=True,
)

model_configurations = [
    vit_config,
    cnn_config,
]

creator = DatasetCreator(model_configurations)
creator.create_cached_datasets()

logging.info("Dataset caching process completed.")


2025-06-07 04:14:26,413 - INFO - Starting dataset caching process...
2025-06-07 04:14:26,414 - INFO - Removed existing F:\Deepfake-Audio-Detector\datasets\final_dataset directory.
2025-06-07 04:14:26,415 - INFO - ModelConfig for vit_3s: audio_length=3.0s, max_frames=94, overlap_ratio=0.5
2025-06-07 04:14:26,416 - INFO - ModelConfig for cnn_3s: audio_length=3.0s, max_frames=94, overlap_ratio=0.5
2025-06-07 04:14:26,417 - INFO - Processing train set for vit_3s
2025-06-07 04:14:26,418 - INFO - Metadata file not found at F:/Deepfake-Audio-Detector/datasets/raw_final_dataset/train/combined_metadata_train.csv. Generating metadata...
2025-06-07 04:15:54,174 - INFO - Generated and saved metadata to F:\Deepfake-Audio-Detector\datasets\raw_final_dataset\train\combined_metadata_train.csv with 102896 samples
2025-06-07 05:53:21,455 - INFO - Saved 102896 samples to F:\Deepfake-Audio-Detector\datasets\final_dataset\vit_3s_dataset\train\metadata.csv
2025-06-07 05:53:21,455 - INFO - Processing val set

In [19]:
BASE_DIR = "F:\\Deepfake-Audio-Detector\\datasets\\final_dataset"

for dataset in ["cnn_3s_dataset", "vit_3s_dataset"]:
    for set_type in ["train", "test", "val"]:
        fake_files = len(os.listdir(os.path.join(BASE_DIR, dataset, set_type, "fake")))
        real_files = len(os.listdir(os.path.join(BASE_DIR, dataset, set_type, "real")))
        print(f"{dataset.upper()} \\ {set_type} have {real_files} real files and {fake_files} fake files")

CNN_3S_DATASET \ train have 51448 real files and 51448 fake files
CNN_3S_DATASET \ test have 7033 real files and 7033 fake files
CNN_3S_DATASET \ val have 3498 real files and 3498 fake files
VIT_3S_DATASET \ train have 51448 real files and 51448 fake files
VIT_3S_DATASET \ test have 7033 real files and 7033 fake files
VIT_3S_DATASET \ val have 3498 real files and 3498 fake files
