# MLAAD

In [4]:
import os
import sys
from typing import Dict, Optional, List, Union, Tuple
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, recall_score, precision_score
from pathlib import Path

# Constants
METADATA_PATH = "/nvme1_hungdx/Lightning-hydra/data/MLAAD/protocol.txt"
META_CSV_PATH = "/nvme1_hungdx/Lightning-hydra/data/MLAAD/meta.csv"
PREDICTION_FILE = "/home/hungdx/code/Lightning-hydra/logs/results/benchmark_kd/Distil_W2V_5_Conf-TCM_wo_norm_ws_cv_emb_c_m_stage2_op_fullset/MLAAD_distil_distil_wav2vec2_n_trans_layers_conformertcm_Distil_W2V_5_Conf-TCM_wo_norm_ws_cv_emb_c_m_stage2_op_fullset.txt"
# PREDICTION_FILE="/nvme1/hungdx/Lightning-hydra/logs/results/cnsl_benchmark/ToP_April/MLAAD_cnsl_xlsr_vib_large_corpus_ToP_April.txt"

class MetricsCalculator:
    @staticmethod
    def calculate_metrics(df: pd.DataFrame, group_column: Optional[str] = None) -> Dict[str, Union[float, Dict[str, Union[float, int]]]]:
        """Calculate various metrics for the given DataFrame.
        
        Args:
            df: DataFrame containing predictions and ground truth
            group_column: Optional column name to group results by
            
        Returns:
            Dictionary containing calculated metrics and sample counts
        """
        results = {
            'overall': accuracy_score(df["label"], df["pred"]) * 100,
            'f1': f1_score(df['label'], df['pred'], pos_label='bonafide'),
            'recall': recall_score(df['label'], df['pred'], pos_label='bonafide'),
            'precision': precision_score(df['label'], df['pred'], pos_label='bonafide'),
            'total_samples': len(df)
        }
        
        if group_column and group_column in df.columns:
            group_metrics = {}
            for group, group_df in df.groupby(group_column):
                group_metrics[group] = {
                    'accuracy': accuracy_score(group_df["label"], group_df["pred"]) * 100,
                    'samples': len(group_df)
                }
            results['groups'] = group_metrics
        
        return results

def load_metadata() -> pd.DataFrame:
    """Load and process metadata files.
    
    Returns:
        DataFrame containing merged metadata
    """
    try:
        metadata = pd.read_csv(METADATA_PATH, sep=" ", header=None)
        metadata.columns = ["path", "subset", "label"]
        
        meta_csv = pd.read_csv(META_CSV_PATH, sep="|")
        
        metadata = metadata.merge(meta_csv, on='path', how='left')
        metadata.rename(columns={
            'subset_x': 'subset',
            'label_y': 'label'
        }, inplace=True)
        
        return metadata
    except Exception as e:
        raise RuntimeError(f"Failed to load metadata: {str(e)}")

def process_prediction_file(score_file: str, metadata_df: pd.DataFrame) -> pd.DataFrame:
    """Process a single prediction file and return results DataFrame.
    
    Args:
        score_file: Path to the prediction file
        metadata_df: DataFrame containing metadata
        
    Returns:
        DataFrame containing processed predictions
    """
    try:
        pred_df = pd.read_csv(score_file, sep=" ", header=None)
        pred_df.columns = ["path", "spoof", "score"]
        pred_df = pred_df.drop_duplicates(subset=['path'])
        
        merged_df = pred_df.merge(metadata_df, on='path', how='left')
        merged_df['pred'] = merged_df.apply(
            lambda x: 'bonafide' if x['spoof'] < x['score'] else 'spoof', axis=1)
        
        return merged_df[merged_df['subset'] == 'eval'].copy()
    except Exception as e:
        raise RuntimeError(f"Failed to process prediction file {score_file}: {str(e)}")

def print_results(model_results: Dict, model_name: str, original_results: Optional[Dict] = None) -> None:
    """Print evaluation results in a formatted way.
    
    Args:
        model_results: Dictionary containing model metrics
        model_name: Name of the model
        original_results: Optional dictionary containing original model results for comparison
    """
    print(f"\n{'-'*70}")
    print(f"Model: {model_name}")
    
    print(f"\nTotal Samples: {model_results['total_samples']}")
    print(f"Overall Accuracy: {model_results['overall']:.2f}%")
    
    if 'groups' in model_results:
        print("\nAccuracy by group:")
        for group, metrics in model_results['groups'].items():
            print(f"  {group}:")
            print(f"    Accuracy: {metrics['accuracy']:.2f}%")
            print(f"    Samples: {metrics['samples']}")
    
    print("\nAdditional metrics:")
    print(f"  F1 Score: {model_results['f1']:.4f}")
    print(f"  Recall: {model_results['recall']:.4f}")
    print(f"  Precision: {model_results['precision']:.4f}")
    
    print(f"{'-'*70}")

def main() -> None:
    """Main function to run the evaluation pipeline."""
    try:
        print("Loading metadata...")
        metadata_df = load_metadata()
        
        prediction_files = [PREDICTION_FILE]
        prediction_files = sorted(prediction_files)
        
        all_results = {}
        
        for score_file in prediction_files:
            model_name = Path(score_file).name
            print(f"\nProcessing {model_name}...")
            
            results_df = process_prediction_file(score_file, metadata_df)
            metrics = MetricsCalculator.calculate_metrics(results_df, group_column='architecture')
            all_results[model_name] = metrics
        
        for model_name, metrics in all_results.items():
            print_results(metrics, model_name)
            
    except Exception as e:
        print(f"Error in main execution: {str(e)}")
        sys.exit(1)

if __name__ == "__main__":
    main()

Loading metadata...


  meta_csv = pd.read_csv(META_CSV_PATH, sep="|")



Processing MLAAD_distil_distil_wav2vec2_n_trans_layers_conformertcm_Distil_W2V_5_Conf-TCM_wo_norm_ws_cv_emb_c_m_stage2_op_fullset.txt...

----------------------------------------------------------------------
Model: MLAAD_distil_distil_wav2vec2_n_trans_layers_conformertcm_Distil_W2V_5_Conf-TCM_wo_norm_ws_cv_emb_c_m_stage2_op_fullset.txt

Total Samples: 418059
Overall Accuracy: 89.88%

Accuracy by group:
  -:
    Accuracy: 95.16%
    Samples: 243059
  FishTTS:
    Accuracy: 97.50%
    Samples: 3000
  Mars5:
    Accuracy: 82.50%
    Samples: 1000
  MatchaTTS:
    Accuracy: 99.20%
    Samples: 1000
  MegaTTS3:
    Accuracy: 66.60%
    Samples: 2000
  MeloTTS:
    Accuracy: 98.80%
    Samples: 1000
  Metavoice-1B:
    Accuracy: 63.60%
    Samples: 1000
  Nari Dia-1.6B:
    Accuracy: 57.80%
    Samples: 1000
  OpenVoiceV2:
    Accuracy: 99.62%
    Samples: 4000
  Resemble.ai (April 12th, 2025):
    Accuracy: 97.32%
    Samples: 5000
  Spark-TTS-0.5B:
    Accuracy: 69.20%
    Samples: 1000
