# Experiment 2: System Prompt Ablation Studies

**Goal:** Identify which components of system prompts carry the most causal weight.

**Setup:**
- Fixed test prompts
- Take complex system prompts and systematically remove components
- Measure impact of each component removal

**Key Questions:**
- Which parts of system prompts matter most?
- Are some components redundant?
- What's the minimal effective system prompt?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from src.model_utils import load_model
from src.metrics import DistributionMetrics
from src.visualization import set_style
from src.test_configs import ALL_TEST_PROMPTS, build_chat_prompt

set_style()

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Define Component-Based System Prompts

In [None]:
# A complex system prompt broken into components
COMPONENTS = {
    "role": "You are a helpful AI assistant.",
    "expertise": "You have expertise across many domains.",
    "accuracy": "Always provide accurate information.",
    "honesty": "Be honest about uncertainty.",
    "reasoning": "Think step by step before answering.",
    "format": "Be clear and concise in your responses.",
    "safety": "Avoid harmful or misleading content.",
}

def build_system_prompt(include_components):
    """Build system prompt from selected components."""
    parts = [COMPONENTS[c] for c in include_components if c in COMPONENTS]
    return " ".join(parts)

# Full system prompt with all components
FULL_SYSTEM = build_system_prompt(COMPONENTS.keys())
print(f"Full system prompt ({len(COMPONENTS)} components):\n{FULL_SYSTEM}")

## 2. Single Component Ablation

In [None]:
def run_ablation_experiment(model, test_prompts, components, n_test=10):
    """Run ablation: measure effect of removing each component."""
    test_subset = test_prompts[:n_test]
    
    # Get baseline (full system prompt)
    full_prompt = build_system_prompt(components.keys())
    baseline_results = []
    
    for test in test_subset:
        prompt = build_chat_prompt(full_prompt, test["prompt"], model.tokenizer)
        dist = model.get_next_token_distribution(prompt, top_k=50)
        baseline_results.append({
            "test_id": test["id"],
            "full_probs": dist["full_probs"],
            "entropy": dist["entropy"],
            "top_token": dist["top_tokens"][0]
        })
    
    # Ablate each component
    ablation_results = []
    
    for remove_component in tqdm(components.keys(), desc="Ablating components"):
        # Build system prompt without this component
        remaining = [c for c in components.keys() if c != remove_component]
        ablated_prompt = build_system_prompt(remaining)
        
        for i, test in enumerate(test_subset):
            prompt = build_chat_prompt(ablated_prompt, test["prompt"], model.tokenizer)
            dist = model.get_next_token_distribution(prompt, top_k=50)
            
            # Compare to baseline
            js = DistributionMetrics.jensen_shannon(
                baseline_results[i]["full_probs"], dist["full_probs"]
            )
            
            ablation_results.append({
                "removed_component": remove_component,
                "test_id": test["id"],
                "js_divergence": js,
                "entropy_change": dist["entropy"] - baseline_results[i]["entropy"],
                "top_token_changed": dist["top_tokens"][0] != baseline_results[i]["top_token"]
            })
    
    return pd.DataFrame(ablation_results)

In [None]:
ablation_df = run_ablation_experiment(model, ALL_TEST_PROMPTS, COMPONENTS, n_test=15)

# Aggregate by component
component_impact = ablation_df.groupby('removed_component').agg({
    'js_divergence': ['mean', 'std'],
    'entropy_change': 'mean',
    'top_token_changed': 'mean'
}).round(4)
component_impact.columns = ['js_mean', 'js_std', 'entropy_change', 'top_change_rate']
component_impact = component_impact.sort_values('js_mean', ascending=False)

print("=== Component Impact (removal effect) ===")
print("Higher JS = component has MORE impact")
print(component_impact)

In [None]:
import os
os.makedirs('../results', exist_ok=True)

fig, ax = plt.subplots(figsize=(10, 6))

c = component_impact.sort_values('js_mean')
colors = plt.cm.Reds(np.linspace(0.3, 0.9, len(c)))
ax.barh(range(len(c)), c['js_mean'], xerr=c['js_std'], color=colors, capsize=3)
ax.set_yticks(range(len(c)))
ax.set_yticklabels([f"{idx}\n\"{COMPONENTS[idx][:30]}...\"" for idx in c.index])
ax.set_xlabel('JS Divergence When Removed')
ax.set_title('Component Importance (Higher = More Important)')

plt.tight_layout()
plt.savefig('../results/exp2_component_importance.png', dpi=150)
plt.show()

## 3. Cumulative Ablation (Remove Multiple)

In [None]:
def cumulative_ablation(model, test_prompts, components, order, n_test=10):
    """Remove components one by one in specified order."""
    test_subset = test_prompts[:n_test]
    all_components = list(components.keys())
    
    # Baseline: no system prompt
    no_system_results = []
    for test in test_subset:
        prompt = build_chat_prompt("", test["prompt"], model.tokenizer)
        dist = model.get_next_token_distribution(prompt, top_k=50)
        no_system_results.append(dist["entropy"])
    
    results = [{"n_components": 0, "components": "none", "mean_entropy": np.mean(no_system_results)}]
    
    # Add components one by one
    current_components = []
    for comp in order:
        current_components.append(comp)
        sys_prompt = build_system_prompt(current_components)
        
        entropies = []
        for test in test_subset:
            prompt = build_chat_prompt(sys_prompt, test["prompt"], model.tokenizer)
            dist = model.get_next_token_distribution(prompt, top_k=50)
            entropies.append(dist["entropy"])
        
        results.append({
            "n_components": len(current_components),
            "components": "+".join(current_components),
            "last_added": comp,
            "mean_entropy": np.mean(entropies)
        })
    
    return pd.DataFrame(results)

In [None]:
# Try different orderings
# Order by importance (most important first)
importance_order = component_impact.index.tolist()

cumulative_df = cumulative_ablation(model, ALL_TEST_PROMPTS, COMPONENTS, importance_order, n_test=10)
print(cumulative_df)

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(cumulative_df['n_components'], cumulative_df['mean_entropy'], 'o-', linewidth=2, markersize=8)
ax.set_xlabel('Number of Components')
ax.set_ylabel('Mean Entropy')
ax.set_title('Effect of Adding System Prompt Components')

# Annotate
for i, row in cumulative_df.iterrows():
    if i > 0:
        ax.annotate(row.get('last_added', ''), (row['n_components'], row['mean_entropy']),
                   textcoords="offset points", xytext=(5, 5), fontsize=8)

plt.tight_layout()
plt.savefig('../results/exp2_cumulative.png', dpi=150)
plt.show()

## 4. Surface-Level Modifications

In [None]:
# Test if surface changes to system prompt matter
SURFACE_VARIANTS = {
    "original": FULL_SYSTEM,
    "lowercase": FULL_SYSTEM.lower(),
    "uppercase": FULL_SYSTEM.upper(),
    "no_periods": FULL_SYSTEM.replace(".", ""),
    "extra_spaces": FULL_SYSTEM.replace(" ", "  "),
    "newlines": FULL_SYSTEM.replace(". ", ".\n"),
}

In [None]:
surface_results = []
test_subset = ALL_TEST_PROMPTS[:10]

# Get original baseline
original_dists = []
for test in test_subset:
    prompt = build_chat_prompt(SURFACE_VARIANTS["original"], test["prompt"], model.tokenizer)
    dist = model.get_next_token_distribution(prompt, top_k=50)
    original_dists.append(dist)

# Compare variants
for variant_name, variant_text in tqdm(SURFACE_VARIANTS.items()):
    if variant_name == "original":
        continue
    
    for i, test in enumerate(test_subset):
        prompt = build_chat_prompt(variant_text, test["prompt"], model.tokenizer)
        dist = model.get_next_token_distribution(prompt, top_k=50)
        
        js = DistributionMetrics.jensen_shannon(original_dists[i]["full_probs"], dist["full_probs"])
        
        surface_results.append({
            "variant": variant_name,
            "test_id": test["id"],
            "js_divergence": js
        })

surface_df = pd.DataFrame(surface_results)
surface_impact = surface_df.groupby('variant')['js_divergence'].agg(['mean', 'std']).round(4)
print("=== Surface Modification Impact ===")
print(surface_impact.sort_values('mean', ascending=False))

## 5. Key Findings

In [None]:
print("="*60)
print("EXPERIMENT 2 SUMMARY: System Prompt Ablation")
print("="*60)

print("\n1. MOST IMPORTANT Components (removing hurts most):")
for comp in component_impact.head(3).index:
    print(f"   • {comp}: {COMPONENTS[comp][:50]}...")

print("\n2. LEAST IMPORTANT Components:")
for comp in component_impact.tail(3).index:
    print(f"   • {comp}: {COMPONENTS[comp][:50]}...")

print("\n3. Surface Modifications Matter?")
if surface_impact['mean'].max() > 0.1:
    print("   → Yes, formatting affects output")
else:
    print("   → No, model is robust to surface changes")

In [None]:
import json
with open('../results/exp2_results.json', 'w') as f:
    json.dump({
        "component_impact": component_impact.to_dict(),
        "surface_impact": surface_impact.to_dict()
    }, f, indent=2, default=float)
print("Saved.")