In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def plot_roc_curves(roc_files, model_names, save_path='x.png'):
    plt.figure(figsize=(12, 9))
    
    colors = ['darkorange', 'blue', 'green', 'red', 'purple', 'brown']
    
    for idx, (file_path, model_name) in enumerate(zip(roc_files, model_names)):
        df = pd.read_csv(file_path)
        fpr = df['fpr'].values
        tpr = df['tpr'].values
        thresholds = df['threshold'].values
        
        auroc = np.trapz(tpr, fpr)
        
        eer_idx = np.argmin(np.abs(fpr - (1 - tpr)))
        eer_value = fpr[eer_idx]
        eer_threshold = thresholds[eer_idx]
        
        color = colors[idx % len(colors)]
        plt.plot(fpr, tpr, color=color, lw=2.5, 
                label=f'{model_name} (AUROC={auroc:.4f}, EER={eer_value:.4f})')
        
        plt.plot(fpr[eer_idx], tpr[eer_idx], 'o', color=color, 
                markersize=8, markeredgecolor='black', markeredgewidth=1)
    
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', 
            label='Random Classifier (AUROC=0.5)', alpha=0.6)
    
    plt.plot([0, 1], [1, 0], 'gray', linestyle=':', alpha=0.4, 
            label='EER line (FPR = 1-TPR)')
    
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate (FPR)', fontsize=14, fontweight='bold')
    plt.ylabel('True Positive Rate (TPR)', fontsize=14, fontweight='bold')
    plt.legend(loc="lower right", fontsize=11, framealpha=0.9)
    plt.grid(True, alpha=0.3, linestyle='--')
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=600, bbox_inches='tight')
    print(f"Combined ROC curve saved to: {save_path}")
    plt.show()


In [None]:
# Whisper
whisper_files = [
    'fishaudio_whisper_summary2_roc_curve.csv',
    'xtts_whisper_summary2_roc_curve.csv',
    't5_whisper_summary2_roc_curve.csv',
    'mms_whisper_summary2_roc_curve.csv'
]

# Wav2vec2
wav2vec_files = [
    'fishaudio_wav2vec2_summary2_roc_curve.csv',
    'xtts_wav2vec2_summary2_roc_curve.csv',
    't5_wav2vec2_summary2_roc_curve.csv',
    'mms_wav2vec2_summary2_roc_curve.csv'
]

model_names = [
    'FishSpeech',
    'XTTS',
    'SpeechT5',
    'MMS'
]

In [None]:
plot_roc_curves(whisper_files, model_names, save_path='../plots/whisper_roc_600.png')

In [None]:
plot_roc_curves(wav2vec_files, model_names, save_path='../plots/wav2vec2_roc_600.png')