# SSML Evaluation Analysis

This notebook analyzes the SSML prediction results saved as JSON files from the model.py script.

It includes:
1. Visualizations of tag usage metrics
2. Error metrics analysis (MAE/MSE for prosody parameters)
3. Break precision/recall/F1 analysis
4. Model comparisons

In [None]:
import os

# set cuda visible devices to 2
os.environ["CUDA_VISIBLE_DEVICES"] = "2"

import json
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, List, Any, Optional
import glob
from tqdm.notebook import tqdm

In [None]:
viz_dir = "visualizations"
os.makedirs(viz_dir, exist_ok=True)

## Load Results Files

In [None]:
def load_results(results_dir="results", model_name=None):
    """Load results from the specified directory"""
    
    results = {}
    
    # Update the search pattern to match the new file naming convention
    if model_name:
        # Look for files like "model_name_zero_shot.json" or "model_name_few_shot_*.json"
        search_path = os.path.join(results_dir, f"{model_name}_*.json")
    else:
        search_path = os.path.join(results_dir, "*.json")
    
    # Find all json files matching the pattern
    json_files = glob.glob(search_path, recursive=True)
    
    for file_path in json_files:
        try:
            with open(file_path, 'r', encoding='utf-8') as f:
                file_name = os.path.basename(file_path)
                results[file_name] = json.load(f)
                print(f"Loaded {file_name}")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")
    
    return results

In [None]:
# Load results for a specific model
model_name = "mistral"  # Change to the model you want to analyze
results = load_results(model_name=model_name)

## Basic Metrics Visualization

Visualize tag counts, prosody parameters, etc.

In [None]:
from sklearn.metrics import r2_score


def compute_r2(evaluation_results: dict, key: str) -> float:
    """
    Compute R² for a given parameter key comparing predictions to ground truth.
    """
    metrics = evaluation_results.get("metrics", {})
    
    # Get true and predicted values based on the key
    if key == "break_time":
        y_true = metrics.get("true_break_time_values", [])
        y_pred = metrics.get("pred_break_time_values", [])
    elif key == "pitch":
        y_true = metrics.get("true_pitch_values", [])
        y_pred = metrics.get("pred_pitch_values", [])
    elif key == "rate":
        y_true = metrics.get("true_rate_values", [])
        y_pred = metrics.get("pred_rate_values", [])
    elif key == "volume":
        y_true = metrics.get("true_volume_values", [])
        y_pred = metrics.get("pred_volume_values", [])
    else:
        print(f"Unknown key: {key}")
        return float('nan')
    
    print(f"R² calculation for {key}: {len(y_true)} true values, {len(y_pred)} predicted values")
    
    # Now both arrays should be the same length (paired data)
    if len(y_true) != len(y_pred):
        print(f"ERROR: Mismatched lengths for {key}: true={len(y_true)}, pred={len(y_pred)}")
        return float('nan')
    
    if len(y_true) < 2:
        print(f"Not enough paired data for {key}: {len(y_true)} pairs")
        return float('nan')
    
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    
    # Check if there's any variance in true values
    var_true = np.var(y_true)
    if var_true == 0:
        print(f"No variance in true {key} values (all values are the same)")
        return float('nan')
    
    # Debug info
    print(f"True values range: {np.min(y_true):.2f} to {np.max(y_true):.2f}, var={var_true:.3f}")
    print(f"Predicted values range: {np.min(y_pred):.2f} to {np.max(y_pred):.2f}, var={np.var(y_pred):.3f}")
    
    try:
        r2 = r2_score(y_true, y_pred)
        print(f"R² for {key}: {r2:.3f}")
        
        # Warn if R² is very negative
        if r2 < -1:
            print(f"WARNING: Very negative R² ({r2:.3f}) suggests poor model performance")
        
        return r2
    except Exception as e:
        print(f"Error calculating R² for {key}: {e}")
        return float('nan')

In [None]:
def plot_metrics(evaluation_results: Dict[str, Any], save_path: Optional[str] = None):
    """Visualize tag usage and prosody metrics."""
    metrics = evaluation_results.get("metrics", {})
    name = evaluation_results.get("model_name", "model")

    plt.figure(figsize=(12, 8))
    # Tag counts
    plt.subplot(2, 2, 1)
    counts = {
        "Total Tags": metrics.get("total_tags_mean", 0),
        "Prosody Tags": metrics.get("prosody_count_mean", 0),
        "Break Tags": metrics.get("break_count_mean", 0)
    }
    plt.bar(counts.keys(), counts.values())
    plt.title(f"Tag Distribution - {name}")
    plt.ylabel("Avg Count")

    # Prosody means
    plt.subplot(2, 2, 2)
    pros_means = {
        "Pitch": metrics.get("pitch_adjustments_mean_mean", 0),
        "Rate": metrics.get("rate_adjustments_mean_mean", 0),
        "Volume": metrics.get("volume_adjustments_mean_mean", 0)
    }
    plt.bar(pros_means.keys(), pros_means.values())
    plt.title(f"Avg Prosody Parameters - {name}")
    plt.ylabel("Mean (%)")

    # Prosody variability
    plt.subplot(2, 2, 3)
    pros_vars = {
        "Pitch": metrics.get("pitch_adjustments_std_mean", 0),
        "Rate": metrics.get("rate_adjustments_std_mean", 0),
        "Volume": metrics.get("volume_adjustments_std_mean", 0)
    }
    plt.bar(pros_vars.keys(), pros_vars.values())
    plt.title(f"Prosody Variability - {name}")
    plt.ylabel("Std Dev (%)")

    # Tags per sentence
    plt.subplot(2, 2, 4)
    plt.bar(["Tags/Sent"], [metrics.get("tags_per_sentence_mean", 0)])
    plt.title(f"Tag Density - {name}")
    plt.ylabel("Avg Tags/Sentence")

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
        plt.close()
    else:
        plt.show()


## Error Metrics Visualization

Visualize MAE/MSE and Precision/Recall/F1 metrics

In [None]:
def plot_error_and_r2(evaluation_results: Dict[str, Any], save_path: Optional[str] = None):
    """Plot MAE, MSE, and R² for prosody and break_time."""
    m = evaluation_results.get("metrics", {})
    params = ["pitch", "rate", "volume", "break_time"]

    labels, mae_vals, mse_vals, r2_vals = [], [], [], []
    for p in params:
        mae_k, mse_k = f"{p}_mae", f"{p}_mse"
        if mae_k in m and mse_k in m:
            labels.append(p.capitalize())
            mae_vals.append(m[mae_k])
            mse_vals.append(m[mse_k])
            # Call compute_r2 with the evaluation_results dict, not a list
            r2_val = compute_r2(evaluation_results, p)
            r2_vals.append(r2_val)
            print(f"Added R² value for {p}: {r2_val}")

    x = np.arange(len(labels))
    width = 0.25
    fig, axes = plt.subplots(3, 1, figsize=(10, 18))

    # MAE/MSE
    ax1 = axes[0]
    ax1.bar(x - width/2, mae_vals, width, label="MAE", alpha=0.7)
    ax1.bar(x + width/2, mse_vals, width, label="MSE", alpha=0.7)
    ax1.set_xticks(x)
    ax1.set_xticklabels(labels)
    ax1.set_title(f"MAE vs MSE - {evaluation_results.get('model_name', '')}")
    ax1.grid(axis='y', linestyle='--', alpha=0.7)
    ax1.legend()

    # Precision/Recall/F1 if present
    if all(k in m for k in ["break_precision", "break_recall", "break_f1"]):
        ax2 = axes[1]
        prf_labels = ["Breaks"]
        prf_vals = [m["break_precision"], m["break_recall"], m["break_f1"]]
        x2 = np.arange(len(prf_labels))
        w2 = width
        ax2.bar(x2 - w2, [prf_vals[0]], w2, label="Precision", alpha=0.7)
        ax2.bar(x2, [prf_vals[1]], w2, label="Recall", alpha=0.7)
        ax2.bar(x2 + w2, [prf_vals[2]], w2, label="F1", alpha=0.7)
        ax2.set_xticks(x2)
        ax2.set_xticklabels(prf_labels)
        ax2.set_title(f"Break P/R/F1 - {evaluation_results.get('model_name', '')}")
        ax2.set_ylim(0,1.1)
        ax2.grid(axis='y', linestyle='--', alpha=0.7)
        ax2.legend()

    # R²
    ax3 = axes[-1]
    # Convert NaN values to 0 for display purposes
    r2_display_vals = []
    for val in r2_vals:
        if np.isnan(val):
            r2_display_vals.append(0)
        elif val < -1:  # For very negative R² values, clamp to -1 for visualization
            r2_display_vals.append(-1)
        else:
            r2_display_vals.append(val)
    
    print(f"Original R² values: {r2_vals}")
    print(f"R² display values: {r2_display_vals}")
    
    # Create the bars with clamped values for display
    bars = ax3.bar(x, r2_display_vals, width, label="R²", alpha=0.7)
    
    # Set colors based on R² value
    for i, bar in enumerate(bars):
        if not np.isnan(r2_vals[i]):
            if r2_vals[i] < 0:
                bar.set_color('salmon')  # Negative R² in red
            else:
                bar.set_color('skyblue')  # Positive R² in blue
    
    ax3.set_xticks(x)
    ax3.set_xticklabels(labels)
    ax3.set_title(f"Coefficient of Determination (R²) - {evaluation_results.get('model_name', '')}")
    ax3.set_ylim(-1.1, 1.1)  # R² can be negative
    ax3.grid(axis='y', linestyle='--', alpha=0.7)
    ax3.axhline(y=0, color='red', linestyle='--', alpha=0.5)  # Add zero line
    ax3.legend()
    
    # Add value labels on bars - use original values for the text
    for i, (bar, val) in enumerate(zip(bars, r2_vals)):
        if not np.isnan(val):
            # Format text based on the magnitude of R²
            if val < -10:
                label_text = f"{val:.1f}"  # Shorter format for very negative values
            else:
                label_text = f"{val:.2f}"
            
            # Position text based on bar height and value
            display_val = r2_display_vals[i]
            if display_val >= 0:
                y_pos = display_val + 0.1
                va = 'bottom'
            else:
                # For negative bars, position text above the bar
                y_pos = -0.05
                va = 'bottom'
            
            ax3.text(bar.get_x() + bar.get_width()/2, y_pos, 
                    label_text, ha='center', va=va, 
                    color='black', fontweight='bold', fontsize=9)

    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
def load_all_model_results(models=None, results_dir="results"):
    """
    Load results from multiple model files with the new naming pattern
    
    Args:
        models: List of model names to load (None for all model names in filenames)
        results_dir: Base results directory
    
    Returns:
        Dictionary of results by model name and result type
    """
    all_results = {}
    
    # If no models specified, extract model names from filenames
    if models is None:
        models = set()
        for item in os.listdir(results_dir):
            if item.endswith('.json'):
                # Extract model name from file patterns like "mistral_zero_shot.json"
                parts = item.split('_')
                if len(parts) >= 2:
                    models.add(parts[0])
    
    print(f"Loading results for models: {models}")
    
    for model in models:
        # Find all files for this model
        model_results = load_results(results_dir=results_dir, model_name=model)
        all_results[model] = model_results
    
    return all_results

In [None]:
# Simple SSML similarity using sentence transformers
def calculate_ssml_similarity(all_results, save_path=None):
    """
    Calculate and plot SSML cosine similarity using sentence transformers.
    Simple approach: just encode the SSML strings directly.
    """
    try:
        from sentence_transformers import SentenceTransformer
        from sklearn.metrics.pairwise import cosine_similarity
        model = SentenceTransformer('all-MiniLM-L6-v2')
        print("Using sentence transformers for SSML similarity")
    except ImportError:
        print("sentence-transformers not available, skipping SSML similarity analysis")
        return {}
    
    model_similarities = {}
    
    for model_name, results in all_results.items():
        model_similarities[model_name] = {}
        
        for file_name, result_data in results.items():
            # Extract approach (zero_shot, few_shot, etc.)
            if "_zero_shot.json" in file_name:
                approach = "zero_shot"
            elif "_few_shot.json" in file_name:
                approach = "few_shot"
            else:
                continue
            
            # Get SSML pairs
            if "results" not in result_data:
                continue
                
            pred_ssmls = []
            gold_ssmls = []
            
            for result in result_data["results"]:
                pred_ssml = result.get("predicted_ssml") or result.get("ssml", "")
                gold_ssml = result.get("gold_ssml", "")
                
                if pred_ssml.strip() and gold_ssml.strip():
                    pred_ssmls.append(pred_ssml)
                    gold_ssmls.append(gold_ssml)
            
            if not pred_ssmls:
                print(f"No valid SSML pairs found for {model_name} {approach}")
                continue
            
            # Calculate embeddings - just encode the SSML directly
            pred_embeddings = model.encode(pred_ssmls)
            gold_embeddings = model.encode(gold_ssmls)
            
            # Calculate pairwise cosine similarities
            similarities = []
            for pred_emb, gold_emb in zip(pred_embeddings, gold_embeddings):
                similarity = cosine_similarity([pred_emb], [gold_emb])[0, 0]
                similarities.append(similarity)
            
            model_similarities[model_name][approach] = {
                "cosine_similarities": similarities,
                "mean_similarity": np.mean(similarities),
                "std_similarity": np.std(similarities),
                "n_pairs": len(similarities)
            }
            
            print(f"{model_name} {approach}: {np.mean(similarities):.3f} ± {np.std(similarities):.3f} (n={len(similarities)})")
    
    return model_similarities

## Model Comparison Visualization

In [None]:
# Update your compare_models function to actually plot SSML similarities
def compare_models(comparison_results: Dict[str, Any], save_path: Optional[str] = None):
    """
    Plot comparison of multiple models including SSML cosine similarity.
    """
    evaluations = comparison_results["evaluations"]
    model_names = [eval_result["model_name"] for eval_result in evaluations]
    num_models = len(model_names)
    x = np.arange(num_models)
    width = 0.7

    # Extract all existing metrics (same as before)
    total_tags = [e["metrics"].get("total_tags_mean", 0) for e in evaluations]
    prosody_tags = [e["metrics"].get("prosody_count_mean", 0) for e in evaluations]
    break_tags = [e["metrics"].get("break_count_mean", 0) for e in evaluations]
    
    gold_total_tags = [e["metrics"].get("gold_total_tags_mean", 0) for e in evaluations]
    gold_prosody_tags = [e["metrics"].get("gold_prosody_count_mean", 0) for e in evaluations]
    gold_break_tags = [e["metrics"].get("gold_break_count_mean", 0) for e in evaluations]
    
    pitch_mae = [e["metrics"].get("pitch_mae", 0) for e in evaluations]
    rate_mae = [e["metrics"].get("rate_mae", 0) for e in evaluations]
    break_time_mae = [e["metrics"].get("break_time_mae", 0) for e in evaluations]
    volume_mae = [e["metrics"].get("volume_mae", 0) for e in evaluations]
    
    pitch_mse = [e["metrics"].get("pitch_mse", 0) for e in evaluations]
    rate_mse = [e["metrics"].get("rate_mse", 0) for e in evaluations]
    break_time_mse = [e["metrics"].get("break_time_mse", 0) for e in evaluations]
    volume_mse = [e["metrics"].get("volume_mse", 0) for e in evaluations]
    
    # Compute RMSE from MSE
    pitch_rmse = np.sqrt(pitch_mse)
    rate_rmse = np.sqrt(rate_mse)
    break_time_rmse = np.sqrt(break_time_mse)
    volume_rmse = np.sqrt(volume_mse)
    
    # Calculate R² values for each model
    pitch_r2 = []
    rate_r2 = []
    break_time_r2 = []
    volume_r2 = []
    for eval_result in evaluations:
        pitch_r2.append(compute_r2(eval_result, "pitch"))
        rate_r2.append(compute_r2(eval_result, "rate"))
        break_time_r2.append(compute_r2(eval_result, "break_time"))
        volume_r2.append(compute_r2(eval_result, "volume"))
    
    precision = [e["metrics"].get("break_precision", 0) for e in evaluations]
    recall = [e["metrics"].get("break_recall", 0) for e in evaluations]
    f1 = [e["metrics"].get("break_f1", 0) for e in evaluations]
    
    tags_per_sent = [e["metrics"].get("tags_per_sentence_mean", 0) for e in evaluations]

    # Calculate SSML similarities for these evaluations
    ssml_similarities = []
    for eval_result in tqdm(evaluations, desc="Calculating SSML similarities"):
        if "results" not in eval_result:
            ssml_similarities.append(0.0)
            continue
            
        try:
            from sentence_transformers import SentenceTransformer
            from sklearn.metrics.pairwise import cosine_similarity
            model = SentenceTransformer('all-MiniLM-L6-v2')
            
            pred_ssmls = []
            gold_ssmls = []
            
            for result in eval_result["results"]:
                pred_ssml = result.get("predicted_ssml") or result.get("ssml", "")
                gold_ssml = result.get("gold_ssml", "")
                
                if pred_ssml.strip() and gold_ssml.strip():
                    pred_ssmls.append(pred_ssml)
                    gold_ssmls.append(gold_ssml)
            
            if pred_ssmls:
                pred_embeddings = model.encode(pred_ssmls)
                gold_embeddings = model.encode(gold_ssmls)
                
                similarities = []
                for pred_emb, gold_emb in zip(pred_embeddings, gold_embeddings):
                    similarity = cosine_similarity([pred_emb], [gold_emb])[0, 0]
                    similarities.append(similarity)
                
                ssml_similarities.append(np.mean(similarities))
                print(f"SSML similarity for {eval_result.get('model_name', 'unknown')}: {np.mean(similarities):.3f}")
            else:
                ssml_similarities.append(0.0)
                
        except ImportError:
            print("sentence-transformers not available, skipping SSML similarity")
            ssml_similarities.append(0.0)

    # Create a 6x3 grid of subplots (keep your existing structure)
    fig, axes = plt.subplots(6, 3, figsize=(22, 32))
    plt.subplots_adjust(bottom=0.15, hspace=0.4, wspace=0.3)

    # Helper for x-tick labels
    def set_xticks_labels(ax):
        ax.set_xticks(x)
        ax.set_xticklabels(model_names, rotation=30, ha='right', fontsize=11)

    # Row 1: Tag metrics (same as before)
    for i, (ax, vals, gold_vals, title) in enumerate(zip(axes[0], 
                                  [total_tags, prosody_tags, break_tags],
                                  [gold_total_tags, gold_prosody_tags, gold_break_tags],
                                  ["Total Tags", "Prosody Tags", "Break Tags"])):
        width_bar = 0.35
        ax.bar(x - width_bar/2, gold_vals, width_bar, color='goldenrod', label='Gold Standard')
        ax.bar(x + width_bar/2, vals, width_bar, color='skyblue', label='Prediction')
        ax.set_title(title)
        set_xticks_labels(ax)
        ax.set_ylabel('Count')
        if i == 0:
            ax.legend()

    # Row 2: MAE metrics (same as before)
    for ax, vals, title in zip(axes[1], [pitch_mae, rate_mae, break_time_mae],
                               ["Pitch MAE", "Rate MAE", "Break Time MAE"]):
        ax.bar(x, vals, color='lightgreen')
        ax.set_title(title)
        set_xticks_labels(ax)
        ax.set_ylabel('MAE')

    # ADD SSML SIMILARITY to the fourth position in row 2
    if any(s > 0 for s in ssml_similarities):  # Only plot if we have valid similarities
        ax_sim = axes[1, 2]  # Fourth position in row 2 (0-indexed, so position 2)
        bars = ax_sim.bar(x, ssml_similarities, color='mediumpurple')
        ax_sim.set_title('SSML Cosine Similarity')
        set_xticks_labels(ax_sim)
        ax_sim.set_ylabel('Similarity')
        ax_sim.set_ylim(0, 1)
        
        # Add value labels on bars
        for bar, similarity in zip(bars, ssml_similarities):
            if similarity > 0:
                ax_sim.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                           f'{similarity:.3f}', ha='center', va='bottom', fontweight='bold')

    # Row 3: MSE metrics (same as before)
    for ax, vals, title in zip(axes[2], [pitch_mse, rate_mse, break_time_mse],
                               ["Pitch MSE", "Rate MSE", "Break Time MSE"]):
        ax.bar(x, vals, color='orange')
        ax.set_title(title)
        set_xticks_labels(ax)
        ax.set_ylabel('MSE')

    # Row 4: RMSE metrics (same as before)
    for ax, vals, title in zip(axes[3], [pitch_rmse, rate_rmse, break_time_rmse],
                               ["Pitch RMSE", "Rate RMSE", "Break Time RMSE"]):
        ax.bar(x, vals, color='violet')
        ax.set_title(title)
        set_xticks_labels(ax)
        ax.set_ylabel('RMSE')

    # Row 5: R² metrics (same as before)
    for ax, vals, title in zip(axes[4], [pitch_r2, rate_r2, break_time_r2],
                               ["Pitch R²", "Rate R²", "Break Time R²"]):
        ax.bar(x, vals, color='dodgerblue')
        ax.set_title(title)
        set_xticks_labels(ax)
        ax.set_ylabel('R²')
        ax.set_ylim(-1.1, 1.05)
        ax.axhline(y=0, color='red', linestyle='--', alpha=0.5)

    # Row 6: Break P/R/F1, Tags/Sentence, Volume MAE (same as before)
    ax_prf, ax_tps, ax_vol = axes[5]
    
    # Break Precision/Recall/F1
    width_prf = 0.25
    ax_prf.bar(x - width_prf, precision, width_prf, label="Precision", alpha=0.7, color='royalblue')
    ax_prf.bar(x, recall, width_prf, label="Recall", alpha=0.7, color='lightblue')
    ax_prf.bar(x + width_prf, f1, width_prf, label="F1", alpha=0.7, color='navy')
    ax_prf.set_title("Break Precision/Recall/F1")
    set_xticks_labels(ax_prf)
    ax_prf.set_ylim(0, 1.1)
    ax_prf.legend()
    ax_prf.set_ylabel('Score')

    # Tags per sentence
    ax_tps.bar(x, tags_per_sent, color='teal')
    ax_tps.set_title("Tags per Sentence")
    set_xticks_labels(ax_tps)
    ax_tps.set_ylabel('Tags/Sentence')

    # Volume MAE
    ax_vol.bar(x, volume_mae, color='coral')
    ax_vol.set_title("Volume MAE")
    set_xticks_labels(ax_vol)
    ax_vol.set_ylabel('MAE')

    fig.suptitle('Model Comparison Metrics (including SSML Similarity)', fontsize=18, fontweight='bold')

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.close()
    else:
        plt.show()

In [None]:
def compare_model_architectures(all_results, mode="zero_shot", save_path=None):
    """
    Compare results across different model architectures with new file naming
    
    Args:
        all_results: Dictionary of results loaded with load_all_model_results
        mode: Either "zero_shot", "few_shot", or "both"
        save_path: Optional path to save the comparison plot
    """
    evaluations = []
    model_mappings = []  # Store which model each evaluation belongs to
    
    # Process results based on mode
    if mode in ["zero_shot", "both"]:
        # Extract zero-shot results from each model
        for model, results in all_results.items():
            # Find files like "model_zero_shot.json"
            zero_shot_file = next((f for f in results if f == f"{model}_zero_shot.json"), None)
            if zero_shot_file:
                evaluations.append(results[zero_shot_file])
                model_mappings.append((model, "zero_shot"))
                print(f"Added zero-shot results for {model}")
    
    if mode in ["few_shot", "both"]:
        # Extract few-shot results from each model
        for model, results in all_results.items():
            # Changed from model_few_shot_all.json to model_few_shot.json
            few_shot_file = next((f for f in results if f == f"{model}_few_shot.json"), None)
            if few_shot_file:
                evaluations.append(results[few_shot_file])
                model_mappings.append((model, "few_shot"))
                print(f"Added few-shot results for {model}")
    
    if not evaluations:
        print("No matching results found to compare")
        return
    
    # Customize model names for better display
    for i, eval_result in enumerate(evaluations):
        if i < len(model_mappings):
            model_name, approach = model_mappings[i]
            # Set a clear model name that includes both architecture and approach
            eval_result["model_name"] = f"{approach.replace('_', '-').title()}-{model_name.capitalize()}"
    
    # Create comparison object and plot
    comparison = {"evaluations": evaluations}
    compare_models(comparison, save_path=save_path)
    
    return comparison

In [None]:
def compare_models_grouped(all_results, save_path=None):
    """
    Compare models with zero-shot and few-shot grouped together and sorted by performance.
    
    Args:
        all_results: Dictionary of results loaded with load_all_model_results
        save_path: Optional path to save the comparison plot
    """
    # Extract model names and prepare data structure
    models = list(all_results.keys())
    model_data = {}
    
    for model in models:
        results = all_results[model]
        model_data[model] = {
            'zero_shot': None,
            'few_shot': None
        }
        
        # Get zero-shot and few-shot results if available
        zero_shot_file = next((f for f in results if f == f"{model}_zero_shot.json"), None)
        few_shot_file = next((f for f in results if f == f"{model}_few_shot.json"), None)
        
        if zero_shot_file:
            model_data[model]['zero_shot'] = results[zero_shot_file]
        if few_shot_file:
            model_data[model]['few_shot'] = results[few_shot_file]
    
    # Create plots with grouped bars
    metrics_to_plot = [
        ('pitch_mae', 'Pitch MAE', True),  # (metric_key, title, lower_is_better)
        ('rate_mae', 'Rate MAE', True),
        ('break_time_mae', 'Break Time MAE', True),
        ('pitch_mse', 'Pitch MSE', True),
        ('rate_mse', 'Rate MSE', True),
        ('break_time_mse', 'Break Time MSE', True),
        ('total_tags_mean', 'Total Tags', False),
        ('prosody_count_mean', 'Prosody Tags', False),
        ('break_count_mean', 'Break Tags', False)
    ]
    
    # Create a 3x3 grid of subplots
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    axes = axes.flatten()
    
    # For each metric, create a grouped bar chart
    for i, (metric_key, title, lower_is_better) in enumerate(metrics_to_plot):
        # Prepare data for this metric
        model_values = []
        for model in models:
            zero_shot_val = None
            few_shot_val = None
            
            if model_data[model]['zero_shot']:
                zero_shot_val = model_data[model]['zero_shot']['metrics'].get(metric_key, 0)
            if model_data[model]['few_shot']:
                few_shot_val = model_data[model]['few_shot']['metrics'].get(metric_key, 0)
            
            # Only include models that have at least one value
            if zero_shot_val is not None or few_shot_val is not None:
                # Use zero_shot_val for sorting if available, otherwise few_shot_val
                sort_val = zero_shot_val if zero_shot_val is not None else few_shot_val
                model_values.append((model, zero_shot_val, few_shot_val, sort_val))
        
        # Sort models by performance (if there's data to sort)
        if model_values:
            if lower_is_better:
                model_values.sort(key=lambda x: x[3])
            else:
                model_values.sort(key=lambda x: x[3], reverse=True)
        
        # Prepare plotting data
        sorted_models = [m[0] for m in model_values]
        zero_shot_values = [m[1] if m[1] is not None else 0 for m in model_values]
        few_shot_values = [m[2] if m[2] is not None else 0 for m in model_values]
        
        # Plot on the corresponding subplot
        ax = axes[i]
        x = np.arange(len(sorted_models))
        width = 0.35
        
        # Create the grouped bars
        ax.bar(x - width/2, zero_shot_values, width, label='Zero-Shot', color='skyblue')
        ax.bar(x + width/2, few_shot_values, width, label='Few-Shot', color='lightcoral')
        
        # Add labels and formatting
        ax.set_title(title)
        ax.set_xticks(x)
        ax.set_xticklabels(sorted_models, rotation=45, ha='right')
        ax.legend()
        ax.grid(axis='y', linestyle='--', alpha=0.7)
    
    # Create separate plot for P/R/F1 (already grouped)
    fig2, ax_prf = plt.subplots(figsize=(10, 6))
    
    # Prepare P/R/F1 data (similar to above but with 3 metrics per model)
    prf_data = []
    for model in models:
        # Initialize with 0 instead of None
        zero_prec = zero_rec = zero_f1 = 0
        few_prec = few_rec = few_f1 = 0
        
        if model_data[model]['zero_shot']:
            m = model_data[model]['zero_shot']['metrics']
            zero_prec = m.get('break_precision', 0) or 0  # Use 0 if None
            zero_rec = m.get('break_recall', 0) or 0
            zero_f1 = m.get('break_f1', 0) or 0
        
        if model_data[model]['few_shot']:
            m = model_data[model]['few_shot']['metrics']
            few_prec = m.get('break_precision', 0) or 0
            few_rec = m.get('break_recall', 0) or 0
            few_f1 = m.get('break_f1', 0) or 0
        
        # Sort by F1 score (use zero-shot or few-shot, whichever is bigger)
        sort_val = max(zero_f1, few_f1) 
        prf_data.append((model, zero_prec, zero_rec, zero_f1, few_prec, few_rec, few_f1, sort_val))
    
    # Sort by F1 score, highest first
    prf_data.sort(key=lambda x: x[7], reverse=True)
    
    # Only create plots if we have data
    if prf_data:
        # Plot P/R/F1 chart
        prf_models = [p[0] for p in prf_data]
        x = np.arange(len(prf_models))
        width = 0.15
        
        # Make sure all values are numeric (not None)
        # Zero-shot bars
        ax_prf.bar(x - width*1.5, [p[1] or 0 for p in prf_data], width, label='Zero-Shot Precision', color='royalblue')
        ax_prf.bar(x - width*0.5, [p[2] or 0 for p in prf_data], width, label='Zero-Shot Recall', color='lightblue')
        ax_prf.bar(x + width*0.5, [p[3] or 0 for p in prf_data], width, label='Zero-Shot F1', color='darkblue')
        
        # Few-shot bars
        ax_prf.bar(x + width*1.5, [p[4] or 0 for p in prf_data], width, label='Few-Shot Precision', color='darkred')
        ax_prf.bar(x + width*2.5, [p[5] or 0 for p in prf_data], width, label='Few-Shot Recall', color='lightcoral')
        ax_prf.bar(x + width*3.5, [p[6] or 0 for p in prf_data], width, label='Few-Shot F1', color='maroon')
        
        ax_prf.set_title('Break Precision/Recall/F1 by Model')
        ax_prf.set_xticks(x + width)
        ax_prf.set_xticklabels(prf_models, rotation=45, ha='right')
        ax_prf.set_ylim(0, 1.1)
        ax_prf.legend()
        ax_prf.grid(axis='y', linestyle='--', alpha=0.7)
    
    plt.tight_layout()
    
    # Save both figures if path provided
    if save_path:
        fig.savefig(f"{save_path}_metrics.png", dpi=300, bbox_inches='tight')
        if prf_data:  # Only save if we have P/R/F1 data
            fig2.savefig(f"{save_path}_prf.png", dpi=300, bbox_inches='tight')
        plt.close(fig)
        if prf_data:
            plt.close(fig2)
    else:
        plt.show()

## Analyze Specific Results

Let's analyze some specific result files

In [None]:
# Analyze zero-shot results if available
zero_shot_file = next((f for f in results if f.startswith("zero_shot")), None)
if zero_shot_file:
    print(f"Analyzing zero-shot results from {zero_shot_file}")
    zero_shot_results = results[zero_shot_file]
    
    # Plot basic metrics
    plot_metrics(zero_shot_results)
    
    # Plot error metrics
    plot_error_and_r2(zero_shot_results)

In [None]:
# Analyze experiment results
if "experiment_results.json" in results:
    experiment_results = results["experiment_results.json"]
    
    # Check if there are few-shot voice-specific results
    if "few_shot" in experiment_results and "by_voice" in experiment_results["few_shot"]:
        for voice, voice_results in experiment_results["few_shot"]["by_voice"].items():
            print(f"Analyzing few-shot results for voice: {voice}")
            
            # Plot basic metrics
            plot_metrics(voice_results)
            
            # Plot error metrics
            plot_error_and_r2(zero_shot_results)

## Model Comparison

Compare zero-shot vs few-shot performance

In [None]:
# Create a comparison dictionary from the experiment results
if "experiment_results.json" in results:
    experiment_results = results["experiment_results.json"]
    
    # Check if both zero-shot and few-shot results are available
    if "zero_shot" in experiment_results and "few_shot" in experiment_results:
        if "all_voices" in experiment_results["zero_shot"] and "all_voices" in experiment_results["few_shot"]:
            # Create a comparison dictionary
            comparison = {
                "evaluations": [
                    experiment_results["zero_shot"]["all_voices"],
                    experiment_results["few_shot"]["all_voices"]
                ]
            }
            
            # Plot comparison
            compare_models(comparison)

In [None]:
# Load results from all model directories
models_to_compare = ["mistral", "llama3", "qwen3:8b", "granite3.3", "deepseek-r1:32b", "qwen3:32b", "qwen2.5:7b"]
all_model_results = load_all_model_results(models=models_to_compare, results_dir="results_100")

# Compare zero-shot across different models
print("Comparing zero-shot performance across model architectures:")
zero_shot_comparison = compare_model_architectures(
    all_model_results, 
    mode="zero_shot",
    save_path=os.path.join(viz_dir, "zero_shot_model_comparison.png")
)

# Compare few-shot across different models
print("Comparing few-shot performance across model architectures:")
few_shot_comparison = compare_model_architectures(
    all_model_results, 
    mode="few_shot",
    save_path=os.path.join(viz_dir, "few_shot_model_comparison.png")
)

# Compare all approaches across different models
print("Comparing all approaches across model architectures:")
all_comparison = compare_model_architectures(
    all_model_results, 
    mode="both",
    save_path=os.path.join(viz_dir, "all_models_comparison.png")
)

print("Creating grouped model comparison (zero-shot and few-shot side by side):")
compare_models_grouped(
    all_model_results,
    save_path=os.path.join(viz_dir, "grouped_model_comparison")
)

### Table for paper

In [None]:
def generate_scientific_table(all_model_results, output_format="latex"):
    """
    Generate a clean, scientific table from model results.
    
    Args:
        all_model_results: Dictionary of results loaded with load_all_model_results
        output_format: "latex", "markdown", or "csv"
        
    Returns:
        Formatted table string
    """
    # Extract metrics for each model and approach
    table_data = []
    
    for model_name, results in all_model_results.items():
        # Process zero-shot results
        zero_shot_file = next((f for f in results if f == f"{model_name}_zero_shot.json"), None)
        if zero_shot_file:
            metrics = results[zero_shot_file]["metrics"]
            
            # Calculate SSML similarity if not already present
            ssml_similarity = 0.0
            if "results" in results[zero_shot_file]:
                try:
                    from sentence_transformers import SentenceTransformer
                    from sklearn.metrics.pairwise import cosine_similarity
                    model = SentenceTransformer('all-MiniLM-L6-v2')
                    
                    pred_ssmls = []
                    gold_ssmls = []
                    
                    for result in results[zero_shot_file]["results"]:
                        pred_ssml = result.get("predicted_ssml") or result.get("ssml", "")
                        gold_ssml = result.get("gold_ssml", "")
                        
                        if pred_ssml.strip() and gold_ssml.strip():
                            pred_ssmls.append(pred_ssml)
                            gold_ssmls.append(gold_ssml)
                    
                    if pred_ssmls:
                        pred_embeddings = model.encode(pred_ssmls)
                        gold_embeddings = model.encode(gold_ssmls)
                        
                        similarities = []
                        for pred_emb, gold_emb in zip(pred_embeddings, gold_embeddings):
                            similarity = cosine_similarity([pred_emb], [gold_emb])[0, 0]
                            similarities.append(similarity)
                        
                        ssml_similarity = np.mean(similarities)
                except ImportError:
                    pass
            
            # Extract metrics including MSE
            row = {
                "Model": model_name,
                "Approach": "Zero-shot",
                "SSML Similarity": ssml_similarity,
                "Break MAE": metrics.get("break_time_mae", 0),
                "Break MSE": metrics.get("break_time_mse", 0),
                "Pitch MAE": metrics.get("pitch_mae", 0),
                "Pitch MSE": metrics.get("pitch_mse", 0),
                "Rate MAE": metrics.get("rate_mae", 0),
                "Rate MSE": metrics.get("rate_mse", 0),
                "Volume MAE": metrics.get("volume_mae", 0),
                "Volume MSE": metrics.get("volume_mse", 0)
            }
            table_data.append(row)
        
        # Process few-shot results
        few_shot_file = next((f for f in results if f == f"{model_name}_few_shot.json"), None)
        if few_shot_file:
            metrics = results[few_shot_file]["metrics"]
            
            # Calculate SSML similarity if not already present
            ssml_similarity = 0.0
            if "results" in results[few_shot_file]:
                try:
                    from sentence_transformers import SentenceTransformer
                    from sklearn.metrics.pairwise import cosine_similarity
                    model = SentenceTransformer('all-MiniLM-L6-v2')
                    
                    pred_ssmls = []
                    gold_ssmls = []
                    
                    for result in results[few_shot_file]["results"]:
                        pred_ssml = result.get("predicted_ssml") or result.get("ssml", "")
                        gold_ssml = result.get("gold_ssml", "")
                        
                        if pred_ssml.strip() and gold_ssml.strip():
                            pred_ssmls.append(pred_ssml)
                            gold_ssmls.append(gold_ssml)
                    
                    if pred_ssmls:
                        pred_embeddings = model.encode(pred_ssmls)
                        gold_embeddings = model.encode(gold_ssmls)
                        
                        similarities = []
                        for pred_emb, gold_emb in zip(pred_embeddings, gold_embeddings):
                            similarity = cosine_similarity([pred_emb], [gold_emb])[0, 0]
                            similarities.append(similarity)
                        
                        ssml_similarity = np.mean(similarities)
                except ImportError:
                    pass
            
            # Extract metrics including MSE
            row = {
                "Model": model_name,
                "Approach": "Few-shot",
                "SSML Similarity": ssml_similarity,
                "Break MAE": metrics.get("break_time_mae", 0),
                "Break MSE": metrics.get("break_time_mse", 0),
                "Pitch MAE": metrics.get("pitch_mae", 0),
                "Pitch MSE": metrics.get("pitch_mse", 0),
                "Rate MAE": metrics.get("rate_mae", 0),
                "Rate MSE": metrics.get("rate_mse", 0),
                "Volume MAE": metrics.get("volume_mae", 0),
                "Volume MSE": metrics.get("volume_mse", 0)
            }
            table_data.append(row)
    
    # Sort by SSML Similarity descending
    table_data.sort(key=lambda x: (-x["SSML Similarity"], x["Model"], x["Approach"]))
    
    # Generate formatted table based on output format
    if output_format == "latex":
        return _generate_latex_table(table_data)
    elif output_format == "markdown":
        return _generate_markdown_table(table_data)
    elif output_format == "csv":
        return _generate_csv_table(table_data)
    else:
        raise ValueError(f"Unsupported output format: {output_format}")

def _generate_latex_table(table_data):
    """Generate LaTeX table format with highlighted best values and nicely formatted model names."""
    # Note: This requires \usepackage{makecell} and \usepackage{threeparttable} in your LaTeX preamble
    header = (
        "\\begin{table*}[htbp!]\n"
        "\\centering\n"
        "\\begin{threeparttable}\n"
        "\\caption{SSML generation performance across models and prompting strategies. Qwen2.5 (7B) offers the best balance of accuracy and efficiency.}\n"
        "\\label{tab:ssml_performance}\n"
        "\\begin{tabular}{l@{\\hspace{8pt}}c@{\\hspace{8pt}}c@{\\hspace{8pt}}c@{\\hspace{8pt}}c@{\\hspace{8pt}}c@{\\hspace{8pt}}c}\n"
        "\\toprule\n"
        "Model & \\makecell{SSML \\\\ Sim. $\\uparrow$} & \\makecell{Pitch \\\\ MAE/RMSE $\\downarrow$} & "
        "\\makecell{Volume \\\\ MAE/RMSE $\\downarrow$} & \\makecell{Rate \\\\ MAE/RMSE $\\downarrow$} & \\makecell{Break Time \\\\ MAE/RMSE $\\downarrow$} \\\\ \n"
        "\\midrule\n"
    )

    # Find best values for MAE (we'll use MAE for highlighting since it's the primary metric)
    best = {
        "SSML Similarity": max(table_data, key=lambda x: x["SSML Similarity"])["SSML Similarity"] if table_data else None,
        "Pitch MAE": min(table_data, key=lambda x: x["Pitch MAE"])["Pitch MAE"] if table_data else None,
        "Volume MAE": min(table_data, key=lambda x: x["Volume MAE"])["Volume MAE"] if table_data else None,
        "Rate MAE": min(table_data, key=lambda x: x["Rate MAE"])["Rate MAE"] if table_data else None,
        "Break MAE": min(table_data, key=lambda x: x["Break MAE"])["Break MAE"] if table_data else None,
    }

    def highlight(val, best_val, is_max):
        tol = 1e-6
        if (is_max and abs(val - best_val) < tol) or (not is_max and abs(val - best_val) < tol):
            return "\\cellcolor[gray]{0.9}"
        return ""
    
    def format_model_name(name):
        """Format model names in a more readable way."""
        # Handle Qwen models with size indicators
        if "qwen" in name.lower():
            if ":" in name:
                base_name, size = name.split(":")
                # Capitalize first letter of base name
                base_name = base_name[0].upper() + base_name[1:]
                # Make size uppercase and add parentheses
                size = size.upper()
                return f"{base_name} ({size})"
            else:
                # Just capitalize if no size indicator
                return name[0].upper() + name[1:]
        
        # Handle DeepSeek
        elif "deepseek" in name.lower():
            if ":" in name:
                base_name, size = name.split(":")
                base_name = base_name.replace("-", "-")  # Keep dash
                # Capitalize first letter and after dash
                parts = base_name.split("-")
                base_name = "-".join([p.capitalize() for p in parts])
                size = size.upper()
                return f"{base_name} ({size})"
            else:
                parts = name.split("-")
                return "-".join([p.capitalize() for p in parts])
        
        # Handle Granite with version
        elif "granite" in name.lower():
            if "." in name:
                base_name, version = name.split(".")
                version = "." + version  # Keep the dot in version
                return f"{base_name.capitalize()} {version}"
            else:
                return name.capitalize()
        
        # For simple names like "mistral" or "llama3"
        else:
            return name.capitalize()
    
    def format_approach(approach):
        """Convert approach to short form."""
        if approach == "Zero-shot":
            return "ZS"
        elif approach == "Few-shot":
            return "FS"
        else:
            return approach

    rows = []
    for row in table_data:
        raw_model_name = row["Model"]
        formatted_model_name = format_model_name(raw_model_name)
        approach_short = format_approach(row["Approach"])
        
        # Combine model name and approach
        model_with_approach = f"{formatted_model_name} ({approach_short})"

        ssml_sim = f"{row['SSML Similarity']:.2f}"
        # Calculate RMSE from MSE
        pitch_rmse = (row['Pitch MSE'] ** 0.5) if row['Pitch MSE'] > 0 else 0
        volume_rmse = (row['Volume MSE'] ** 0.5) if row['Volume MSE'] > 0 else 0
        rate_rmse = (row['Rate MSE'] ** 0.5) if row['Rate MSE'] > 0 else 0
        break_rmse = (row['Break MSE'] ** 0.5) if row['Break MSE'] > 0 else 0
        
        pitch_mae_rmse = f"{row['Pitch MAE']:.2f}/{pitch_rmse:.2f}"
        volume_mae_rmse = f"{row['Volume MAE']:.2f}/{volume_rmse:.2f}"
        rate_mae_rmse = f"{row['Rate MAE']:.2f}/{rate_rmse:.2f}"
        break_mae_rmse = f"{row['Break MAE']:.2f}/{break_rmse:.2f}"

        ssml_sim_cell = highlight(row["SSML Similarity"], best["SSML Similarity"], is_max=True)
        pitch_mae_cell = highlight(row["Pitch MAE"], best["Pitch MAE"], is_max=False)
        volume_mae_cell = highlight(row["Volume MAE"], best["Volume MAE"], is_max=False)
        rate_mae_cell = highlight(row["Rate MAE"], best["Rate MAE"], is_max=False)
        break_mae_cell = highlight(row["Break MAE"], best["Break MAE"], is_max=False)

        # Reorder columns: SSML, Pitch, Volume, Rate, Break Time
        rows.append(
            f"{model_with_approach} & {ssml_sim_cell}{ssml_sim} & {pitch_mae_cell}{pitch_mae_rmse} & {volume_mae_cell}{volume_mae_rmse} & {rate_mae_cell}{rate_mae_rmse} & {break_mae_cell}{break_mae_rmse} \\\\"
        )

    footer = (
        "\\bottomrule\n"
        "\\end{tabular}\n"
        "\\begin{tablenotes}\\footnotesize\n"
        "\\item $\\uparrow$: higher is better, $\\downarrow$: lower is better. ZS: Zero-Shot, FS: Few-Shot.\n"
        "\\end{tablenotes}\n"
        "\\end{threeparttable}\n"
        "\\end{table*}\n"
    )

    return header + "\n".join(rows) + "\n" + footer

def _generate_markdown_table(table_data):
    """Generate Markdown table format."""
    header = "| Model | Approach | SSML Sim. ↑ | Pitch MAE ↓ | Volume MAE ↓ | Rate MAE ↓ | Break Time MAE ↓ |\n"
    header += "|-------|----------|-------------|-------------|-------------|------------|------------------|\n"
    
    rows = []
    for row in table_data:
        model_name = row["Model"]
        approach = row["Approach"]
        ssml_sim = f"{row['SSML Similarity']:.3f}"
        pitch_mae = f"{row['Pitch MAE']:.2f}"
        volume_mae = f"{row['Volume MAE']:.2f}"
        rate_mae = f"{row['Rate MAE']:.2f}"
        break_mae = f"{row['Break MAE']:.2f}"
        
        # Reorder columns: SSML, Pitch, Volume, Rate, Break Time
        rows.append(f"| {model_name} | {approach} | {ssml_sim} | {pitch_mae} | {volume_mae} | {rate_mae} | {break_mae} |")
    
    return header + "\n".join(rows)

def _generate_csv_table(table_data):
    """Generate CSV table format."""
    header = "Model,Approach,SSML Similarity,Pitch MAE,Volume MAE,Rate MAE,Break Time MAE\n"
    
    rows = []
    for row in table_data:
        model_name = row["Model"]
        approach = row["Approach"]
        ssml_sim = f"{row['SSML Similarity']:.3f}"
        pitch_mae = f"{row['Pitch MAE']:.2f}"
        volume_mae = f"{row['Volume MAE']:.2f}"
        rate_mae = f"{row['Rate MAE']:.2f}"
        break_mae = f"{row['Break MAE']:.2f}"
        
        # Reorder columns: SSML, Pitch, Volume, Rate, Break Time
        rows.append(f"{model_name},{approach},{ssml_sim},{pitch_mae},{volume_mae},{rate_mae},{break_mae}")
    
    return header + "\n".join(rows)

In [None]:
latex_table = generate_scientific_table(all_model_results, output_format="latex")
print(latex_table)

### Plot for paper

In [None]:
def generate_scientific_tag_plots(all_model_results, save_path=None, figsize=(10, 3), dpi=500):
    """
    Generate two scientific plots for prosody and break tags, showing model predictions vs gold standard.
    Models are ordered by the *best* (lowest absolute error to gold) of zero-shot or few-shot for each tag.
    If save_path is given, saves two separate figures: save_path + "_prosody.png"/".pdf" and save_path + "_break.png"/".pdf"
    """
    # Set scientific style with larger fonts
    plt.rcParams.update({
        'font.family': 'Times New Roman',
        'font.serif': ['Computer Modern Roman'],
        'font.size': 14,              # Increased from 12
        'axes.labelsize': 13,         # Increased from 11
        'axes.titlesize': 14,         # Increased from 12
        'xtick.labelsize': 11,        # Increased from 9
        'ytick.labelsize': 11,        # Increased from 9
        'legend.fontsize': 11,        # Increased from 9
        'figure.titlesize': 15        # Increased from 13
    })

    # Format model names nicely
    def format_model_name(name):
        """Format model names in a more readable way."""
        # Handle Qwen models with size indicators
        if "qwen" in name.lower():
            if ":" in name:
                base_name, size = name.split(":")
                # Capitalize first letter of base name
                base_name = base_name[0].upper() + base_name[1:]
                # Make size uppercase and add parentheses
                return f"{base_name}\n({size.upper()})"
            else:
                # Just capitalize if no size indicator
                return name[0].upper() + name[1:]
        
        # Handle DeepSeek
        elif "deepseek" in name.lower():
            if ":" in name:
                base_name, size = name.split(":")
                return f"DeeepSeek-R1\n({size.upper()})"
            else:
                parts = name.split("-")
                return "-".join([p.capitalize() for p in parts])
        
        # Handle Granite with version
        elif "granite" in name.lower():
            return name.capitalize() + "\n(8B)"
        
        # For simple names like "mistral" or "llama3"
        elif "mistral" in name.lower():
            return name.capitalize() + "\n(7B)"
        elif "llama3" in name.lower():
            return name.capitalize() + "\n(8B)"
        else:
            return name.capitalize()

    # Extract data
    models = list(all_model_results.keys())
    model_data = {}

    # Find global gold standard values
    gold_prosody_tags = []
    gold_break_tags = []

    for model in models:
        results = all_model_results[model]
        model_data[model] = {
            'zero_shot': {'prosody': 0, 'break': 0},
            'few_shot': {'prosody': 0, 'break': 0},
            'gold_prosody': 0,
            'gold_break': 0
        }

        # Get zero-shot results
        zero_shot_file = next((f for f in results if f == f"{model}_zero_shot.json"), None)
        if zero_shot_file:
            metrics = results[zero_shot_file]["metrics"]
            model_data[model]['zero_shot']['prosody'] = metrics.get("prosody_count_mean", 0)
            model_data[model]['zero_shot']['break'] = metrics.get("break_count_mean", 0)
            model_data[model]['gold_prosody'] = metrics.get("gold_prosody_count_mean", 0)
            model_data[model]['gold_break'] = metrics.get("gold_break_count_mean", 0)
            gold_prosody_tags.append(metrics.get("gold_prosody_count_mean", 0))
            gold_break_tags.append(metrics.get("gold_break_count_mean", 0))

        # Get few-shot results
        few_shot_file = next((f for f in results if f == f"{model}_few_shot.json"), None)
        if few_shot_file:
            metrics = results[few_shot_file]["metrics"]
            model_data[model]['few_shot']['prosody'] = metrics.get("prosody_count_mean", 0)
            model_data[model]['few_shot']['break'] = metrics.get("break_count_mean", 0)
            if model_data[model]['gold_prosody'] == 0:
                model_data[model]['gold_prosody'] = metrics.get("gold_prosody_count_mean", 0)
                gold_prosody_tags.append(metrics.get("gold_prosody_count_mean", 0))
            if model_data[model]['gold_break'] == 0:
                model_data[model]['gold_break'] = metrics.get("gold_break_count_mean", 0)
                gold_break_tags.append(metrics.get("gold_break_count_mean", 0))

    # Calculate average gold standard values
    avg_gold_prosody = np.mean(gold_prosody_tags) if gold_prosody_tags else 0
    avg_gold_break = np.mean(gold_break_tags) if gold_break_tags else 0

    # --- ORDERING: sort by best (closest to gold) value among zero/few-shot for each tag ---

    # For prosody tags
    prosody_order = []
    for model in models:
        gold = model_data[model]['gold_prosody']
        zero = model_data[model]['zero_shot']['prosody']
        few = model_data[model]['few_shot']['prosody']
        best_err = min(abs(zero - gold), abs(few - gold))
        prosody_order.append((model, best_err))
    prosody_order.sort(key=lambda x: x[1])
    prosody_sorted = [m[0] for m in prosody_order]

    # For break tags
    break_order = []
    for model in models:
        gold = model_data[model]['gold_break']
        zero = model_data[model]['zero_shot']['break']
        few = model_data[model]['few_shot']['break']
        best_err = min(abs(zero - gold), abs(few - gold))
        break_order.append((model, best_err))
    break_order.sort(key=lambda x: x[1])
    break_sorted = [m[0] for m in break_order]

    # --- Plot 1: Prosody Tags ---
    # Increase figure size slightly for better readability
    fig1, ax1 = plt.subplots(figsize=(figsize[0]//2 + 1, figsize[1] + 1))
    x1 = np.arange(len(prosody_sorted))
    zero_shot_prosody = [model_data[model]['zero_shot']['prosody'] for model in prosody_sorted]
    few_shot_prosody = [model_data[model]['few_shot']['prosody'] for model in prosody_sorted]
    ax1.bar(x1 - 0.175, zero_shot_prosody, 0.35, label='Zero-Shot', color='#4477AA', alpha=0.8)
    ax1.bar(x1 + 0.175, few_shot_prosody, 0.35, label='Few-Shot', color='#EE6677', alpha=0.8)
    ax1.axhline(y=avg_gold_prosody, color='black', linestyle='--', linewidth=1.5, label='Gold Standard')
    # ax1.set_title('Prosody Tags per Sample', pad=10)
    ax1.set_ylabel('Average Count')
    ax1.set_xticks(x1)
    # Use the format_model_name function for better labels
    ax1.set_xticklabels([format_model_name(name) for name in prosody_sorted], rotation=45, ha='right')
    ax1.grid(axis='y', linestyle='--', alpha=0.3)
    ax1.legend(frameon=True, fancybox=False, edgecolor='black')
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    # --- Plot 2: Break Tags ---
    # Increase figure size slightly for better readability
    fig2, ax2 = plt.subplots(figsize=(figsize[0]//2 + 1, figsize[1] + 1))
    x2 = np.arange(len(break_sorted))
    zero_shot_break = [model_data[model]['zero_shot']['break'] for model in break_sorted]
    few_shot_break = [model_data[model]['few_shot']['break'] for model in break_sorted]
    ax2.bar(x2 - 0.175, zero_shot_break, 0.35, label='Zero-Shot', color='#4477AA', alpha=0.8)
    ax2.bar(x2 + 0.175, few_shot_break, 0.35, label='Few-Shot', color='#EE6677', alpha=0.8)
    ax2.axhline(y=avg_gold_break, color='black', linestyle='--', linewidth=1.5, label='Gold Standard')
    # ax2.set_title('Break Tags per Sample', pad=10)
    ax2.set_ylabel('Average Count')
    ax2.set_xticks(x2)
    # Use the format_model_name function for better labels
    ax2.set_xticklabels([format_model_name(name) for name in break_sorted], rotation=45, ha='right')
    ax2.grid(axis='y', linestyle='--', alpha=0.3)
    ax2.legend(frameon=True, fancybox=False, edgecolor='black')
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    # Save if path provided
    if save_path:
        fig1.savefig(f"{save_path}_prosody.pdf", format='pdf', bbox_inches='tight', dpi=dpi, pad_inches=0)
        fig1.savefig(f"{save_path}_prosody.png", format='png', bbox_inches='tight', dpi=dpi, pad_inches=0)
        fig2.savefig(f"{save_path}_break.pdf", format='pdf', bbox_inches='tight', dpi=dpi, pad_inches=0)
        fig2.savefig(f"{save_path}_break.png", format='png', bbox_inches='tight', dpi=dpi, pad_inches=0)

    return fig1, fig2

In [None]:
import matplotlib.font_manager as fm

# Tell matplotlib to rebuild its font cache
font_path = "/home/infres/horstmann-24/.local/share/fonts/truetype/msttcorefonts/times.ttf"
fm.fontManager.addfont(font_path)
# Now list all fonts to verify it's included
for f in fm.findSystemFonts(fontpaths=['/home/infres/horstmann-24/.local/share/fonts']):
    print(f)

In [None]:
fig1, fig2 = generate_scientific_tag_plots(all_model_results, save_path="ssml_tag_usage_comparison")
# Or to display them in the notebook
fig1, fig2 = generate_scientific_tag_plots(all_model_results)

# Display the generated figures
plt.show()




## Additional Analysis

Here you can add custom analysis code to examine specific aspects of the results.

In [None]:
def analyze_break_position_accuracy(results_data):
    """Analyze how the break position threshold affects accuracy"""
    # This is a placeholder for a more detailed analysis
    # You could implement this to test different thresholds
    pass

In [None]:
# Function to analyze specific examples to understand why breaks might be missed or falsely detected
def analyze_break_errors(results_data):
    if "results" not in results_data:
        print("No detailed results found")
        return
    
    # Find samples with parsed_sequence
    for i, result in enumerate(results_data["results"]):
        if "params" in result and "parsed_sequence" in result["params"]:
            print(f"Sample {i+1}: {result['input_text'][:50]}...")
            pred_seq = result["params"]["parsed_sequence"]
            
            # Print break positions
            breaks = []
            for j, item in enumerate(pred_seq):
                if item.get("type") == "break":
                    breaks.append(f"Position {j}: {item.get('time', '?')}")
            
            print(f"Found {len(breaks)} breaks:")
            for b in breaks:
                print(f"  - {b}")
            print("\n")

In [None]:
# Example: Analyze break errors in zero-shot results
if "zero_shot_mistral.json" in results:
    analyze_break_errors(results["zero_shot_mistral.json"])

## Save Updated Visualizations

Save any visualizations to disk if needed

In [None]:
# Directory to save visualizations

# Example: Save zero-shot visualizations
if "zero_shot_mistral.json" in results:
    zero_shot_results = results["zero_shot_mistral.json"]
    plot_metrics(zero_shot_results, save_path=os.path.join(viz_dir, "zero_shot_metrics.png"))
    plot_error_metrics(zero_shot_results, save_path=os.path.join(viz_dir, "zero_shot_errors.png"))

print(f"Visualizations saved to {viz_dir}")