In [None]:
import os
import numpy as np
from transformers import ASTFeatureExtractor
from transformers.utils import is_speech_available
from transformers.audio_utils import mel_filter_bank, spectrogram, window_function

if is_speech_available():
    import torchaudio.compliance.kaldi as ta_kaldi

# based on the following literature that uses the Hamming window:
    # https://arxiv.org/pdf/2505.15136
    # https://arxiv.org/pdf/2409.05924
    # https://ieeexplore.ieee.org/stamp/stamp.jsp?arnumber=11007653
class ASTFeatureExtractorHamming(ASTFeatureExtractor):
    """
    Custom AST Feature Extractor that uses Hamming window instead of Hann/Hanning.
    """
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
        # Override the window for numpy-based processing (when torchaudio is not available)
        if not is_speech_available():
            # Recalculate mel filters and window with hamming
            mel_filters = mel_filter_bank(
                num_frequency_bins=257,
                num_mel_filters=self.num_mel_bins,
                min_frequency=20,
                max_frequency=self.sampling_rate // 2,
                sampling_rate=self.sampling_rate,
                norm=None,
                mel_scale="kaldi",
                triangularize_in_mel_space=True,
            )
            self.mel_filters = mel_filters
            # Use hamming window instead of hann
            self.window = window_function(400, "hamming", periodic=False)
    
    def _extract_fbank_features(self, waveform: np.ndarray, max_length: int) -> np.ndarray:
        """
        Override to use hamming window type in torchaudio.compliance.kaldi.fbank
        """
        if is_speech_available():
            waveform = torch.from_numpy(waveform).unsqueeze(0)
            fbank = ta_kaldi.fbank(
                waveform,
                sample_frequency=self.sampling_rate,
                window_type="hamming",  # Changed from "hanning" to "hamming"
                num_mel_bins=self.num_mel_bins,
            )
        else:
            # Use numpy implementation with hamming window
            waveform = np.squeeze(waveform)
            fbank = spectrogram(
                waveform,
                self.window,  # This is now hamming window from __init__
                frame_length=400, # this follows the 25 ms frame length used in the paper (16000mhz * 0.025 = 400)
                hop_length=160, # this follows the hop length used in the paper (16000mhz * 0.01 = 160)
                fft_length=512,
                power=2.0,
                center=False,
                preemphasis=0.97,
                mel_filters=self.mel_filters,
                log_mel="log",
                mel_floor=1.192092955078125e-07,
                remove_dc_offset=True,
            ).T
            fbank = torch.from_numpy(fbank)

        n_frames = fbank.shape[0]
        difference = max_length - n_frames

        # pad or truncate, depending on difference
        if difference > 0:
            pad_module = torch.nn.ZeroPad2d((0, 0, 0, difference))
            fbank = pad_module(fbank)
        elif difference < 0:
            fbank = fbank[0:max_length, :]

        fbank = fbank.numpy()
        return fbank

In [None]:
from transformers import ASTForAudioClassification
from evaluate import evaluator, combine
import os
import json
import torch
from transformers import ASTFeatureExtractor

NUM_MEL_BINS = 128 # based on https://arxiv.org/pdf/2409.05924
MAX_SEQUENCE_LENGTH = 507

# checkpoints = [
#     # # 1st iteration - test run (one dataset only)
#     # "./runs2/ast_classifier/checkpoint-158700",
    
#     # # 2nd iteration - first official run
#     # "./eval/ast_classifier/checkpoint-174570", # epoch 11

#     # # 3rd iteration / checking if epoch 13 is really peak (test patchout)
#     # "./runs2/ast_classifier/checkpoint-63480", # epoch 9
#     # "./runs2/ast_classifier/checkpoint-71415", # epoch 10
#     # "./runs2/ast_classifier/checkpoint-79350", # epoch 11
#     # "./runs2/ast_classifier/checkpoint-87285", # epoch 12
#     # "./runs2/ast_classifier/checkpoint-95220", # epoch 13

#     # 4th iteration - cleaner dataset + using librespeech. Tuned on test set so use val set for testing
#     # "./runs2/ast_classifier/checkpoint-31740", # epoch 1
#     # "./runs2/ast_classifier/checkpoint-63480", # epoch 2
#     # "./runs2/ast_classifier/checkpoint-95220", # epoch 3
#     # "./runs2/ast_classifier/checkpoint-126960", # epoch 4
#     # "./runs2/ast_classifier/checkpoint-158700", # epoch 5
#     # "./runs2/ast_classifier/checkpoint-190440", # epoch 6
#     # "./runs2/ast_classifier/checkpoint-222180", # epoch 7

#     # 5th iteration - cleaner dataset + using librespeech. Tuned on test set so use val set for testing. 15% oversampled
#     # "./epoch1-9/ast_classifier/checkpoint-34567", # epoch 1
#     # "./epoch1-9/ast_classifier/checkpoint-69134", # epoch 2
#     # "./epoch1-9/ast_classifier/checkpoint-103701", # epoch 3
#     # "./epoch1-9/ast_classifier/checkpoint-138268", # epoch 4
#     # "./epoch1-9/ast_classifier/checkpoint-172835", # epoch 5
#     # "./epoch1-9/ast_classifier/checkpoint-207402", # epoch 6
#     # "./epoch1-9/ast_classifier/checkpoint-241969", # epoch 7    
#     # "./epoch1-9/ast_classifier/checkpoint-276536", # epoch 8    
#     # "./epoch1-9/ast_classifier/checkpoint-311103", # epoch 9
# ]

checkpoints = [
    os.path.join("eval-checkpoints", "ast_classifier", d) 
    for d in sorted(os.listdir(os.path.join("eval-checkpoints", "ast_classifier")))
    if os.path.isdir(os.path.join("eval-checkpoints", "ast_classifier", d)) and "iteration" in d
]

print("Found checkpoints:", checkpoints)

def load_models(checkpoints:list):
    models={}
    for i, checkpoint in enumerate(checkpoints):
        print(checkpoint)
        model = ASTForAudioClassification.from_pretrained(checkpoint, local_files_only=True)
        if i == len(checkpoints) - 1 or i == len(checkpoints) - 1: # last 2 model uses 1024 bin
            feature_extractor = ASTFeatureExtractorHamming.from_pretrained(checkpoint, num_mel_bins=NUM_MEL_BINS, max_sequence_length=1024, local_files_only=True)
            print(model)
        else:
            feature_extractor = ASTFeatureExtractorHamming.from_pretrained(checkpoint, num_mel_bins=NUM_MEL_BINS, max_sequence_length=MAX_SEQUENCE_LENGTH, local_files_only=True)
        
        print("do_normalize:", feature_extractor.do_normalize)
        print("mean:", getattr(feature_extractor, "mean", None))
        print("std:", getattr(feature_extractor, "std", None))
        feature_extractor.do_normalize = True
        print("do_normalize:", feature_extractor.do_normalize)
        
        models[i] = {
            "model": model,
            "feature_extractor": feature_extractor
        }
    return models

task_evaluator = evaluator(task="audio-classification")
results = []
models = load_models(checkpoints)

In [None]:
# Configuration
import os
import numpy as np
import pandas as pd
import evaluate
import torch
from tqdm.auto import tqdm
from transformers import pipeline
from transformers.pipelines.pt_utils import KeyDataset
from sklearn.metrics import (
    confusion_matrix, ConfusionMatrixDisplay, classification_report,
    roc_curve, auc, precision_recall_curve, average_precision_score,
    f1_score, precision_score, recall_score
)
from datasets import load_dataset
import matplotlib.pyplot as plt

# --- Configuration ---
# List of dataset paths to evaluate on
DATASET_PATHS = [
    {"name": "wavefake", "path": r"C:\Users\crumbz\Downloads\thesis-testing\eval-datasets\wavefake"},
    {"name": "dataset-balanced-arrow", "path": r"C:\Users\crumbz\Downloads\thesis-testing\eval-datasets\dataset-balanced-arrow"},
    {"name": "LibriSeVoc", "path": r"C:\Users\crumbz\Downloads\thesis-testing\eval-datasets\LibriSeVoc"},
    {"name": "ASVspoof2019", "path": r"C:\Users\crumbz\Downloads\thesis-testing\eval-datasets\ASVSpoof2019"},
    {"name": "release_in_the_wild", "path": r"C:\Users\crumbz\Downloads\thesis-testing\eval-datasets\release_in_the_wild"},
    {"name": "SONAR_dataset", "path": r"C:\Users\crumbz\Downloads\thesis-testing\eval-datasets\SONAR_dataset"},

]

# Evaluation settings
OUTPUT_DIR = "evaluation-results"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- Helper Functions ---
def load_dataset_helper(dataset_path):
    """Load dataset from path, handling both arrow datasets and audio folders"""
    try:
        # First try loading as an arrow dataset
        if os.path.exists(os.path.join(dataset_path, "dataset_dict.json")):
            from datasets import load_from_disk
            return load_from_disk(dataset_path)
        # If not an arrow dataset, load as an audio folder
        return load_dataset("audiofolder", data_dir=dataset_path, num_proc=os.cpu_count() - 1, drop_labels=False)
    except Exception as e:
        print(f"Error loading dataset from {dataset_path}: {str(e)}")
        raise

def save_plots(plot_data, model_name, dataset_name, output_dir):
    """Save evaluation plots to disk"""
    # Create a safe model name by taking only the last part of the path
    safe_model_name = model_name.split('/')[-1].split('\\')[-1]  # Handles both / and \ separators
    base_path = os.path.join(output_dir, f"{dataset_name[:10]}_{safe_model_name[:30]}")
    
    # Ensure the directory exists
    os.makedirs(output_dir, exist_ok=True)
    
    # Save confusion matrix
    plt.figure(figsize=(5, 5))
    disp = ConfusionMatrixDisplay(confusion_matrix=plot_data["cm"], 
                                 display_labels=["fake (0)", "real (1)"])
    disp.plot(cmap="Blues", values_format="d", colorbar=False)
    plt.title(f"Confusion Matrix\n{model_name}")
    plt.tight_layout()
    plt.savefig(f"{base_path}_cm.png", bbox_inches='tight')
    plt.close()
    
    # Save ROC curve
    plt.figure(figsize=(5, 5))
    plt.plot(plot_data["fpr"], plot_data["tpr"], 
             label=f"AUC = {plot_data['roc_auc']:.4f}")
    plt.plot([0, 1], [0, 1], "--", color="gray", alpha=0.7)
    plt.xlabel("False Positive Rate")
    plt.ylabel("True Positive Rate")
    plt.title("ROC Curve")
    plt.legend(loc="lower right")
    plt.tight_layout()
    plt.savefig(f"{base_path}_roc.png", bbox_inches='tight')
    plt.close()
    
    # Add F1-score vs Threshold plot
    plt.figure(figsize=(5, 5))
    plt.plot(plot_data["thresholds"], plot_data["f1_scores"], label="F1-score", color="green")
    plt.axvline(x=0.5, color="r", linestyle="--", label="Default Threshold (0.5)")
    plt.xlabel("Threshold")
    plt.ylabel("F1-Score")
    plt.title("F1-Score vs Threshold")
    plt.legend()
    plt.tight_layout()
    plt.savefig(f"{base_path}_f1.png", bbox_inches='tight')
    plt.close()

    # Save classification report
    with open(f"{base_path}_report.txt", "w") as f:
        f.write(plot_data["classification_report"])
    
    return base_path  # Return the base path for reference

# --- Main Evaluation Function ---
def evaluate_model_on_dataset(model_info, dataset, dataset_name, output_dir):
    """Evaluate a single model on a single dataset"""
    device_index = 0 if torch.cuda.is_available() else -1
    pipe = pipeline(
        "audio-classification",
        model=model_info["model"],
        feature_extractor=model_info["feature_extractor"],
        device=device_index,
    )
    
    # Prepare data
    label_mapping = {"fake": 0, "real": 1}
    test_ds = dataset["test"]
    references = [label_mapping[x] if x in label_mapping else int(x) 
                 for x in test_ds["label"]]
    
    # Get predictions
    scores_pos = []
    y_pred = []
    
    for pred in tqdm(
        pipe(KeyDataset(test_ds, "audio"), top_k=None),
        total=len(test_ds),
        desc=f"Evaluating on {dataset_name}",
    ):
        score_map = {p["label"]: p["score"] for p in pred}
        real_score = score_map.get("real", 0.0)
        fake_score = score_map.get("fake", 0.0)
        scores_pos.append(real_score)
        y_pred.append(1 if real_score >= 0.5 else 0)
    
    # Calculate metrics
    metrics = {
        "accuracy": evaluate.load("accuracy").compute(
            predictions=y_pred, references=references
        )["accuracy"],
        "precision": precision_score(references, y_pred, zero_division=0),
        "recall": recall_score(references, y_pred, zero_division=0),
        "f1": f1_score(references, y_pred, zero_division=0),
    }
    
    # Calculate ROC and PR curves
    fpr, tpr, _ = roc_curve(references, scores_pos)
    prec, rec, _ = precision_recall_curve(references, scores_pos)
    metrics.update({
        "roc_auc": auc(fpr, tpr),
        "average_precision": average_precision_score(references, scores_pos)
    })
    
    # Prepare plot data
    plot_data = {
        "cm": confusion_matrix(references, y_pred, labels=[0, 1]),
        "fpr": fpr,
        "tpr": tpr,
        "prec": prec,
        "rec": rec,
        "roc_auc": metrics["roc_auc"],
        "classification_report": classification_report(
            references, y_pred, target_names=["fake", "real"], digits=4
        )
    }
    
    # Calculate F1-scores for different thresholds
    thresholds = np.linspace(0, 1, 101)
    f1_scores = [f1_score(references, [1 if score >= t else 0 for score in scores_pos], zero_division=0) 
                for t in thresholds]
    
    # Update plot_data with thresholds and f1_scores
    plot_data.update({
        "thresholds": thresholds,
        "f1_scores": f1_scores
    })

    return metrics, plot_data

# --- Main Execution ---
all_results = []

for dataset_info in DATASET_PATHS:
    dataset_name = dataset_info["name"]
    dataset_path = dataset_info["path"]
    
    print(f"\\n{'='*50}")
    print(f"Evaluating on dataset: {dataset_name}")
    print(f"{'='*50}")
    
    try:
        # Load dataset
        print(f"Loading dataset from {dataset_path}...")
        dataset = load_dataset_helper(dataset_path)
        
        # Create output directory for this dataset
        dataset_output_dir = os.path.join(OUTPUT_DIR, dataset_name)
        os.makedirs(dataset_output_dir, exist_ok=True)
        
        # Evaluate each model
        for i, model_info in models.items():
            print(f"\\nEvaluating model {i+1}/{len(models)}")
            
            # Run evaluation
            metrics, plot_data = evaluate_model_on_dataset(
                model_info=model_info,
                dataset=dataset,
                dataset_name=dataset_name,
                output_dir=dataset_output_dir
            )
            
            # In your main execution loop, replace the save_plots call with:
            base_path = save_plots(
                plot_data=plot_data,
                model_name=checkpoints[i].split("/")[-1].split("\\")[-1],  # Get just the checkpoint name
                dataset_name=dataset_name,
                output_dir=dataset_output_dir
            )
            
            # Store results
            result = {
                "dataset": dataset_name,
                "model": checkpoints[i],
                **metrics
            }
            all_results.append(result)
            
            # Save intermediate results
            pd.DataFrame(all_results).to_csv(
                os.path.join(OUTPUT_DIR, "all_results.csv"),
                index=False
            )   
            
    except Exception as e:
        print(f"Error evaluating {dataset_name}: {str(e)}")
        continue

# Display final results
if all_results:
    results_df = pd.DataFrame(all_results)
    print("\\nEvaluation completed! Summary of results:")
    display(results_df.groupby(['dataset', 'model'])[['accuracy', 'f1', 'roc_auc']].mean().round(4))
    results_df.to_csv(os.path.join(OUTPUT_DIR, "final_results.csv"), index=False)
    print(f"\\nResults saved to {os.path.join(OUTPUT_DIR, 'final_results.csv')}")
else:
    print("No results were generated. Please check for errors.")