In [5]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import pairwise_distances
from nltk.metrics.distance import edit_distance

# Load evaluation data
evaluation_data = pd.read_csv('results.csv')
evaluation_data = evaluation_data[evaluation_data['TestType'] == 'intelligibility']

# Load reference text
with open('../datasets/test_sentences.txt', 'r') as f:
    reference_texts = f.readlines()
reference_texts = [text.strip() for text in reference_texts]  # Remove newlines

print(len(evaluation_data['AudioID'].unique()))
print(len(reference_texts))

# Ensure the number of audio samples matches the number of reference texts
#assert len(evaluation_data['AudioID'].unique()) == len(reference_texts), "Mismatch between audio samples and reference texts"

# Map AudioID to reference text
audio_to_reference = {i+1: text for i, text in enumerate(reference_texts)}

# Function to calculate Word Error Rate (WER)
def calculate_wer(reference, hypothesis):
    ref_words = reference.split()
    hyp_words = hypothesis.split()
    # Calculate edit distance at word level
    distance = edit_distance(ref_words, hyp_words)
    return distance / max(len(ref_words), 1)  # Avoid division by zero

# Calculate WER for each participant's transcription
evaluation_data['Reference'] = evaluation_data['AudioID'].map(audio_to_reference)
evaluation_data['WER'] = evaluation_data.apply(
    lambda row: calculate_wer(row['Reference'], row['Response']), axis=1
)

# Group by ModelID to calculate average WER
model_wer = evaluation_data.groupby('ModelID')['WER'].mean().reset_index()

# Plot 1: Bar plot for average WER by model
plt.figure(figsize=(10, 6))
sns.barplot(x='ModelID', y='WER', data=model_wer)
plt.title('Average Word Error Rate (WER) by Model')
plt.ylabel('WER (Lower is Better)')
plt.show()

# Plot 2: Box plot for WER distribution by model
plt.figure(figsize=(10, 6))
sns.boxplot(x='ModelID', y='WER', data=evaluation_data)
plt.title('Distribution of Word Error Rate (WER) by Model')
plt.ylabel('WER (Lower is Better)')
plt.show()

20
100


AttributeError: 'float' object has no attribute 'split'