# SLSforASVspoof Deepfake Detection Notebook

This notebook demonstrates how to use the SLS (Supervised Label Smoothing) model for detecting deepfake audio samples, based on the original implementation from the [SLSforASVspoof-2021-DF](https://github.com/QiShanZhang/SLSforASVspoof-2021-DF) repository.

**Steps:**
1. Import dependencies and define utility functions
2. Load audio files from a directory
3. Define preprocessing functions for audio data
4. Load the pretrained SLS model
5. Run batch inference on audio files
6. Visualize detection scores

In [None]:
import os
import torch
import librosa
import numpy as np
import pandas as pd
import pickle
import sys
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from torch import nn

# Critical: Add the repository root to path so we can import modules correctly
# Make sure this points to the directory containing the model.py file
sys.path.append('..')  # Add parent directory to path

# Import the actual model from the repository
from model import Model

def get_audio_files_recursive(folder_path, formats=('wav', 'flac'), exclude_prefix=('.', '__')):
    """
    Recursively find audio files in a directory with specified extensions.
    
    Args:
        folder_path (str): Root directory to search.
        formats (tuple): File extensions to include (lowercase).
        exclude_prefix (tuple): Skip directories starting with these prefixes.
    
    Returns:
        list: Full paths to audio files.
    """
    valid_files = []
    for root, dirs, files in os.walk(folder_path):
        # Skip hidden or special directories
        dirs[:] = [d for d in dirs if not d.startswith(exclude_prefix)]
        for file in files:
            ext = os.path.splitext(file)[1][1:].lower()  # Extension without the dot
            if ext in formats:
                full_path = os.path.join(root, file)
                valid_files.append(full_path)
    return valid_files

# Set path to your audio files directory
input_folder = '/data/audio_files'  # Update this to your actual audio directory
file_paths = get_audio_files_recursive(input_folder, formats=('wav', 'flac'))
print(f"Found {len(file_paths)} audio files ({', '.join(set([os.path.splitext(f)[1] for f in file_paths]))} formats)")

In [None]:
# Define a simple argparse Namespace to mimic command-line arguments
# These are needed to initialize the model
class Args:
    def __init__(self):
        # Match the default arguments from main.py
        self.loss = 'weighted_CCE'
        self.track = 'DF'
        self.seed = 1234
        self.cudnn_deterministic_toggle = True
        self.cudnn_benchmark_toggle = False
        self.algo = 0  # No Rawboost augmentation for inference
        
        # Parameters for different augmentation algorithms (not used in inference)
        self.nBands = 5
        self.minF = 20
        self.maxF = 8000
        self.minBW = 100
        self.maxBW = 1000
        self.minCoeff = 10
        self.maxCoeff = 100
        self.minG = 0
        self.maxG = 0
        self.minBiasLinNonLin = 5
        self.maxBiasLinNonLin = 20
        self.N_f = 5
        self.P = 10
        self.g_sd = 2
        self.SNRmin = 10
        self.SNRmax = 40

# Create dummy args
args = Args()

In [None]:
# Preprocess audio files using a similar approach to the original code
def preprocess_audio(audio_path, sr=16000):
    """
    Load and preprocess audio file for the SLS model based on original repository's approach
    """
    try:
        # Use librosa to load the audio file
        audio, sample_rate = librosa.load(audio_path, sr=sr, mono=True)
        
        # Normalize audio to match the preprocessing in the original code
        audio = librosa.util.normalize(audio)
        
        # Handle length - original repo expects 64600 samples (approx 4 seconds at 16kHz)
        target_len = 64600
        if len(audio) < target_len:
            # Pad if shorter
            audio = np.pad(audio, (0, target_len - len(audio)))
        elif len(audio) > target_len:
            # Take the first part if longer
            audio = audio[:target_len]
            
        return audio
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

In [None]:
# Initialize the model following the approach in main.py
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create model instance using the same approach as in main.py
model = Model(args, device)

# Count parameters (optional)
nb_params = sum([param.view(-1).size()[0] for param in model.parameters()])
print(f"Model has {nb_params:,} parameters")

# Wrap with DataParallel as in original code
model = nn.DataParallel(model).to(device)

# Load pretrained model weights (adjust path if needed)
model_path = "pretrained_models/asvdf_sls_best.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

print("SLS model loaded successfully")

In [None]:
import torch.utils.data

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, sample_rate=16000):
        self.file_list = file_list
        self.sample_rate = sample_rate
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        features = preprocess_audio(file_path, sr=self.sample_rate)
        if features is None:
            # Return zeros if extraction failed
            features = np.zeros(64600)  # Match expected length from original code
        
        # Convert to tensor
        features_tensor = torch.FloatTensor(features)
        return features_tensor, file_path

# Create dataset and dataloader
batch_size = 16
num_workers = 4
pickle_file = 'sls_scores.pkl'
force_overwrite = False  # Set to True to force recalculation

if os.path.exists(pickle_file) and not force_overwrite:
    with open(pickle_file, 'rb') as f:
        all_scores = pickle.load(f)
    print("Loaded scores from pickle file. Skipping inference.")
else:
    all_scores = {}
    
    # Create dataset and dataloader
    dataset = AudioDataset(file_paths)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)
    
    # Run inference on batches
    for batch_features, batch_paths in tqdm(dataloader, desc="Processing batches"):
        batch_features = batch_features.to(device)
        
        with torch.no_grad():
            # Get model outputs - follow the scoring approach in produce_evaluation_file
            batch_out = model(batch_features)
            # Extract scores (matching the original code in produce_evaluation_file)
            batch_score = batch_out[:, 1].cpu().numpy()
        
        # Store scores by filepath
        for file_path, score in zip(batch_paths, batch_score):
            all_scores[file_path] = float(score)
        
        # Save progress after each batch
        with open(pickle_file, 'wb') as f:
            pickle.dump(all_scores, f)
            
    print(f"Processed {len(all_scores)} files. Results saved to {pickle_file}")

In [None]:
# Visualize score distribution
scores = list(all_scores.values())

plt.figure(figsize=(10, 6))
n, bins, patches = plt.hist(scores, bins=50, alpha=0.7, 
                           color='green', edgecolor='black')

# Set threshold based on EER from paper
# According to the repository, higher scores indicate deepfake/spoof
eer_threshold = 0.5  # Default threshold for SLS model

# Add vertical line for threshold
plt.axvline(eer_threshold, color='red', linestyle='--',
           linewidth=2, label=f'EER Threshold: {eer_threshold:.4f}')

plt.title('DF Detection Scores - SLSforASVspoof', fontsize=14)
plt.xlabel('Score Value (higher = more likely fake)', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add text annotations
plt.text(0.05, 0.95, f'Total files: {len(scores)}', 
        transform=plt.gca().transAxes, verticalalignment='top')
plt.text(0.05, 0.90, f'Mean score: {np.mean(scores):.4f}',
        transform=plt.gca().transAxes, verticalalignment='top')
plt.text(0.05, 0.85, f'Files with score > threshold: {sum(s > eer_threshold for s in scores)}',
        transform=plt.gca().transAxes, verticalalignment='top')

plt.tight_layout()
plt.show()

In [None]:
# Optional: Save results to a CSV file for further analysis
results_df = pd.DataFrame({
    'file_path': list(all_scores.keys()),
    'score': list(all_scores.values()),
    'prediction': ['spoof' if s > eer_threshold else 'bonafide' for s in all_scores.values()]
})

# Display the first few results
results_df.head()

# SLSforASVspoof Deepfake Detection Notebook

This notebook demonstrates how to use the SLS (Supervised Label Smoothing) model for detecting deepfake audio samples.

**Steps:**
1. Import dependencies and define utility functions
2. Load audio files from a directory
3. Define preprocessing functions for audio data
4. Load the pretrained SLS model
5. Run batch inference on audio files
6. Visualize detection scores

In [None]:
import os
import torch
import librosa
import numpy as np
import pandas as pd
import pickle
from tqdm.auto import tqdm
import matplotlib.pyplot as plt

def get_audio_files_recursive(folder_path, formats=('wav', 'flac'), exclude_prefix=('.', '__')):
    """
    Recursively find audio files in a directory with specified extensions.
    
    Args:
        folder_path (str): Root directory to search.
        formats (tuple): File extensions to include (lowercase).
        exclude_prefix (tuple): Skip directories starting with these prefixes.
    
    Returns:
        list: Full paths to audio files.
    """
    valid_files = []
    for root, dirs, files in os.walk(folder_path):
        # Skip hidden or special directories
        dirs[:] = [d for d in dirs if not d.startswith(exclude_prefix)]
        for file in files:
            ext = os.path.splitext(file)[1][1:].lower()  # Extension without the dot
            if ext in formats:
                full_path = os.path.join(root, file)
                valid_files.append(full_path)
    return valid_files

# Set path to your audio files directory
input_folder = '/data/audio_files'
file_paths = get_audio_files_recursive(input_folder, formats=('wav', 'flac'))
print(f"Found {len(file_paths)} audio files ({', '.join(set([os.path.splitext(f)[1] for f in file_paths]))} formats)")

In [None]:
# Import model architecture
import sys
sys.path.append('.')
from models.SLS_RawNet2 import MainModel

# Function to extract features from audio
def extract_features(audio_path, sr=16000):
    """Load and preprocess audio file for SLS model"""
    # Load audio with librosa
    try:
        audio, sample_rate = librosa.load(audio_path, sr=sr, mono=True)
        
        # Check audio length and pad/trim if necessary
        target_len = sr * 4  # 4 seconds
        if len(audio) < target_len:
            # Pad audio if shorter than target length
            audio = np.pad(audio, (0, target_len - len(audio)))
        elif len(audio) > target_len:
            # Take the first 4 seconds if longer
            audio = audio[:target_len]
            
        return audio
    except Exception as e:
        print(f"Error processing {audio_path}: {e}")
        return None

In [None]:
# Initialize the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Create the SLS model instance
model = MainModel()

# Load pretrained model weights
model_path = "pretrained_models/asvdf_sls_best.pth"
model.load_state_dict(torch.load(model_path, map_location=device))
model = model.to(device)
model.eval()

print("SLS model loaded successfully")

In [None]:
import torch.utils.data

class AudioDataset(torch.utils.data.Dataset):
    def __init__(self, file_list, sample_rate=16000):
        self.file_list = file_list
        self.sample_rate = sample_rate
        
    def __len__(self):
        return len(self.file_list)
    
    def __getitem__(self, idx):
        file_path = self.file_list[idx]
        features = extract_features(file_path, sr=self.sample_rate)
        if features is None:
            # Return zeros if extraction failed
            features = np.zeros(self.sample_rate * 4)
        
        # Convert to tensor
        features_tensor = torch.FloatTensor(features)
        return features_tensor, file_path

# Create dataset and dataloader
batch_size = 16
num_workers = 4
pickle_file = 'sls_scores.pkl'
force_overwrite = False  # Set to True to force recalculation

if os.path.exists(pickle_file) and not force_overwrite:
    with open(pickle_file, 'rb') as f:
        all_scores = pickle.load(f)
    print("Loaded scores from pickle file. Skipping inference.")
else:
    all_scores = {}
    
    # Create dataset and dataloader
    dataset = AudioDataset(file_paths)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                                            shuffle=False, num_workers=num_workers)
    
    # Run inference on batches
    for batch_features, batch_paths in tqdm(dataloader, desc="Processing batches"):
        batch_features = batch_features.to(device)
        
        with torch.no_grad():
            # Get model outputs
            outputs = model(batch_features)
            # Extract scores (assuming binary classification with bonafide=0, spoof=1)
            probs = torch.softmax(outputs, dim=1)
            scores = probs[:, 1].cpu().numpy()  # Probability of being spoof/fake
        
        # Store scores by filepath
        for file_path, score in zip(batch_paths, scores):
            all_scores[file_path] = float(score)
        
        # Save progress after each batch
        with open(pickle_file, 'wb') as f:
            pickle.dump(all_scores, f)
            
    print(f"Processed {len(all_scores)} files. Results saved to {pickle_file}")

In [None]:
# Visualize score distribution
scores = list(all_scores.values())

plt.figure(figsize=(10, 6))
n, bins, patches = plt.hist(scores, bins=50, alpha=0.7, 
                           color='green', edgecolor='black')

# SLS model threshold (according to paper)
eer_threshold = 0.5  # Default threshold for SLS model

# Add vertical line for threshold
plt.axvline(eer_threshold, color='red', linestyle='--',
           linewidth=2, label=f'EER Threshold: {eer_threshold:.4f}')

plt.title('DF Detection Scores - SLSforASVspoof', fontsize=14)
plt.xlabel('Score Value (higher = more likely fake)', fontsize=12)
plt.ylabel('Frequency', fontsize=12)
plt.legend()
plt.grid(axis='y', linestyle='--', alpha=0.7)

# Add text annotations
plt.text(0.05, 0.95, f'Total files: {len(scores)}', 
        transform=plt.gca().transAxes, verticalalignment='top')
plt.text(0.05, 0.90, f'Mean score: {np.mean(scores):.4f}',
        transform=plt.gca().transAxes, verticalalignment='top')

plt.tight_layout()
plt.show()