# ASVspoof DF Evaluation Notebook

This notebook performs inference on your own list of WAV files using a pretrained DF detection model.  
The file `yourtts_wav_files.txt` contains the full paths to your audio files.  
We import the model and necessary functions from the original repository without modifying any files.

**Steps:**
1. Import dependencies and define utility functions.
2. Define a custom dataset to load your WAV files.
3. Load your file list.
4. Create a DataLoader.
5. Load the pretrained model and set it to evaluation mode.
6. Run inference and collect scores.

In [None]:
# Optional Cell: Generate WAV file list from directory
import os

def get_audio_files_recursive(
    folder_path,
    formats=('wav', 'flac'),
    exclude_prefix=('.', '__'),
    required_path_segment=None
):
    """
    Recursively find audio files in a directory with specified extensions,
    optionally requiring a specific segment in their path.
    
    Args:
        folder_path (str): Root directory to search.
        formats (tuple): File extensions to include (lowercase).
        exclude_prefix (tuple): Skip directories starting with these prefixes.
        required_path_segment (str, optional): 
            A string that must be part of the file's path.
            Example: '/test/', os.sep + 'test' + os.sep, or 'test_file'.
    
    Returns:
        list: Full paths to audio files.
    """
    valid_files = []
    normalized_segment_for_dir_check = None

    if required_path_segment:
        # Heuristic: if it looks like a path segment for a directory
        is_dir_segment = (
            required_path_segment.startswith(os.sep) or
            required_path_segment.endswith(os.sep) or
            os.sep in required_path_segment
        )
        if is_dir_segment:
            normalized_segment_for_dir_check = os.path.normpath(
                required_path_segment
            )
            if not normalized_segment_for_dir_check.startswith(os.sep):
                normalized_segment_for_dir_check = (
                    os.sep + normalized_segment_for_dir_check
                )
            if not normalized_segment_for_dir_check.endswith(os.sep):
                normalized_segment_for_dir_check = (
                    normalized_segment_for_dir_check + os.sep
                )
        else: # Treat as a general substring
            normalized_segment_for_dir_check = required_path_segment

    for root, dirs, files in os.walk(folder_path):
        # Skip hidden/special directories
        dirs[:] = [d for d in dirs if not d.startswith(exclude_prefix)]
        
        for file in files:
            ext = os.path.splitext(file)[1][1:].lower() # Get ext without dot
            if ext in formats:
                full_path = os.path.join(root, file)
                normalized_full_path = os.path.normpath(full_path)

                if normalized_segment_for_dir_check:
                    if normalized_segment_for_dir_check in normalized_full_path:
                        valid_files.append(full_path)
                else: # Otherwise, add all found audio files
                    valid_files.append(full_path)
                
    return valid_files

# Updated input folder to point to the root of your synthetic datasets
input_folder = './data/dataset_sintetico/'

# --- Configuration for path filtering ---
# Example 1: Filter for 'test' directory
# required_segment_filter = os.sep + 'test' + os.sep 
# Example 2: Filter for paths containing 'test' as a substring
required_segment_filter = '/test/'
# Example 3: No path segment filtering
# required_segment_filter = None
# --- End Configuration ---

file_paths = get_audio_files_recursive(
    input_folder,
    formats=('wav', 'flac'),
    required_path_segment=required_segment_filter
)

# --- For debugging the paths (optional) ---
# for p in file_paths[:5]: 
#     print(p)
# print(f"Total paths found before filtering (if any): {len(file_paths)}")
# print("---")
# --- End debugging ---

# Determine file types found for the print message
if file_paths:
    found_formats = sorted(list(set(
        [os.path.splitext(f)[1] for f in file_paths]
    )))
    formats_str = ', '.join(found_formats)
else:
    formats_str = 'N/A'

# Updated print statement for better readability
if required_segment_filter:
    display_filter = (
        required_segment_filter.strip(os.sep)
        if os.sep in required_segment_filter
        else required_segment_filter
    )
    print(
        f"Found {len(file_paths)} audio files with paths containing "
        f"'{display_filter}' ({formats_str} formats)"
    )
else:
    print(
        f"Found {len(file_paths)} audio files "
        f"(no path segment filter applied) ({formats_str} formats)"
    )

In [None]:
# Cell 2: Imports and utility functions

import os
import torch
import librosa
import numpy as np
from torch.utils.data import Dataset, DataLoader

# Utility function to pad or truncate an audio signal to a fixed length.
def pad(x, desired_length):
    if len(x) >= desired_length:
        return x[:desired_length]
    else:
        pad_width = desired_length - len(x)
        return np.pad(x, (0, pad_width), mode='constant')

In [None]:
# Cell 3: Define a custom dataset to load WAV files

class CustomWavDataset(Dataset):
    def __init__(self, file_list, base_folder='./data/dataset_sintetico/', sr=16000, cut=64600):
        """
        Args:
            file_list (list): List of full paths to WAV files.
            base_folder (str): Base folder to exclude from the relative path.
            sr (int): Target sampling rate.
            cut (int): Fixed length (in samples) for each audio waveform.
        """
        self.file_list = file_list
        self.base_folder = os.path.normpath(base_folder)
        self.sr = sr
        self.cut = cut

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        # Load the audio file; force sample rate to 16 kHz.
        signal, _ = librosa.load(file_path, sr=self.sr)
        signal_padded = pad(signal, self.cut)
        # Convert signal to a PyTorch tensor.
        signal_tensor = torch.tensor(signal_padded, dtype=torch.float32)
        
        # Generate a relative path as the unique identifier
        # First normalize the path to handle different path formats
        norm_file_path = os.path.normpath(file_path)
        norm_base_folder = os.path.normpath(self.base_folder)
        
        # Remove the base folder from the path to get the relative path
        if norm_file_path.startswith(norm_base_folder):
            # +1 to remove the leading slash
            rel_path = norm_file_path[len(norm_base_folder) + 1:]
        else:
            # Fallback if the path doesn't start with the base folder
            rel_path = os.path.basename(file_path)
        
        return signal_tensor, rel_path

In [None]:
# Cell 4: Skip if using Cell 1
# If Cell 1 was run, do not override file_paths
if "file_paths" not in globals():
    file_list_path = "/root/rafaello/datasets/yourtts_wav_files.txt"  # Update path if needed.
    with open(file_list_path, "r") as f:
        file_paths = [line.strip() for line in f if line.strip()]
    print(f"Loaded {len(file_paths)} wav file paths.")
else:
    print(f"Using dynamically detected {len(file_paths)} audio files from Cell 1.")

In [None]:
# Cell 5: Create the dataset and dataloader

base_folder = input_folder
dataset = CustomWavDataset(file_paths, base_folder=base_folder)
# Adjust the batch size as needed.
dataloader = DataLoader(dataset, batch_size=2, shuffle=False, drop_last=False)

In [None]:
import os
import sys
import torch
from argparse import Namespace

# Set CUDA_VISIBLE_DEVICES first
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Change if needed
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Create minimal args with required parameters
args = Namespace(
    track='DF',
    model_path='/app/SSL_Anti-spoofing/Best_LA_model_for_DF.pth',
    protocols_path='/app/SSL_Anti-spoofing/protocols/',
    database_path='/app/SSL_Anti-spoofing/database/',
    loss='WCE'
)

from model import Model # This should still work due to the .pth file
model = Model(args, device)
model = torch.nn.DataParallel(model).to(device)

checkpoint = torch.load(args.model_path, map_location=device)
model.load_state_dict(checkpoint, strict=False)

model.eval()
print("Model loaded in evaluation mode")

In [None]:
# Cell 7: Run inference with continuous save
from tqdm.autonotebook import tqdm
import os
import torch
import pickle

# Configuration
load_scores_file = "./outputs/df_detection_scores.pkl"  # File to load existing scores
save_scores_file = "./outputs/df_detection_scores.pkl"  # File to save scores
force_recalculate = True        # Set to True to ignore saved scores and recompute

# Create outputs directory if it doesn't exist
os.makedirs(os.path.dirname(save_scores_file), exist_ok=True)

# If not forcing recalculation and scores exist, load and skip inference
if not force_recalculate and os.path.exists(load_scores_file):
    print(f"Loading existing scores from {load_scores_file}...")
    with open(load_scores_file, "rb") as f:
        all_scores = pickle.load(f)
    print("Scores loaded. Skipping inference.")
else:
    print("No saved scores found or recalculation forced. Running inference...")
    all_scores = {}  # Initialize an empty dictionary

    # Identify unprocessed files
    all_utt_ids = set(file_paths)  # Assuming `file_paths` contains all expected files
    processed_utt_ids = set(all_scores.keys())

    # If all files are already processed, skip inference
    if processed_utt_ids >= all_utt_ids:
        print(f"All {len(all_utt_ids)} files are already processed. Skipping inference.")
    else:
        print(f"Processing {len(all_utt_ids - processed_utt_ids)} new files...")

        # Track progress and save results
        with torch.no_grad():
            for batch_x, utt_ids in tqdm(dataloader, 
                                         desc="Processing batches", 
                                         unit="batch",
                                         dynamic_ncols=True):
                # Skip already processed files
                utt_ids_to_process = [uid for uid in utt_ids if uid not in all_scores]
                if not utt_ids_to_process:
                    continue

                # Process the batch
                batch_x = batch_x.to(device)
                batch_out = model(batch_x)
                batch_scores = batch_out[:, 1].cpu().numpy()

                # Update scores dictionary and save immediately
                for utt_id, score in zip(utt_ids, batch_scores):
                    all_scores[utt_id] = score

                # Save scores after each batch
                with open(save_scores_file, "wb") as f:
                    pickle.dump(all_scores, f)

        print(f"\nInference complete. Processed {len(all_scores)} files.")
        print(f"Scores saved to: {save_scores_file}")

In [None]:
# Cell 8: Visualize score distribution with threshold
import matplotlib.pyplot as plt
import numpy as np  # Required for np.mean
import pickle
import os

# Ensure scores are loaded
if not 'all_scores' in globals():
    if os.path.exists(scores_file):
        print(f"Loading scores from {scores_file}...")
        with open(scores_file, "rb") as f:
            all_scores = pickle.load(f)
    else:
        raise RuntimeError("Scores file not found, and inference was not run.")

# Convert scores to a list
scores = list(all_scores.values())

# Create histogram
plt.figure(figsize=(10, 6))
n, bins, patches = plt.hist(scores, bins=100, alpha=0.7, 
                            color='blue', edgecolor='black')

# Add threshold line
threshold = -3.5324
plt.axvline(threshold, color='red', linestyle='--', 
           linewidth=2, label=f'Threshold: {threshold:.4f}')

# Add labels and title
plt.title('DF Detection Scores - F5-TTS on SSL-Anti-Spoofing', fontsize=14)
plt.xlabel('Score Value', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add count annotations
plt.text(0.05, 0.95, f'Total files: {len(scores)}', 
         transform=plt.gca().transAxes, verticalalignment='top')
plt.text(0.05, 0.90, f'Mean score: {np.mean(scores):.4f}', 
         transform=plt.gca().transAxes, verticalalignment='top')

plt.tight_layout()
plt.show()