# Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Installations

In [None]:
%%capture
!pip install datasets==3.6.0
!pip install torchmetrics
!pip install accelerate
!pip install transformers

In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import WhisperModel, WhisperFeatureExtractor
from datasets import load_dataset
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, classification_report, precision_score, recall_score
from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall, BinaryF1Score, BinaryEER
from dataclasses import dataclass
from typing import Dict, List, Union
import csv
from typing import Dict, List, Optional, Union, Any
from tqdm.notebook import tqdm

In [None]:
from huggingface_hub import notebook_login
notebook_login()

In [None]:
# Configuration
MODEL_NAME = "openai/whisper-base"
BATCH_SIZE = 8
LEARNING_RATE = 1e-4
NUM_EPOCHS = 5
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
outputdir = "/content/drive/MyDrive/Colab Notebooks/fakevoices"

# Dataset and Data Collator

In [None]:
def prepare_dataset(dataset, feature_extractor, split="train"):

    def preprocess_function(batch):
        audio_arrays = [item["array"] for item in batch["audio"]]
        sampling_rates = [item["sampling_rate"] for item in batch["audio"]]

        input_features_list = []
        labels = [1 if lbl == "real" else 0 for lbl in batch["label"]]
        for audio in audio_arrays:
            inputs = feature_extractor(
                audio,
                sampling_rate=sampling_rates[0],
                return_tensors="pt"
            )
            input_features_list.append(inputs.input_features.squeeze(0))

        batch_outputs = {
            "input_features": input_features_list,
            "labels": labels
        }

        return batch_outputs

    processed_dataset = dataset.map(
        preprocess_function,
        batched=True,
        batch_size=BATCH_SIZE,
        remove_columns=dataset.column_names
    )

    return processed_dataset


@dataclass
class WhisperDataCollator:
    """
    Data collator for Whisper that handles padding of input features.
    """

    feature_extractor: WhisperFeatureExtractor

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [feature["input_features"] for feature in features]
        labels = [feature["labels"] for feature in features]

        batch_input_features = torch.stack([
            torch.tensor(feat) if not isinstance(feat, torch.Tensor) else feat
            for feat in input_features
        ])

        batch = {
            "input_features": batch_input_features,
            "labels": torch.tensor(labels, dtype=torch.long)
        }

        return batch

# Whisper Model

In [None]:
class WhisperAudioClassifier(nn.Module):
    """
    Whisper encoder and classification head for fake audio detection
    """

    def __init__(self, model_name=MODEL_NAME, freeze_encoder=True, freeze_feature_extractor=False, token=None):
        super().__init__()

        self.whisper = WhisperModel.from_pretrained(model_name, token=token)

        if freeze_feature_extractor:
            if hasattr(self.whisper.encoder, 'conv1'):
                for param in self.whisper.encoder.conv1.parameters():
                    param.requires_grad = False
            if hasattr(self.whisper.encoder, 'conv2'):
                for param in self.whisper.encoder.conv2.parameters():
                    param.requires_grad = False

        if freeze_encoder:
            for param in self.whisper.encoder.parameters():
                param.requires_grad = False

        hidden_size = self.whisper.config.d_model

        self.classifier = nn.Sequential(
            nn.Linear(hidden_size, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 2)
        )

    def forward(self, input_features):
        encoder_outputs = self.whisper.encoder(
            input_features,
            return_dict=True
        )

        hidden_states = encoder_outputs.last_hidden_state
        pooled_output = torch.mean(hidden_states, dim=1)

        logits = self.classifier(pooled_output)
        return logits

# Train

In [None]:
def train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS):

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)

    best_val_eer = float('inf')  # Lower EER is better

    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        train_preds = []
        train_labels = []
        train_probs = []

        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for batch in pbar:
            input_features = batch["input_features"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            optimizer.zero_grad()

            logits = model(input_features)
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()

            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = torch.argmax(logits, dim=1)
            train_preds.extend(preds.cpu())
            train_labels.extend(labels.cpu())
            train_probs.extend(probs.cpu())

            pbar.set_postfix({"loss": f"{loss.item():.4f}"})

        train_acc = BinaryAccuracy().to(DEVICE)
        train_f1 = BinaryF1Score().to(DEVICE)
        train_eer = BinaryEER(thresholds=None).to(DEVICE)

        train_preds_tensor = torch.stack(train_preds).to(DEVICE)
        train_labels_tensor = torch.stack(train_labels).to(DEVICE)
        train_probs_tensor = torch.stack(train_probs).to(DEVICE)

        train_acc.update(train_preds_tensor, train_labels_tensor)
        train_f1.update(train_preds_tensor, train_labels_tensor)
        train_eer.update(train_probs_tensor, train_labels_tensor)

        avg_train_loss = train_loss / len(train_loader)
        train_acc_val = train_acc.compute().item()
        train_f1_val = train_f1.compute().item()
        train_eer_val = train_eer.compute().item()

        model.eval()
        val_loss = 0.0
        val_preds = []
        val_labels = []
        val_probs = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]"):
                input_features = batch["input_features"].to(DEVICE)
                labels = batch["labels"].to(DEVICE)

                logits = model(input_features)
                loss = criterion(logits, labels)

                val_loss += loss.item()

                probs = torch.softmax(logits, dim=1)[:, 1]
                preds = torch.argmax(logits, dim=1)
                val_preds.extend(preds.cpu())
                val_labels.extend(labels.cpu())
                val_probs.extend(probs.cpu())

        # metrics
        val_acc = BinaryAccuracy().to(DEVICE)
        val_f1 = BinaryF1Score().to(DEVICE)
        val_eer = BinaryEER(threshold=None).to(DEVICE)

        val_preds_tensor = torch.stack(val_preds).to(DEVICE)
        val_labels_tensor = torch.stack(val_labels).to(DEVICE)
        val_probs_tensor = torch.stack(val_probs).to(DEVICE)

        val_acc.update(val_preds_tensor, val_labels_tensor)
        val_f1.update(val_preds_tensor, val_labels_tensor)
        val_eer.update(val_probs_tensor, val_labels_tensor)

        avg_val_loss = val_loss / len(val_loader)
        val_acc_val = val_acc.compute().item()
        val_f1_val = val_f1.compute().item()
        val_eer_val = val_eer.compute().item()

        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print(f"Train - Loss: {avg_train_loss:.4f}, Acc: {train_acc_val:.4f}, F1: {train_f1_val:.4f}, EER: {train_eer_val:.4f}")
        print(f"Val   - Loss: {avg_val_loss:.4f}, Acc: {val_acc_val:.4f}, F1: {val_f1_val:.4f}, EER: {val_eer_val:.4f}")

        # save best model based on validation EER (lower is better)
        if val_eer_val < best_val_eer:
            best_val_eer = val_eer_val
            torch.save(model.state_dict(), "best_fake_audio_detector.pt")
            print(f"✓ Saved best model with validation EER: {val_eer_val:.4f}")

    print(f"\nTraining complete! Best validation EER: {best_val_eer:.4f}")
    return model

# Run training

In [None]:
print(f"Using device: {DEVICE}")

feature_extractor = WhisperFeatureExtractor.from_pretrained(MODEL_NAME)
data_collator = WhisperDataCollator(feature_extractor=feature_extractor)

dataset = load_dataset("", trust_remote_code=True) # HF dataset

In [None]:
train_dataset = prepare_dataset(dataset["train"], feature_extractor)
val_dataset = prepare_dataset(dataset["validation"], feature_extractor)
test_datset = prepare_dataset(dataset["train"], feature_extractor)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=data_collator)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=data_collator)
test_loader = DataLoader(test_datset, batch_size=BATCH_SIZE, collate_fn=data_collator)

In [None]:
# Initialize model
print("\nInitializing model...")
model = WhisperAudioClassifier().to(DEVICE)
print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

# Run training

In [None]:
# print("\nStarting training...")
model = train_model(model, train_loader, val_loader, num_epochs=NUM_EPOCHS)

# Evaluation

In [None]:
def evaluate_model(model, test_loader, file_name):

    binary_acc = BinaryAccuracy().to(DEVICE)
    binary_precision = BinaryPrecision().to(DEVICE)
    binary_recall = BinaryRecall().to(DEVICE)
    binary_f1 = BinaryF1Score().to(DEVICE)
    binary_eer = BinaryEER(threshold=None).to(DEVICE)

    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_features = batch["input_features"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            logits = model(input_features)
            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = torch.argmax(logits, dim=1)

            binary_acc.update(preds, labels)
            binary_precision.update(preds, labels)
            binary_recall.update(preds, labels)
            binary_f1.update(preds, labels)
            binary_eer.update(probs, labels)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # metrics
    accuracy = binary_acc.compute().item()
    precision = binary_precision.compute().item()
    recall = binary_recall.compute().item()
    f1 = binary_f1.compute().item()
    eer = binary_eer.compute().item()

    # print
    print("\n" + "="*60)
    print("Test Set Results - Comprehensive Metrics")
    print("="*60)
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    print(f"EER:       {eer:.4f}")
    print("="*60)

    # detailed classification report
    print("\nDetailed Classification Report:")
    print(classification_report(all_labels, all_preds,
                                target_names=["Real", "Fake"],
                                digits=4))

    res = {
        'predictions': all_preds,
        'labels': all_labels,
        'probabilities': all_probs,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'eer': eer
    }
    save_results_summary(res, os.path.join(outputdir, file_name))
    return res


def save_results_summary(results, filepath):
    """Save a compact summary of results"""
    with open(filepath, 'w') as f:
        f.write("="*60 + "\n")
        f.write("Test Results Summary\n")
        f.write("="*60 + "\n\n")

        # Metrics table
        f.write(f"{'Metric':<15} {'Value':>10}\n")
        f.write("-"*60 + "\n")
        f.write(f"{'Accuracy':<15} {results['accuracy']:>10.4f}\n")
        f.write(f"{'Precision':<15} {results['precision']:>10.4f}\n")
        f.write(f"{'Recall':<15} {results['recall']:>10.4f}\n")
        f.write(f"{'F1-Score':<15} {results['f1']:>10.4f}\n")
        f.write(f"{'EER':<15} {results['eer']:>10.4f}\n")
        f.write("\n")

        # confusion matrix
        preds = results['predictions']
        labels = results['labels']

        tp = sum((p == 1 and l == 1) for p, l in zip(preds, labels))  # True Positive
        tn = sum((p == 0 and l == 0) for p, l in zip(preds, labels))  # True Negative
        fp = sum((p == 1 and l == 0) for p, l in zip(preds, labels))  # False Positive
        fn = sum((p == 0 and l == 1) for p, l in zip(preds, labels))  # False Negative

        f.write("Confusion Matrix:\n")
        f.write("-"*60 + "\n")
        f.write(f"                 Predicted Fake    Predicted Real\n")
        f.write(f"Actual Fake      {tn:>14d}    {fp:>14d}\n")
        f.write(f"Actual Real      {fn:>14d}    {tp:>14d}\n")
        f.write("\n")
        f.write(f"Total Samples: {len(preds)}\n")
        f.write("="*60 + "\n")

# Evaluation: AFAD test split

In [None]:
print("\nStarting eval...")
# Load best model
model.load_state_dict(torch.load(os.path.join(outputdir, "WHISPER.pt")))

evaluate_model(model, test_loader, "WHISPER_test_summary.txt")

# Evaluation: given the proprietary TTS model

In [None]:
test_tts_eleven = dataset['test'].filter(lambda example: example['tts'] in ['eleven_multilingual_v2',  'none'] and example['label'] in ['real', 'fake'])
test_tts_openai = dataset['test'].filter(lambda example: example['tts'] in ['gpt-4o-mini-tts',  'none'] and example['label'] in ['real', 'fake'])
test_tts_minimax = dataset['test'].filter(lambda example: example['tts'] in ['speech-2.5-hd-preview',  'none'] and example['label'] in ['real', 'fake'])
test_tts_resemble = dataset['test'].filter(lambda example: example['tts'] in ['reseamble-AI',  'none'] and example['label'] in ['real', 'fake'])

In [None]:
# prepare datasets
test_tts_dataset_e = prepare_dataset(test_tts_eleven, feature_extractor)
test_tts_dataset_o = prepare_dataset(test_tts_openai, feature_extractor)
test_tts_dataset_m = prepare_dataset(test_tts_minimax, feature_extractor)
test_tts_dataset_r = prepare_dataset(test_tts_resemble, feature_extractor)

In [None]:
# Dataloaders
eleven_loader = DataLoader(test_tts_dataset_e, batch_size=BATCH_SIZE, collate_fn=data_collator)
openai_loader = DataLoader(test_tts_dataset_o, batch_size=BATCH_SIZE, collate_fn=data_collator)
minimax_loader = DataLoader(test_tts_dataset_m, batch_size=BATCH_SIZE, collate_fn=data_collator)
resemble_loader = DataLoader(test_tts_dataset_r, batch_size=BATCH_SIZE, collate_fn=data_collator)

In [None]:
# Save Results
evaluate_model(model, eleven_loader, "elevenlabs_whisper_summary.txt")
evaluate_model(model, openai_loader, "openai_whisper_summary.txt")
evaluate_model(model, minimax_loader, "minimax_whisper_summary.txt")
evaluate_model(model, resemble_loader, "resemble_whisper_summary.txt")

# Evaluation: the ad-hoc datasets

In [None]:
fishaudio = dataset['fishaudio']
xtts = dataset['xtts']
mms = dataset['mms']
t5 = dataset['speecht5']

In [None]:
# prepare datasets
prepare_fishaudio = prepare_dataset(fishaudio, feature_extractor)
prepare_xtts = prepare_dataset(xtts, feature_extractor)
prepare_mms = prepare_dataset(mms, feature_extractor)
prepare_t5 = prepare_dataset(t5, feature_extractor)

In [None]:
# Dataloaders
fishaudio_loader = DataLoader(prepare_fishaudio, batch_size=BATCH_SIZE, collate_fn=data_collator)
xtts_loader = DataLoader(prepare_xtts, batch_size=BATCH_SIZE, collate_fn=data_collator)
mms_loader = DataLoader(prepare_mms, batch_size=BATCH_SIZE, collate_fn=data_collator)
t5_loader = DataLoader(prepare_t5, batch_size=BATCH_SIZE, collate_fn=data_collator)

In [None]:
# Save Results
evaluate_model(model, fishaudio_loader, "fishaudio_whisper_summary.txt")
evaluate_model(model, xtts_loader, "xtts_whisper_summary.txt")
evaluate_model(model, mms_loader, "mms_whisper_summary.txt")
evaluate_model(model, t5_loader, "t5_whisper_summary.txt")

# ROC/EER

In [None]:
from torchmetrics.classification import BinaryROC, BinaryAUROC
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
import numpy as np
import pandas as pd

def evaluate_model2(model, test_loader, file_name):

    binary_acc = BinaryAccuracy().to(DEVICE)
    binary_precision = BinaryPrecision().to(DEVICE)
    binary_recall = BinaryRecall().to(DEVICE)
    binary_f1 = BinaryF1Score().to(DEVICE)
    binary_eer = BinaryEER().to(DEVICE)
    binary_roc = BinaryROC().to(DEVICE)
    binary_auroc = BinaryAUROC().to(DEVICE)

    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Evaluating"):
            input_features = batch["input_features"].to(DEVICE)
            labels = batch["labels"].to(DEVICE)

            logits = model(input_features)
            probs = torch.softmax(logits, dim=1)[:, 1]
            preds = torch.argmax(logits, dim=1)

            binary_acc.update(preds, labels)
            binary_precision.update(preds, labels)
            binary_recall.update(preds, labels)
            binary_f1.update(preds, labels)
            binary_eer.update(probs, labels)
            binary_roc.update(probs, labels)
            binary_auroc.update(probs, labels)

            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())

    # metrics
    accuracy = binary_acc.compute().item()
    precision = binary_precision.compute().item()
    recall = binary_recall.compute().item()
    f1 = binary_f1.compute().item()
    eer = binary_eer.compute().item()
    fpr, tpr, thresholds = binary_roc.compute()
    auroc = binary_auroc.compute().item()

    # Calculate predictions at EER threshold
    eer_threshold = thresholds[torch.argmin(torch.abs(fpr - (1 - tpr)))].item()
    preds_at_eer = (np.array(all_probs) >= eer_threshold).astype(int)

    # Metrics at EER threshold
    acc_at_eer = accuracy_score(all_labels, preds_at_eer)
    prec_at_eer = precision_score(all_labels, preds_at_eer, zero_division=0)
    rec_at_eer = recall_score(all_labels, preds_at_eer, zero_division=0)
    f1_at_eer = f1_score(all_labels, preds_at_eer, zero_division=0)

    # print
    print("\n" + "="*60)
    print("Test Set Results - Comprehensive Metrics")
    print("="*60)
    print(f"Decision Threshold: 0.5 (default argmax)")
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1 Score:  {f1:.4f}")
    print("-"*60)
    print(f"EER Threshold: {eer_threshold:.4f}")
    print(f"EER:       {eer:.4f}")
    print(f"Accuracy @ EER:  {acc_at_eer:.4f}")
    print(f"Precision @ EER: {prec_at_eer:.4f}")
    print(f"Recall @ EER:    {rec_at_eer:.4f}")
    print(f"F1 Score @ EER:  {f1_at_eer:.4f}")
    print("-"*60)
    print(f"AUROC:     {auroc:.4f}")
    print("="*60)

    # detailed classification report
    print("\nDetailed Classification Report (at 0.5 threshold):")
    print(classification_report(all_labels, all_preds,
                                target_names=["Real", "Fake"],
                                digits=4))

    print("\nDetailed Classification Report (at EER threshold):")
    print(classification_report(all_labels, preds_at_eer,
                                target_names=["Real", "Fake"],
                                digits=4))

    res = {
        'predictions': all_preds,
        'predictions_at_eer': preds_at_eer.tolist(),
        'labels': all_labels,
        'probabilities': all_probs,
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'eer': eer,
        'eer_threshold': eer_threshold,
        'accuracy_at_eer': acc_at_eer,
        'precision_at_eer': prec_at_eer,
        'recall_at_eer': rec_at_eer,
        'f1_at_eer': f1_at_eer,
        'auroc': auroc,
        'fpr': fpr.cpu().numpy().tolist(),
        'tpr': tpr.cpu().numpy().tolist(),
        'thresholds': thresholds.cpu().numpy().tolist()
    }
    save_results_summary(res, os.path.join(outputdir, file_name))
    return res


def save_results_summary(results, filepath):
    """Save a compact summary of results"""
    with open(filepath, 'w') as f:
        f.write("="*60 + "\n")
        f.write("Test Results Summary\n")
        f.write("="*60 + "\n\n")

        # metrics table - Default threshold
        f.write("Metrics at Default Threshold (0.5):\n")
        f.write(f"{'Metric':<15} {'Value':>10}\n")
        f.write("-"*60 + "\n")
        f.write(f"{'Accuracy':<15} {results['accuracy']:>10.4f}\n")
        f.write(f"{'Precision':<15} {results['precision']:>10.4f}\n")
        f.write(f"{'Recall':<15} {results['recall']:>10.4f}\n")
        f.write(f"{'F1-Score':<15} {results['f1']:>10.4f}\n")
        f.write("\n")

        # metrics at EER threshold
        f.write(f"Metrics at EER Threshold ({results['eer_threshold']:.4f}):\n")
        f.write(f"{'Metric':<15} {'Value':>10}\n")
        f.write("-"*60 + "\n")
        f.write(f"{'EER':<15} {results['eer']:>10.4f}\n")
        f.write(f"{'Accuracy':<15} {results['accuracy_at_eer']:>10.4f}\n")
        f.write(f"{'Precision':<15} {results['precision_at_eer']:>10.4f}\n")
        f.write(f"{'Recall':<15} {results['recall_at_eer']:>10.4f}\n")
        f.write(f"{'F1-Score':<15} {results['f1_at_eer']:>10.4f}\n")
        f.write("\n")

        # AUROC
        f.write(f"{'AUROC':<15} {results['auroc']:>10.4f}\n")
        f.write("\n")

        # confusion matrix at default threshold
        preds = results['predictions']
        labels = results['labels']

        tp = sum((p == 1 and l == 1) for p, l in zip(preds, labels))
        tn = sum((p == 0 and l == 0) for p, l in zip(preds, labels))
        fp = sum((p == 1 and l == 0) for p, l in zip(preds, labels))
        fn = sum((p == 0 and l == 1) for p, l in zip(preds, labels))

        f.write("Confusion Matrix (at 0.5 threshold):\n")
        f.write("-"*60 + "\n")
        f.write(f"{'':20} Predicted Fake    Predicted Real\n")
        f.write(f"{'Actual Fake':<20} {tn:>14d}    {fp:>14d}\n")
        f.write(f"{'Actual Real':<20} {fn:>14d}    {tp:>14d}\n")
        f.write("\n")

        # confusion matrix at EER threshold
        preds_eer = results['predictions_at_eer']

        tp_eer = sum((p == 1 and l == 1) for p, l in zip(preds_eer, labels))
        tn_eer = sum((p == 0 and l == 0) for p, l in zip(preds_eer, labels))
        fp_eer = sum((p == 1 and l == 0) for p, l in zip(preds_eer, labels))
        fn_eer = sum((p == 0 and l == 1) for p, l in zip(preds_eer, labels))

        f.write(f"Confusion Matrix (at EER threshold {results['eer_threshold']:.4f}):\n")
        f.write("-"*60 + "\n")
        f.write(f"{'':20} Predicted Fake    Predicted Real\n")
        f.write(f"{'Actual Fake':<20} {tn_eer:>14d}    {fp_eer:>14d}\n")
        f.write(f"{'Actual Real':<20} {fn_eer:>14d}    {tp_eer:>14d}\n")
        f.write("\n")

        f.write(f"Total Samples: {len(preds)}\n")
        f.write("="*60 + "\n\n")

        # ROC Curve Data Summary
        f.write("ROC Curve Summary:\n")
        f.write("-"*60 + "\n")
        fpr_vals = results['fpr']
        tpr_vals = results['tpr']
        thresh_vals = results['thresholds']

        f.write(f"Total points on ROC curve: {len(fpr_vals)}\n")
        f.write(f"AUROC: {results['auroc']:.4f}\n\n")

        # Sample key points from ROC curve
        f.write("Key ROC Points:\n")
        f.write(f"{'Threshold':<12} {'FPR':<12} {'TPR':<12}\n")
        f.write("-"*60 + "\n")

        # Show a sample of points (first, middle, last, and EER point)
        indices = [0, len(thresh_vals)//4, len(thresh_vals)//2,
                   3*len(thresh_vals)//4, len(thresh_vals)-1]

        for idx in indices:
            f.write(f"{thresh_vals[idx]:<12.4f} {fpr_vals[idx]:<12.4f} {tpr_vals[idx]:<12.4f}\n")

        # Add EER point
        eer_idx = np.argmin(np.abs(np.array(fpr_vals) - (1 - np.array(tpr_vals))))
        f.write(f"\nEER Point:\n")
        f.write(f"{thresh_vals[eer_idx]:<12.4f} {fpr_vals[eer_idx]:<12.4f} {tpr_vals[eer_idx]:<12.4f}\n")

        f.write("\n")
        f.write("Note: Full ROC curve data (FPR, TPR, thresholds) saved in results dictionary\n")
        f.write("="*60 + "\n")

        # Save full ROC curve to CSV
        roc_df = pd.DataFrame({
            'threshold': results['thresholds'],
            'fpr': results['fpr'],
            'tpr': results['tpr']
        })
        roc_csv_path = filepath.replace('.txt', '_roc_curve.csv')
        roc_df.to_csv(roc_csv_path, index=False)
        print(f"ROC curve data saved to: {roc_csv_path}")

In [None]:
evaluate_model2(model, fishaudio_loader, "fishaudio_whisper_summary2.txt")
evaluate_model2(model, xtts_loader, "xtts_whisper_summary2.txt")
evaluate_model2(model, mms_loader, "mms_whisper_summary2.txt")
evaluate_model2(model, t5_loader, "t5_whisper_summary2.txt")