# Protocol Perturbations Figure

Generates the protocol instability figure comparing verdict stability across protocols and nuanced verdict fate.

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import warnings
warnings.filterwarnings('ignore')

# =============================================================================
# FIGURE: Protocol Instability + Nuanced Verdict Fate (Content vs Protocol)
# Panel A: Protocol instability by base judgment (separated by protocol)
# Panel B: Nuanced verdict fate - Content (top) vs Protocol (bottom)
# =============================================================================

# Load data
protocol_df = pd.read_parquet('../data/protocol_tests_combined.parquet')
content_df = pd.read_parquet('../data/master_final_model_own_baseline.parquet')

baseline_content = content_df[content_df['perturbation_type'] == 'none'][['id', 'model', 'run_number', 'standardized_judgment']].copy()
baseline_content = baseline_content.rename(columns={'standardized_judgment': 'baseline_verdict'})
perturbed_content = content_df[content_df['perturbation_type'] != 'none'].copy()
content_merged = perturbed_content.merge(baseline_content, on=['id', 'model', 'run_number'], how='left')

verdict_labels = {
    'Other_At_Fault': 'Other At Fault',
    'Self_At_Fault': 'Self At Fault', 
    'All_At_Fault': 'All At Fault',
    'No_One_At_Fault': 'No One At Fault',
    'Unclear': 'Unclear',
}
verdict_categories = ['Other_At_Fault', 'Self_At_Fault', 'All_At_Fault', 'No_One_At_Fault']

# Structured protocols + unstructured
protocols = ['explanation_first', 'system_prompt', 'unstructured']
protocol_display = {'explanation_first': 'Explanation First', 'system_prompt': 'System Prompt', 'unstructured': 'Unstructured'}

# Calculate flip rates for all protocols (for standard verdict categories)
protocol_flip_data = []
for proto in protocols:
    proto_df = protocol_df[protocol_df['protocol'] == proto]
    for base in verdict_categories:
        subset = proto_df[proto_df['main_study_verdict'] == base]
        n = len(subset)
        if n > 0:
            flipped = (subset['standardized_judgment'] != subset['main_study_verdict']).sum()
            rate = flipped / n * 100
            p = flipped / n
            z = 1.96
            denom = 1 + z**2/n
            spread = z * np.sqrt((p*(1-p) + z**2/(4*n))/n) / denom
            ci = spread * 100
            protocol_flip_data.append({'protocol': proto, 'verdict': base, 'flip_rate': rate, 'ci': ci})

# Add "Unclear" category: rate at which each protocol produces Unclear/No_Verdict outcomes
for proto in protocols:
    proto_df = protocol_df[protocol_df['protocol'] == proto]
    n = len(proto_df)
    if n > 0:
        # Count outcomes that are Unclear, No_Verdict, or similar
        unclear_count = proto_df['standardized_judgment'].isin(['Unclear', 'No_Verdict', 'unclear', 'no_verdict']).sum()
        rate = unclear_count / n * 100
        p = unclear_count / n
        z = 1.96
        denom = 1 + z**2/n
        spread = z * np.sqrt((p*(1-p) + z**2/(4*n))/n) / denom
        ci = spread * 100
        protocol_flip_data.append({'protocol': proto, 'verdict': 'Unclear', 'flip_rate': rate, 'ci': ci})

protocol_flip_df = pd.DataFrame(protocol_flip_data)

# Nuanced fate - Protocol (structured only)
structured_protocols = ['explanation_first', 'system_prompt']
structured_df = protocol_df[protocol_df['protocol'].isin(structured_protocols)].copy()
nuanced_protocol = structured_df[structured_df['main_study_verdict'].isin(['All_At_Fault', 'No_One_At_Fault'])]
total_nuanced_proto = len(nuanced_protocol)

proto_pcts = {
    'Retained': ((nuanced_protocol['standardized_judgment'] == 'All_At_Fault') | 
                 (nuanced_protocol['standardized_judgment'] == 'No_One_At_Fault')).sum() / total_nuanced_proto * 100,
    'Other At Fault': (nuanced_protocol['standardized_judgment'] == 'Other_At_Fault').sum() / total_nuanced_proto * 100,
    'Self At Fault': (nuanced_protocol['standardized_judgment'] == 'Self_At_Fault').sum() / total_nuanced_proto * 100,
}

# Nuanced fate - Content
nuanced_content = content_merged[content_merged['baseline_verdict'].isin(['All_At_Fault', 'No_One_At_Fault'])]
total_nuanced_content = len(nuanced_content)

content_pcts = {
    'Retained': ((nuanced_content['standardized_judgment'] == 'All_At_Fault') | 
                 (nuanced_content['standardized_judgment'] == 'No_One_At_Fault')).sum() / total_nuanced_content * 100,
    'Other At Fault': (nuanced_content['standardized_judgment'] == 'Other_At_Fault').sum() / total_nuanced_content * 100,
    'Self At Fault': (nuanced_content['standardized_judgment'] == 'Self_At_Fault').sum() / total_nuanced_content * 100,
}

# === CREATE FIGURE ===
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5), dpi=150)

# --- Panel A ---
# Order by average flip rate, but put Unclear at the end
verdict_categories_with_unclear = verdict_categories + ['Unclear']
avg_flip = protocol_flip_df[protocol_flip_df['verdict'] != 'Unclear'].groupby('verdict')['flip_rate'].mean().sort_values(ascending=False)
verdict_order = avg_flip.index.tolist() + ['Unclear']

x = np.arange(len(verdict_order))
width = 0.25  # Narrower bars for 3 protocols
colors = {'explanation_first': '#444444', 'system_prompt': '#999999', 'unstructured': '#cccccc'}
hatches = {'explanation_first': '', 'system_prompt': '', 'unstructured': '//'}

for i, proto in enumerate(protocols):
    offset = (i - 1) * width  # Center the 3 bars
    rates = []
    cis = []
    for v in verdict_order:
        match = protocol_flip_df[(protocol_flip_df['verdict'] == v) & (protocol_flip_df['protocol'] == proto)]
        if len(match) > 0:
            rates.append(match['flip_rate'].values[0])
            cis.append(match['ci'].values[0])
        else:
            rates.append(0)
            cis.append(0)
    rates = np.array(rates)
    cis = np.array(cis)
    
    bars = ax1.bar(x + offset, rates, width, label=protocol_display[proto], 
                   color=colors[proto], edgecolor='#222222', linewidth=0.5,
                   hatch=hatches[proto])
    ax1.errorbar(x + offset, rates, yerr=cis, fmt='none', ecolor='#222222', capsize=3, lw=1)

# Add vertical separator before Unclear
ax1.axvline(x=len(verdict_order) - 1.5, color='#888888', linestyle='--', linewidth=1, alpha=0.5)

ax1.set_xlabel('Main Study Verdict', fontweight='bold')
ax1.set_ylabel('Flip Rate (%) / Unclear Rate (%)', fontweight='bold')
ax1.set_title('(A) Protocol Instability by Base Judgment', fontweight='bold', loc='left')
ax1.set_xticks(x)
ax1.set_xticklabels([verdict_labels.get(v, v) for v in verdict_order], rotation=15, ha='right')
ax1.set_ylim(0, 100)  # Increased to accommodate unstructured's higher rates
ax1.legend(loc='upper right', framealpha=0.9)
ax1.grid(axis='y', alpha=0.3, linestyle='--')
ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)

# --- Panel B: Legend for segment colors, only percentages in bars ---
bar_height = 0.25
y_positions = [0.55, 0.15]
row_labels = ['Content Perturbations', 'Protocol Perturbations']
row_data = [content_pcts, proto_pcts]

seg_colors = {'Retained': '#bbbbbb', 'Other At Fault': '#333333', 'Self At Fault': '#666666'}
segment_order = ['Retained', 'Other At Fault', 'Self At Fault']

for y_pos, row_label, data in zip(y_positions, row_labels, row_data):
    left = 0
    for seg in segment_order:
        val = data[seg]
        ax2.barh(y_pos, val, left=left, color=seg_colors[seg], edgecolor='#222222', 
                 linewidth=0.5, height=bar_height)
        
        # Only percentages inside bars
        if val > 4:
            text_color = 'white' if seg in ['Other At Fault', 'Self At Fault'] else 'black'
            ax2.text(left + val/2, y_pos, f'{val:.0f}%', ha='center', va='center', 
                    fontsize=9, fontweight='bold', color=text_color)
        
        left += val
    
    # Row label below the bar
    ax2.text(50, y_pos - 0.17, row_label, ha='center', va='top', fontsize=9, fontstyle='italic')

# Create legend
legend_handles = [mpatches.Patch(facecolor=seg_colors[seg], edgecolor='#222222', label=seg) 
                  for seg in segment_order]
ax2.legend(handles=legend_handles, loc='upper right', framealpha=0.9, fontsize=8)

ax2.set_xlim(0, 100)
ax2.set_ylim(-0.15, 0.85)
ax2.set_xlabel('Percentage', fontweight='bold')
ax2.set_title('(B) Nuanced Verdict Fate (All At Fault + No One At Fault)', fontweight='bold', loc='left')
ax2.set_yticks([])
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
ax2.spines['left'].set_visible(False)

plt.tight_layout()
plt.savefig('../figures/fig_protocol_instability_with_fate.pdf', bbox_inches='tight', dpi=300)
plt.savefig('../figures/fig_protocol_instability_with_fate.png', bbox_inches='tight', dpi=300)
plt.show()

print("Saved: fig_protocol_instability_with_fate.pdf/png")