In [None]:
# ============================================================================
# Import all required libraries
# ============================================================================

# Add parent directory to path to import utils
import sys
sys.path.append('..')

# Data manipulation and analysis
import pandas as pd
import numpy as np

# Visualization
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap


# Custom utilities
from utils import (
    normalize_text,
    categorize_language,
    group_language_by_family,
    read_transcription_data,
)

# ============================================================================
# Bootstrap Resampling Utilities
# ============================================================================
def bootstrap_mean(data, n_bootstrap=1000, confidence_level=0.95, random_state=42):
    """
    Compute bootstrap confidence intervals for the mean.
    
    Parameters:
    -----------
    data : array-like
        Binary data (0s and 1s for accuracy)
    n_bootstrap : int
        Number of bootstrap samples
    confidence_level : float
        Confidence level (e.g., 0.95 for 95% CI)
    random_state : int
        Random seed for reproducibility
        
    Returns:
    --------
    mean : float
        Point estimate (mean)
    ci_lower : float
        Lower bound of confidence interval
    ci_upper : float
        Upper bound of confidence interval
    """
    np.random.seed(random_state)
    n = len(data)
    
    if n == 0:
        return np.nan, np.nan, np.nan
    
    # Generate bootstrap samples
    bootstrap_means = []
    for _ in range(n_bootstrap):
        sample = np.random.choice(data, size=n, replace=True)
        bootstrap_means.append(np.mean(sample))
    
    # Compute percentiles for CI
    alpha = 1 - confidence_level
    ci_lower = np.percentile(bootstrap_means, 100 * alpha / 2)
    ci_upper = np.percentile(bootstrap_means, 100 * (1 - alpha / 2))
    mean = np.mean(data)
    
    return mean, ci_lower, ci_upper

# ============================================================================
# Configure visualization settings
# ============================================================================
# Use DejaVu Sans which is commonly available on Linux systems
matplotlib.rcParams['font.family'] = 'DejaVu Sans'
matplotlib.rcParams['font.size'] = 18

# Set seaborn style for cleaner plots
sns.set_style("whitegrid")
sns.set_context("notebook", font_scale=1.4)

pd.set_option('display.max_rows', 20)


In [None]:
# ============================================================================
# CONFIGURATION: All four model sizes analyzed simultaneously
# ============================================================================

# Each entry maps: baseline model name -> finetuned model name
MODEL_PAIRS = {
    "tiny": "whisper_tiny_im_on_all_17863_20260130_131358",
    "base": "whisper_base_im_on_all_17862_20260130_131358",
    "medium": "whisper_medium_im_on_all_17864_20260130_131403",
    "small": "whisper_small_im_on_all_18126_20260206_163258",
    "large": "whisper_large_im_on_all_17861_20260130_131353",
}

# Per-language finetuned models: model name -> language code
LANG_MODEL_TO_CODE = {
    "whisper_base_im_on_ar_18089_20260206_131541":    "ar",
    "whisper_base_im_on_cs_18089_20260206_131543":    "cs",
    "whisper_base_im_on_de_18089_20260206_131542":    "de",
    "whisper_base_im_on_es_18089_20260206_131543":    "es",
    "whisper_base_im_on_fr_18089_20260206_131652":    "fr",
    "whisper_base_im_on_hi_18089_20260206_131641":    "hi",
    "whisper_base_im_on_hu_18089_20260206_131656":    "hu",
    "whisper_base_im_on_it_18108_20260206_135235":    "it",
    "whisper_base_im_on_ja_18118_20260206_142832":    "ja",
    "whisper_base_im_on_ko_18089_20260206_132907":    "ko",
    "whisper_base_im_on_nl_18089_20260206_132915":    "nl",
    "whisper_base_im_on_pl_18089_20260206_132927":    "pl",
    "whisper_base_im_on_pt_18089_20260206_132927":    "pt",
    "whisper_base_im_on_ru_18089_20260206_133021":    "ru",
    "whisper_base_im_on_tr_18089_20260206_133039":    "tr",
    "whisper_base_im_on_zh-cn_18108_20260206_135234": "zh-cn",
}

# All per-language models map to "base" size
LANG_MODEL_PAIRS = {model: "base" for model in LANG_MODEL_TO_CODE}

# Add a 'model_size' column to tag each row with its size (tiny/base/medium/large)
def get_model_size(model_name):
    """Map a model name (baseline or finetuned) to its size label."""
    if model_name in MODEL_PAIRS:
        return model_name
    for baseline, finetuned in MODEL_PAIRS.items():
        if model_name == finetuned:
            return baseline
    if model_name in LANG_MODEL_PAIRS:
        return LANG_MODEL_PAIRS[model_name]
    return 'unknown'

# Convenience lists
BASELINE_MODELS = list(MODEL_PAIRS.keys())
FINETUNED_MODELS = list(MODEL_PAIRS.values())
ALL_MODELS = BASELINE_MODELS + FINETUNED_MODELS

# Whether to compare with baseline models or only analyze finetuned models
COMPARE_WITH_OTHER_MODELS = True


In [None]:
print(f"Model pairs ({len(MODEL_PAIRS)}):")
for baseline, finetuned in MODEL_PAIRS.items():
    print(f"  {baseline} -> {finetuned}")
print(f"\nComparing with baseline models: {COMPARE_WITH_OTHER_MODELS}")


In [None]:
all_data = read_transcription_data("all")
all_data = all_data[all_data['prompt']=='No prompt']
street_origin = pd.read_csv("../street_names.tsv")
street_origin['name'] = street_origin['name'].str.lower()

all_data = all_data.set_index("answer").join(street_origin.set_index("name"), how='left').reset_index()

all_data.columns = ['answer', 'participant_id', 'index', 'model', 'prompt', 'original_text',
       'transcription', 'transcription_og', 'Status', 'Primary language',
       'Age', 'Sex', 'Language', 'english_only', 'multilingual', 'not_english',
       'levenshtein_distance', 'is_correct', 'age_decade', 'origin']
all_data['correct'] = all_data['transcription'] == all_data['answer']
all_data['language_family'] = all_data['Primary language'].apply(group_language_by_family)



all_data['model_size'] = all_data['model'].apply(get_model_size)

# Add a 'model_type' column: 'baseline' or 'finetuned'
all_finetuned = list(MODEL_PAIRS.values()) + list(LANG_MODEL_PAIRS.keys())
all_data['model_type'] = all_data['model'].apply(
    lambda m: 'finetuned' if m in all_finetuned else 'baseline'
)

# --- data: MODEL_PAIRS models only (baselines + their finetuned counterparts) ---
data = all_data[all_data['model'].isin(ALL_MODELS)].copy()
print(f"data: Filtered to {len(ALL_MODELS)} models ({len(data)} rows)")
print(f"  Models: {ALL_MODELS}")

print(f"\nModels loaded per size:")
for size in BASELINE_MODELS:
    subset = data[data['model_size'] == size]
    n_baseline = len(subset[subset['model_type'] == 'baseline'])
    n_finetuned = len(subset[subset['model_type'] == 'finetuned'])
    print(f"  {size}: {n_baseline} baseline samples, {n_finetuned} finetuned samples")

# --- data_lang: LANG_MODEL_PAIRS models + base baseline ---
lang_models_to_keep = ['base'] + list(LANG_MODEL_PAIRS.keys())
data_lang = all_data[all_data['model'].isin(lang_models_to_keep)].copy()

# Add lang_code column (NaN for baseline, language code for per-language finetuned)
data_lang['lang_code'] = data_lang['model'].map(LANG_MODEL_TO_CODE)

print(f"\ndata_lang: {len(lang_models_to_keep)} models ({len(data_lang)} rows)")
print(f"  Languages: {sorted(data_lang['lang_code'].dropna().unique())}")
print(f"  Baseline (base) samples: {len(data_lang[data_lang['model'] == 'base'])}")
print(f"  Per-language finetuned samples: {len(data_lang[data_lang['model'] != 'base'])}")

In [None]:
# Compare accuracy between all models
print("=" * 60)
print("OVERALL ACCURACY COMPARISON (All Model Sizes)")
print("=" * 60)

# Overall accuracy table
print("\nOverall Accuracy by Model:")
accuracy_by_model = data.groupby('model')['is_correct'].agg(['mean', 'count', 'sum'])
accuracy_by_model.columns = ['Accuracy', 'Total Samples', 'Correct']
accuracy_by_model = accuracy_by_model.sort_values('Accuracy', ascending=False)
display(accuracy_by_model.round(3))

# Summary table by model size (baseline vs finetuned)
print("\n" + "=" * 60)
print("ACCURACY BY MODEL SIZE (Baseline vs Finetuned)")
print("=" * 60)

summary_rows = []
for size in BASELINE_MODELS:
    baseline_data = data[(data['model_size'] == size) & (data['model_type'] == 'baseline')]
    finetuned_data = data[(data['model_size'] == size) & (data['model_type'] == 'finetuned')]
    
    b_acc = baseline_data['is_correct'].mean() if len(baseline_data) > 0 else float('nan')
    f_acc = finetuned_data['is_correct'].mean() if len(finetuned_data) > 0 else float('nan')
    delta = f_acc - b_acc if not (np.isnan(b_acc) or np.isnan(f_acc)) else float('nan')
    
    summary_rows.append({
        'Model Size': size,
        'Baseline Acc': round(b_acc, 3),
        'Baseline N': len(baseline_data),
        'Finetuned Acc': round(f_acc, 3),
        'Finetuned N': len(finetuned_data),
        'Delta': round(delta, 3),
    })

summary_df = pd.DataFrame(summary_rows)
display(summary_df)


In [None]:
# display(data[data['model'].str.contains("im_on")].groupby("language_group").mean(numeric_only=True)[['is_correct']])
# display(data[~data['model'].str.contains("im_on")].groupby("language_group").mean(numeric_only=True)[['is_correct']])

In [None]:
(0.823103 - 0.648621) / 0.648621

In [None]:
(0.779310 - 0.455172)/0.455172

In [None]:
# ============================================================================
# Figure 1: Accuracy by Language Group with Bootstrap Confidence Intervals
# ============================================================================

# Create language group column based on english_only, multilingual, not_english
def get_language_group(row):
    if row['english_only']:
        return 'English Only'
    elif row['multilingual']:
        return 'Multilingual'
    elif row['not_english']:
        return 'Non-English'
    return 'Unknown'

data['language_group'] = data.apply(get_language_group, axis=1)


def plot_accuracy_by_language_group(data, model_sizes, model_types, title, save_path,
                                    order_by_type='finetuned', colors=None):
    """
    Plot accuracy by language group with bootstrap confidence intervals.

    Parameters
    ----------
    data : DataFrame
        Must contain columns: language_group, model_size, model_type, is_correct
    model_sizes : list[str]
        Model sizes to create subplots for (e.g. ['tiny', 'base', ...])
    model_types : list[str]
        Model types to compare (e.g. ['baseline'] or ['baseline', 'finetuned'])
    title : str
        Figure suptitle
    save_path : str
        Path to save the figure
    order_by_type : str
        Which model_type to use for ordering language groups (descending accuracy).
        Falls back to first available type if not found.
    colors : dict, optional
        Mapping of model_type -> color. Defaults provided for baseline/finetuned.
    """
    if colors is None:
        colors = {'baseline': '#3498db', 'finetuned': '#e74c3c'}

    lang_groups = sorted(data['language_group'].unique())

    # Compute bootstrap CIs
    results = []
    for lang_group in lang_groups:
        for size in model_sizes:
            for model_type in model_types:
                subset = data[(data['language_group'] == lang_group) &
                              (data['model_size'] == size) &
                              (data['model_type'] == model_type)]
                if len(subset) > 0:
                    accuracy_data = subset['is_correct'].values
                    mean, ci_lower, ci_upper = bootstrap_mean(accuracy_data, n_bootstrap=1000)
                    results.append({
                        'language_group': lang_group,
                        'model_size': size,
                        'model_type': model_type,
                        'accuracy': mean,
                        'ci_lower': ci_lower,
                        'ci_upper': ci_upper,
                        'n_samples': len(subset)
                    })

    results_df = pd.DataFrame(results)

    # Determine language group ordering
    order_candidates = [order_by_type] + [t for t in model_types if t != order_by_type]
    lang_group_order = lang_groups
    for otype in order_candidates:
        order_data = results_df[(results_df['model_size'] == model_sizes[0]) &
                                (results_df['model_type'] == otype)]
        if len(order_data) > 0:
            lang_group_order = (order_data.set_index('language_group')['accuracy']
                                .sort_values(ascending=False).index.tolist())
            break

    # Create subplots: one per model size
    n_types = len(model_types)
    n_sizes = len(model_sizes)
    fig, axes = plt.subplots(1, n_sizes, figsize=(8 * n_sizes, 9), sharey=True)
    if n_sizes == 1:
        axes = [axes]

    x = np.arange(len(lang_group_order))
    width = 0.35 if n_types > 1 else 0.5

    for ax_idx, size in enumerate(model_sizes):
        ax = axes[ax_idx]

        for i, model_type in enumerate(model_types):
            model_data = (results_df[(results_df['model_size'] == size) &
                                     (results_df['model_type'] == model_type)]
                          .set_index('language_group').reindex(lang_group_order))

            if len(model_data) == 0:
                continue

            accuracies = model_data['accuracy'].values
            ci_lowers = model_data['ci_lower'].values
            ci_uppers = model_data['ci_upper'].values
            errors = np.array([accuracies - ci_lowers, ci_uppers - accuracies])

            if n_types == 1:
                offset = 0
            else:
                offset = -width / 2 if i == 0 else width / 2
            color = colors.get(model_type, f'C{i}')
            label = f'{model_type.capitalize()} ({size})'

            bars = ax.bar(x + offset, accuracies, width, label=label,
                          color=color, alpha=0.8, edgecolor='white', linewidth=1.5)

            # Error bars (95% CI)
            ax.errorbar(x + offset, accuracies, yerr=errors, fmt='none',
                        ecolor='black', capsize=4, capthick=1.5, alpha=0.7, linewidth=1.5)

            # Value labels on bars
            for j, (bar, acc, n_samp) in enumerate(zip(bars, accuracies, model_data['n_samples'].values)):
                height = bar.get_height()
                ax.annotate(f'{acc:.2f}\n(n={n_samp})',
                            xy=(bar.get_x() + bar.get_width() / 2, height),
                            xytext=(0, 5), textcoords="offset points",
                            ha='center', va='bottom', fontsize=14, fontweight='bold')

        # Styling per subplot
        ax.set_xlabel('Language Group', fontsize=18, fontweight='bold')
        if ax_idx == 0:
            ax.set_ylabel('Accuracy', fontsize=18, fontweight='bold')
        ax.set_title(f'whisper-{size}', fontsize=20, fontweight='bold', pad=12)
        ax.set_xticks(x)
        ax.set_xticklabels(lang_group_order, rotation=0, ha='center', fontsize=16)
        ax.legend(loc='lower right', fontsize=14, frameon=True, shadow=True)
        ax.set_ylim(0, 1.15)
        ax.yaxis.grid(True, alpha=0.3, linestyle='--')
        ax.set_axisbelow(True)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.tick_params(axis='both', labelsize=14)

    fig.suptitle(title, fontsize=22, fontweight='bold', y=1.02)
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

    return results_df


# --- Figure 1a: Finetuned vs Baseline ---
plot_accuracy_by_language_group(
    data, BASELINE_MODELS, ['baseline', 'finetuned'],
    title='Accuracy by Language Group: Finetuned vs Baseline (All Model Sizes)\n(with 95% Bootstrap Confidence Intervals)',
    save_path='figures/accuracy_finetuned_vs_baseline_by_language_group.png',
    order_by_type='finetuned',
)

# --- Figure 1b: whisper-base only (baseline vs finetuned) ---
plot_accuracy_by_language_group(
    data, ['base'], ['baseline', 'finetuned'],
    title='Accuracy by Language Group: whisper-base (Baseline vs Finetuned)\n(with 95% Bootstrap Confidence Intervals)',
    save_path='figures/accuracy_base_model_by_language_group.png',
    order_by_type='finetuned',
)



In [None]:
# Create binary columns for each individual language
# Parse comma-separated languages and create one column per language

# Extract all individual languages from Primary language column
all_languages = set()
for lang_str in data['Primary language'].dropna():
    for lang in lang_str.split(','):
        all_languages.add(lang.strip())

# Create a binary column for each individual language
for lang in sorted(all_languages):
    col_name = f'lang_{lang.replace(" ", "_")}'
    data[col_name] = data['Primary language'].apply(
        lambda x: 1 if pd.notna(x) and lang.lower() in x.lower() else 0
    )

# Create English-only speaker column (monolingual English)
data['english_only_speaker'] = data['Primary language'].apply(
    lambda x: 1 if pd.notna(x) and x.strip().lower() == 'english' else 0
)

# Show the result
lang_cols = [c for c in data.columns if c.startswith('lang_')]
print(f"Created {len(lang_cols)} language columns: {lang_cols}")
print(f"English-only speakers: {data['english_only_speaker'].sum() // 2}")  # Divide by 2 for 2 models

# Show sample of participant-level data
participant_langs = data.groupby('participant_id')[lang_cols + ['english_only_speaker']].max()
print(f"\nParticipants x Languages ({participant_langs.shape[0]} participants):")

In [None]:
# ============================================================================
# Figure 2: Accuracy by Language with Bootstrap Confidence Intervals
# ============================================================================

# Languages that had synthetic training data (based on voice cloning)
LANGUAGES_WITH_TRAINING_DATA = {
    'Arabic', 'Czech', 'German', 'Spanish', 'French', 'Hindi', 'Hungarian', 
    'Italian', 'Japanese', 'Korean', 'Dutch', 'Polish', 'Portuguese', 
    'Russian', 'Turkish', 'Chinese'
}

# Get language columns and clean names
lang_cols = [c for c in data.columns if c.startswith('lang_')]
lang_names = [c.replace('lang_', '').replace("'", "").replace("_", " ") for c in lang_cols]

def has_training_data(lang):
    return any(train_lang.lower() in lang.lower() for train_lang in LANGUAGES_WITH_TRAINING_DATA)

# Compute accuracy with bootstrap for each language (base model only)
results = []
for col, lang in zip(lang_cols, lang_names):
    for model_type in ['baseline', 'finetuned']:
        subset = data[(data[col] == 1) & 
                     (data['model_size'] == 'base') & 
                     (data['model_type'] == model_type)]
        if len(subset) > 0:
            accuracy_data = subset['is_correct'].values
            mean, ci_lower, ci_upper = bootstrap_mean(accuracy_data, n_bootstrap=1000)
            results.append({
                'language': lang,
                'model_type': model_type,
                'accuracy': mean,
                'ci_lower': ci_lower,
                'ci_upper': ci_upper,
                'n_samples': len(subset),
                'has_training': has_training_data(lang)
            })

results_df = pd.DataFrame(results)

# Sort by finetuned accuracy: languages WITH training data first (by accuracy desc)
ft_results = results_df[results_df['model_type'] == 'finetuned']
if len(ft_results) > 0:
    finetuned_acc = ft_results.set_index('language')['accuracy']
    finetuned_training = ft_results.set_index('language')['has_training']
    lang_order = sorted(
        finetuned_acc.index, 
        key=lambda x: (not finetuned_training.get(x, False), -finetuned_acc.get(x, 0))
    )
else:
    lang_order = lang_names

# Create single plot for base model
from matplotlib.patches import Patch

fig, ax = plt.subplots(figsize=(16, 8))

x = np.arange(len(lang_order))
width = 0.35

for i, model_type in enumerate(['baseline', 'finetuned']):
    model_data = results_df[results_df['model_type'] == model_type].set_index('language').reindex(lang_order)
    
    if len(model_data) == 0:
        continue
    
    accuracies = model_data['accuracy'].values
    ci_lowers = model_data['ci_lower'].values
    ci_uppers = model_data['ci_upper'].values
    has_training = model_data['has_training'].values
    errors = np.array([accuracies - ci_lowers, ci_uppers - accuracies])
    
    offset = -width/2 if i == 0 else width/2
    
    # Color based on training data availability
    if i == 0:  # Baseline
        colors = ['#3498db' if ht else '#a9cce3' for ht in has_training]
    else:  # Finetuned
        colors = ['#e74c3c' if ht else '#f5b7b1' for ht in has_training]
    
    # Plot bars
    for j, (acc, err_lower, err_upper, color, n_samples) in enumerate(zip(
        accuracies, errors[0], errors[1], colors, model_data['n_samples'].values
    )):
        if np.isnan(acc):
            continue
        bar = ax.bar(x[j] + offset, acc, width, color=color, alpha=0.8, 
                     edgecolor='white', linewidth=1.5)
        
        # Add error bar
        ax.errorbar(x[j] + offset, acc, yerr=[[err_lower], [err_upper]], 
                   fmt='none', ecolor='black', capsize=3, capthick=1.2, 
                   alpha=0.7, linewidth=1.2)
        
        # Add value label
        ax.annotate(f'{acc:.2f}', 
                   xy=(x[j] + offset, acc),
                   xytext=(0, 3), textcoords="offset points", 
                   ha='center', va='bottom', fontsize=12, fontweight='bold')

# Add legend with both colors
legend_elements = [
    Patch(facecolor='#3498db', alpha=0.8, label='Baseline (w/ training)'),
    Patch(facecolor='#a9cce3', alpha=0.8, label='Baseline (no training)'),
    Patch(facecolor='#e74c3c', alpha=0.8, label='Finetuned (w/ training)'),
    Patch(facecolor='#f5b7b1', alpha=0.8, label='Finetuned (no training)'),
]

ax.set_xlabel('Language', fontsize=18, fontweight='bold')
ax.set_ylabel('Accuracy', fontsize=18, fontweight='bold')
ax.set_title('whisper-base: Accuracy by Language (Finetuned vs Baseline)\n(with 95% Bootstrap Confidence Intervals)',
             fontsize=20, fontweight='bold', pad=12)
ax.set_xticks(x)
ax.set_xticklabels(lang_order, rotation=45, ha='right', fontsize=16)
ax.legend(handles=legend_elements, loc='lower left', fontsize=14, frameon=True, shadow=True)
ax.set_ylim(0, 1.15)
ax.yaxis.grid(True, alpha=0.3, linestyle='--')
ax.set_axisbelow(True)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('figures/accuracy_finetuned_vs_baseline_by_language.png', dpi=300, bbox_inches='tight')
plt.show()


In [None]:
data_lang['lang_code'].fillna('base')

In [None]:
# ============================================================================
# Join each per-language finetuned model with the base baseline on (participant_id, answer)
# ============================================================================

data_lang['lang_code'] = data_lang['lang_code'].fillna("base")

# Separate base baseline rows and per-language finetuned rows
base_rows = data_lang[data_lang['model'] == 'base'][['participant_id', 'answer', 'transcription', 'is_correct']].copy()
base_rows = base_rows.rename(columns={
    'transcription': 'base_transcription',
    'is_correct': 'base_is_correct',
})

lang_rows = data_lang[data_lang['model'] != 'base'].copy()

# Join: for each finetuned row, attach the base model's result on the same (participant_id, answer)
data_lang_merged = lang_rows.merge(base_rows, on=['participant_id', 'answer'], how='left')

# Useful derived columns
data_lang_merged['both_correct'] = data_lang_merged['is_correct'] & data_lang_merged['base_is_correct']
data_lang_merged['finetuned_improved'] = data_lang_merged['is_correct'] & ~data_lang_merged['base_is_correct']
data_lang_merged['finetuned_regressed'] = ~data_lang_merged['is_correct'] & data_lang_merged['base_is_correct']
data_lang_merged['overall_improvement'] = data_lang_merged['is_correct'].astype(int) - data_lang_merged['base_is_correct'].astype(int)

print(f"\nOverall: base accuracy = {data_lang_merged['base_is_correct'].mean():.3f}, "
      f"finetuned accuracy = {data_lang_merged['is_correct'].mean():.3f}")
print(f"Improved: {data_lang_merged['finetuned_improved'].sum()}, "
      f"Regressed: {data_lang_merged['finetuned_regressed'].sum()}")
data_lang_merged.head()

In [None]:
data_lang_merged.groupby(['lang_code']).mean(numeric_only=True)[['overall_improvement']].sort_values(by='overall_improvement', ascending=False)

#0.289 

In [None]:
# ============================================================================
# Heatmap: finetuned improvement rate by lang_code × language_family
# ============================================================================

# Pivot the grouped means into a 2D matrix
heatmap_data = (data_lang_merged
    .groupby(['lang_code', 'Primary language'])['overall_improvement']
    .mean()
    .unstack(fill_value=np.nan))

# Add whisper_base_im_on_all (finetuned on all languages) as an extra row
base_all_model = MODEL_PAIRS['base']  # whisper_base_im_on_all
base_all_rows = all_data[all_data['model'] == base_all_model].copy()
base_baseline_rows = all_data[all_data['model'] == 'base'][['participant_id', 'answer', 'is_correct']].copy()
base_baseline_rows = base_baseline_rows.rename(columns={'is_correct': 'base_is_correct'})
base_all_merged = base_all_rows.merge(base_baseline_rows, on=['participant_id', 'answer'], how='left')
base_all_merged['overall_improvement'] = base_all_merged['is_correct'].astype(int) - base_all_merged['base_is_correct'].astype(int)
base_all_improvement = base_all_merged.groupby('Primary language')['overall_improvement'].mean()
heatmap_data.loc['all'] = base_all_improvement

# Sort rows by overall average improvement (descending)
row_order = heatmap_data.mean(axis=1).sort_values(ascending=False).index
heatmap_data = heatmap_data.loc[row_order]

# Sort columns by overall average improvement (descending)
col_order = heatmap_data.mean(axis=0).sort_values(ascending=False).index
heatmap_data = heatmap_data[col_order]

# Map language codes to full language names
LANG_CODE_TO_NAME = {
    'ar': 'Arabic', 'cs': 'Czech', 'de': 'German', 'es': 'Spanish',
    'fr': 'French', 'hi': 'Hindi', 'hu': 'Hungarian', 'it': 'Italian',
    'ja': 'Japanese', 'ko': 'Korean', 'nl': 'Dutch', 'pl': 'Polish',
    'pt': 'Portuguese', 'ru': 'Russian', 'tr': 'Turkish', 'zh-cn': 'Chinese',
    'all': 'All Languages',
}
heatmap_data.index = [LANG_CODE_TO_NAME.get(code.replace("_18089", "").replace("_18108", "").replace("_18118", ""), code) for code in heatmap_data.index]


custom_cmap = LinearSegmentedColormap.from_list('BrWtBl', ['#8B2500', '#E8742A', '#FFFFFF', '#7FB5D3', '#1F3B6E'])

# Tighter limits for more extreme colors

fig, ax = plt.subplots(figsize=(18, 10))
sns.heatmap(
    heatmap_data,
    annot=True,
    fmt='.2f',
    cmap=custom_cmap,
    center=0,
    vmin=-0.10,
    vmax=0.6,
    linewidths=0.5,
    linecolor='white',
    cbar_kws={'label': 'Improvement Rate'},
    ax=ax,
)

ax.set_xlabel('Primary Language', fontsize=18, fontweight='bold')
ax.set_ylabel('Finetuning Language', fontsize=18, fontweight='bold')
ax.set_title('Finetuned Improvement Rate by Training Language × Primary Language',
             fontsize=20, fontweight='bold', pad=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('figures/finetuned_improvement_heatmap_lang_x_primary_language.png', dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# ============================================================================
# Heatmap: finetuned improvement rate by lang_code × language_family
# ============================================================================

# Pivot the grouped means into a 2D matrix
heatmap_data = (data_lang_merged
    .groupby(['lang_code', 'answer'])['overall_improvement']
    .mean()
    .unstack(fill_value=np.nan))

# Add whisper_base_im_on_all (finetuned on all languages) as an extra row
base_all_model = MODEL_PAIRS['base']  # whisper_base_im_on_all
base_all_rows = all_data[all_data['model'] == base_all_model].copy()
base_baseline_rows = all_data[all_data['model'] == 'base'][['participant_id', 'answer', 'is_correct']].copy()
base_baseline_rows = base_baseline_rows.rename(columns={'is_correct': 'base_is_correct'})
base_all_merged = base_all_rows.merge(base_baseline_rows, on=['participant_id', 'answer'], how='left')
base_all_merged['overall_improvement'] = base_all_merged['is_correct'].astype(int) - base_all_merged['base_is_correct'].astype(int)
base_all_improvement = base_all_merged.groupby('answer')['overall_improvement'].mean()
heatmap_data.loc['all'] = base_all_improvement

heatmap_data = heatmap_data.round(2)

# Sort rows by overall average improvement (descending)
row_order = heatmap_data.mean(axis=1).sort_values(ascending=False).index
heatmap_data = heatmap_data.loc[row_order]

# Sort columns by overall average improvement (descending)
col_order = heatmap_data.mean(axis=0).sort_values(ascending=False).index
heatmap_data = heatmap_data[col_order]

# Map language codes to full language names
LANG_CODE_TO_NAME = {
    'ar': 'Arabic', 'cs': 'Czech', 'de': 'German', 'es': 'Spanish',
    'fr': 'French', 'hi': 'Hindi', 'hu': 'Hungarian', 'it': 'Italian',
    'ja': 'Japanese', 'ko': 'Korean', 'nl': 'Dutch', 'pl': 'Polish',
    'pt': 'Portuguese', 'ru': 'Russian', 'tr': 'Turkish', 'zh-cn': 'Chinese',
    'all': 'All Languages',
}
heatmap_data.index = [LANG_CODE_TO_NAME.get(code.replace("_18089", "").replace("_18108", "").replace("_18118", ""), code) for code in heatmap_data.index]


custom_cmap = LinearSegmentedColormap.from_list('BrWtBl', ['#8B2500', '#E8742A', '#FFFFFF', '#7FB5D3', '#1F3B6E'])

# Tighter limits for more extreme colors

fig, ax = plt.subplots(figsize=(16, 8))
sns.heatmap(
    heatmap_data,
    annot=True,
    fmt='.2f',
    cmap=custom_cmap,
    center=0,
    vmin=-0.10,
    vmax=0.6,
    linewidths=0.5,
    linecolor='white',
    cbar_kws={'label': 'Improvement Rate'},
    ax=ax,
)

ax.set_xlabel('Speaker Language Family', fontsize=18, fontweight='bold')
ax.set_ylabel('Finetuning Language', fontsize=18, fontweight='bold')
ax.set_title('Finetuned Improvement Rate by Training Language × Speaker Language Family',
             fontsize=20, fontweight='bold', pad=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig('figures/finetuned_improvement_heatmap_lang_x_streetname.png', dpi=300, bbox_inches='tight')
plt.show()