# Experiment 2: Ablation Studies on Prompt Components

**Goal:** Identify which parts of a prompt carry the most causal weight.

**Key Questions:**
- Is prompt effectiveness driven by semantic content or structural cues?
- Which specific words/phrases are most impactful?
- Do effects come from adding helpful information or removing noise?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import re

from src.model_utils import load_model
from src.prompt_utils import (
    ablate_prompt_component, shuffle_sentences, 
    remove_punctuation, lowercase_prompt
)
from src.metrics import ExperimentResults, SequenceMetrics
from src.visualization import set_style, plot_ablation_results

set_style()

In [None]:
# Load model
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Define High-Performing Prompts to Ablate

In [None]:
PROMPTS_TO_ABLATE = {
    "chain_of_thought": {
        "full": """You are an expert mathematician. 
Answer the following question step by step, showing your reasoning.

Question: If a train travels at 60 mph for 2 hours, then 40 mph for 1 hour, what is the total distance?

Let's think step by step.""",
        "components": {
            "persona": "You are an expert mathematician.",
            "instruction": "Answer the following question step by step, showing your reasoning.",
            "cot_trigger": "Let's think step by step.",
            "question_label": "Question:"
        },
        "expected": "160"
    },
    
    "few_shot": {
        "full": """Classify the sentiment as positive or negative.

Text: \"This movie was fantastic!\"
Sentiment: positive

Text: \"I hated every minute of it.\"
Sentiment: negative

Text: \"The food was absolutely delicious!\"
Sentiment:""",
        "components": {
            "instruction": "Classify the sentiment as positive or negative.",
            "example_1": 'Text: \"This movie was fantastic!\"\nSentiment: positive',
            "example_2": 'Text: \"I hated every minute of it.\"\nSentiment: negative'
        },
        "expected": " positive"
    },
    
    "structured": {
        "full": """### Task
Extract the key information from the following text.

### Input
John Smith, age 35, works as a software engineer at Google in Mountain View.

### Output Format
Name: [name]
Age: [age]
Job: [job]
Company: [company]

### Output""",
        "components": {
            "task_section": "### Task\nExtract the key information from the following text.",
            "format_section": "### Output Format\nName: [name]\nAge: [age]\nJob: [job]\nCompany: [company]"
        },
        "expected": "\nName: John Smith"
    }
}

## 2. Ablation Functions

In [None]:
def create_ablations(prompt_data):
    """Create all ablation variants for a prompt."""
    full_prompt = prompt_data["full"]
    components = prompt_data["components"]
    
    ablations = {"full": full_prompt}
    
    # Remove each component
    for comp_name, comp_text in components.items():
        ablated = ablate_prompt_component(full_prompt, [comp_text])
        ablations[f"remove_{comp_name}"] = ablated
    
    # Surface-level modifications
    ablations["lowercase"] = lowercase_prompt(full_prompt)
    ablations["no_punctuation"] = remove_punctuation(full_prompt)
    ablations["shuffled"] = shuffle_sentences(full_prompt)
    ablations["no_newlines"] = " ".join(full_prompt.split())
    
    return ablations


def run_ablation_experiment(model, prompt_data):
    """Run ablation experiment and measure impact."""
    ablations = create_ablations(prompt_data)
    expected = prompt_data["expected"]
    
    results = {}
    for ablation_name, ablated_prompt in tqdm(ablations.items(), desc="Running ablations"):
        seq_probs = model.get_sequence_log_probs(ablated_prompt, expected)
        dist = model.get_next_token_distribution(ablated_prompt)
        
        results[ablation_name] = {
            "prompt": ablated_prompt,
            "target_log_prob": seq_probs["total_log_prob"],
            "entropy": dist["entropy"],
            "top_5": dist["top_tokens"][:5],
            "perplexity": SequenceMetrics.perplexity(seq_probs["log_probs"])
        }
    
    return results

## 3. Run Ablation Experiments

In [None]:
all_ablation_results = {}

for prompt_name, prompt_data in PROMPTS_TO_ABLATE.items():
    print(f"\n{'='*60}")
    print(f"Ablating: {prompt_name}")
    print(f"{'='*60}")
    
    results = run_ablation_experiment(model, prompt_data)
    all_ablation_results[prompt_name] = results
    
    baseline = results["full"]["target_log_prob"]
    print(f"\nBaseline log-prob: {baseline:.4f}")
    print(f"\nAblation impacts:")
    
    for abl_name, abl_results in sorted(results.items(), key=lambda x: x[1]["target_log_prob"], reverse=True):
        if abl_name == "full":
            continue
        diff = abl_results["target_log_prob"] - baseline
        pct = (diff / abs(baseline)) * 100 if baseline != 0 else 0
        print(f"  {abl_name:25s}: {abl_results['target_log_prob']:.4f} ({pct:+.1f}%)")

## 4. Visualize Ablation Results

In [None]:
import os
os.makedirs('../results', exist_ok=True)

for prompt_name, results in all_ablation_results.items():
    baseline_score = results["full"]["target_log_prob"]
    ablation_scores = {k: v["target_log_prob"] for k, v in results.items() if k != "full"}
    
    fig = plot_ablation_results(
        baseline_score, ablation_scores,
        metric_name="Target Log Probability",
        title=f"Ablation Study: {prompt_name}"
    )
    plt.savefig(f'../results/exp2_ablation_{prompt_name}.png', dpi=150, bbox_inches='tight')
    plt.show()

## 5. Word-Level Ablation

In [None]:
def word_level_ablation(model, prompt, expected, test_words):
    """Measure impact of removing individual words."""
    words = prompt.split()
    baseline = model.get_sequence_log_probs(prompt, expected)["total_log_prob"]
    
    results = {}
    for word in tqdm(test_words, desc="Testing words"):
        ablated = " ".join(w for w in words if w != word)
        log_prob = model.get_sequence_log_probs(ablated, expected)["total_log_prob"]
        results[word] = {"log_prob": log_prob, "impact": log_prob - baseline}
    
    return results, baseline

In [None]:
# Test on CoT prompt
cot_prompt = PROMPTS_TO_ABLATE["chain_of_thought"]
test_words = ["expert", "mathematician", "step", "reasoning", "Question", "Let's", "think"]

word_results, baseline = word_level_ablation(
    model, cot_prompt["full"], cot_prompt["expected"], test_words
)

print(f"\nBaseline: {baseline:.4f}")
for word, data in sorted(word_results.items(), key=lambda x: x[1]["impact"]):
    print(f"  '{word}': impact = {data['impact']:.4f}")

In [None]:
# Visualize
fig, ax = plt.subplots(figsize=(10, 6))
sorted_items = sorted(word_results.items(), key=lambda x: x[1]["impact"])
words = [w for w, _ in sorted_items]
impacts = [d["impact"] for _, d in sorted_items]

colors = ['green' if i > 0 else 'red' for i in impacts]
ax.barh(range(len(words)), impacts, color=colors, alpha=0.7)
ax.set_yticks(range(len(words)))
ax.set_yticklabels([f'"{w}"' for w in words])
ax.set_xlabel('Impact on Log Probability (when removed)')
ax.set_title('Word-Level Ablation')
ax.axvline(x=0, color='gray', linestyle='-', linewidth=0.5)

plt.tight_layout()
plt.savefig('../results/exp2_word_level.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Key Findings

In [None]:
print("="*60)
print("EXPERIMENT 2 SUMMARY: Ablation Studies")
print("="*60)

for prompt_name, results in all_ablation_results.items():
    baseline = results["full"]["target_log_prob"]
    
    # Find most impactful ablation
    impacts = [(k, v["target_log_prob"] - baseline) for k, v in results.items() if k != "full"]
    most_harmful = min(impacts, key=lambda x: x[1])
    most_helpful = max(impacts, key=lambda x: x[1])
    
    print(f"\n{prompt_name}:")
    print(f"  Most harmful ablation: {most_harmful[0]} ({most_harmful[1]:.4f})")
    print(f"  Most helpful ablation: {most_helpful[0]} ({most_helpful[1]:.4f})")

In [None]:
# Save results
import json

# Convert to serializable format
save_data = {}
for prompt_name, results in all_ablation_results.items():
    save_data[prompt_name] = {
        k: {"target_log_prob": v["target_log_prob"], "entropy": v["entropy"]}
        for k, v in results.items()
    }

with open('../results/exp2_ablation_results.json', 'w') as f:
    json.dump(save_data, f, indent=2)

print("Results saved.")