# Experiment 6: System Prompt Robustness

**Goal:** Test how robust system prompt effects are to paraphrasing and variations.

**Setup:**
- Fixed test prompts
- Same system prompt intent expressed differently
- Measure consistency of effects across paraphrases

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from src.model_utils import load_model
from src.metrics import DistributionMetrics
from src.visualization import set_style
from src.test_configs import ALL_TEST_PROMPTS, build_chat_prompt

set_style()

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Define Paraphrase Sets

In [None]:
# Same intent, different wording
PARAPHRASE_SETS = {
    "be_concise": [
        "Be concise.",
        "Keep it brief.",
        "Short answers only.",
        "Be as brief as possible.",
        "Respond concisely.",
        "Use few words.",
        "Keep responses short.",
    ],
    
    "be_helpful": [
        "Be helpful.",
        "Help the user.",
        "Provide helpful responses.",
        "Be as helpful as possible.",
        "Assist the user effectively.",
        "Give useful answers.",
        "Be a helpful assistant.",
    ],
    
    "be_accurate": [
        "Be accurate.",
        "Ensure accuracy.",
        "Provide accurate information.",
        "Be precise and correct.",
        "Give accurate answers.",
        "Accuracy is paramount.",
        "Make sure your answers are correct.",
    ],
    
    "think_step_by_step": [
        "Think step by step.",
        "Reason through this step by step.",
        "Work through this systematically.",
        "Break this down into steps.",
        "Analyze this step by step.",
        "Take it one step at a time.",
        "Proceed methodically.",
    ]
}

## 2. Measure Paraphrase Consistency

In [None]:
TEST_SUBSET = ALL_TEST_PROMPTS[:8]

def measure_paraphrase_consistency(model, test_prompts, paraphrase_sets):
    results = []
    
    for intent, paraphrases in tqdm(paraphrase_sets.items()):
        for test in test_prompts:
            # Get distribution for each paraphrase
            distributions = []
            for para in paraphrases:
                prompt = build_chat_prompt(para, test["prompt"], model.tokenizer)
                dist = model.get_next_token_distribution(prompt, top_k=50)
                distributions.append({
                    "paraphrase": para,
                    "full_probs": dist["full_probs"],
                    "entropy": dist["entropy"],
                    "top_token": dist["top_tokens"][0]
                })
            
            # Calculate pairwise JS divergences
            js_values = []
            for i in range(len(distributions)):
                for j in range(i+1, len(distributions)):
                    js = DistributionMetrics.jensen_shannon(
                        distributions[i]["full_probs"],
                        distributions[j]["full_probs"]
                    )
                    js_values.append(js)
            
            # Check top token consistency
            top_tokens = [d["top_token"] for d in distributions]
            top_token_consistency = len(set(top_tokens)) == 1
            
            results.append({
                "intent": intent,
                "test_id": test["id"],
                "category": test["category"],
                "mean_js": np.mean(js_values),
                "max_js": np.max(js_values),
                "entropy_std": np.std([d["entropy"] for d in distributions]),
                "top_token_consistent": top_token_consistency,
                "n_unique_top_tokens": len(set(top_tokens))
            })
    
    return pd.DataFrame(results)

consistency_df = measure_paraphrase_consistency(model, TEST_SUBSET, PARAPHRASE_SETS)

## 3. Analyze Consistency

In [None]:
# Consistency by intent
intent_consistency = consistency_df.groupby('intent').agg({
    'mean_js': ['mean', 'std'],
    'top_token_consistent': 'mean',
    'n_unique_top_tokens': 'mean'
}).round(4)

intent_consistency.columns = ['js_mean', 'js_std', 'token_consistency', 'n_unique_tokens']
intent_consistency = intent_consistency.sort_values('js_mean')

print("=== Paraphrase Robustness by Intent ===")
print("Lower JS = more robust to paraphrasing")
print(intent_consistency)

In [None]:
import os
os.makedirs('../results', exist_ok=True)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# JS divergence between paraphrases
ax = axes[0]
data = intent_consistency['js_mean']
colors = plt.cm.RdYlGn_r(data / data.max())
ax.barh(range(len(data)), data.values, xerr=intent_consistency['js_std'], color=colors, capsize=3)
ax.set_yticks(range(len(data)))
ax.set_yticklabels(data.index)
ax.set_xlabel('Mean JS Divergence Between Paraphrases')
ax.set_title('Paraphrase Sensitivity\n(Lower = More Robust)')

# Top token consistency
ax = axes[1]
data = intent_consistency['token_consistency'] * 100
ax.barh(range(len(data)), data.values, alpha=0.7)
ax.set_yticks(range(len(data)))
ax.set_yticklabels(data.index)
ax.set_xlabel('% Prompts with Consistent Top Token')
ax.set_title('Top Token Consistency Across Paraphrases')

plt.tight_layout()
plt.savefig('../results/exp6_robustness.png', dpi=150)
plt.show()

## 4. Identify Sensitive Paraphrases

In [None]:
# For each intent, find which paraphrases cause most divergence
print("=== Most Divergent Paraphrase Pairs ===")

for intent, paraphrases in PARAPHRASE_SETS.items():
    # Sample one test
    test = TEST_SUBSET[0]
    
    # Get all distributions
    dists = []
    for para in paraphrases:
        prompt = build_chat_prompt(para, test["prompt"], model.tokenizer)
        dist = model.get_next_token_distribution(prompt, top_k=50)
        dists.append((para, dist["full_probs"]))
    
    # Find max divergent pair
    max_js = 0
    max_pair = None
    for i in range(len(dists)):
        for j in range(i+1, len(dists)):
            js = DistributionMetrics.jensen_shannon(dists[i][1], dists[j][1])
            if js > max_js:
                max_js = js
                max_pair = (dists[i][0], dists[j][0])
    
    print(f"\n{intent}:")
    print(f"  Max JS: {max_js:.4f}")
    if max_pair:
        print(f"  '{max_pair[0]}' vs '{max_pair[1]}'")

In [None]:
print("\n=== Summary ===")
most_robust = intent_consistency['js_mean'].idxmin()
least_robust = intent_consistency['js_mean'].idxmax()
print(f"Most robust intent: {most_robust}")
print(f"Least robust intent: {least_robust}")
print(f"\nImplication: '{least_robust}' instructions are sensitive to exact wording")

In [None]:
import json
with open('../results/exp6_results.json', 'w') as f:
    json.dump({"intent_consistency": intent_consistency.to_dict()}, f, indent=2, default=float)
print("Saved.")