# Experiment 9: Prompt-Task Interaction Matrix

**Goal:** Understand whether prompt strategies are task-specific or general.

**Key Questions:**
- Do some strategies work universally across tasks?
- Are there task-specific strategies?
- Can certain strategies hurt performance on some tasks?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from src.model_utils import load_model
from src.visualization import set_style

set_style()

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Define Task Types and Strategies

In [None]:
TASKS = {
    "arithmetic": [
        {"q": "What is 23 + 45?", "a": "68"},
        {"q": "What is 156 - 89?", "a": "67"},
        {"q": "What is 12 * 7?", "a": "84"},
    ],
    "factual": [
        {"q": "What is the capital of France?", "a": "Paris"},
        {"q": "What planet is known as the Red Planet?", "a": "Mars"},
        {"q": "What is the chemical symbol for water?", "a": "H2O"},
    ],
    "reasoning": [
        {"q": "If all cats are animals, and all animals need food, do cats need food?", "a": "yes"},
        {"q": "John is taller than Mary. Mary is taller than Sue. Is John taller than Sue?", "a": "yes"},
        {"q": "If it's raining, the ground is wet. The ground is wet. Is it definitely raining?", "a": "no"},
    ],
    "classification": [
        {"q": "Is 'happy' a positive or negative word?", "a": "positive"},
        {"q": "Is 'terrible' a positive or negative word?", "a": "negative"},
        {"q": "Is 'excellent' a positive or negative word?", "a": "positive"},
    ],
    "extraction": [
        {"q": "Extract the name: 'John Smith is a doctor.'", "a": "John Smith"},
        {"q": "Extract the number: 'The price is $45.'", "a": "45"},
        {"q": "Extract the city: 'She lives in Tokyo.'", "a": "Tokyo"},
    ],
    "completion": [
        {"q": "Complete: The early bird catches the", "a": "worm"},
        {"q": "Complete: An apple a day keeps the doctor", "a": "away"},
        {"q": "Complete: Rome wasn't built in a", "a": "day"},
    ]
}

STRATEGIES = {
    "plain": "{question}",
    "cot": "{question}\n\nLet's think step by step.",
    "expert": "You are an expert. {question}",
    "careful": "Be careful and accurate. {question}",
    "structured": "Question: {question}\nAnswer:",
    "teacher": "You are a helpful teacher. {question}",
    "concise": "{question} Answer concisely:",
    "confident": "{question} I'm sure the answer is"
}

## 2. Build the Interaction Matrix

In [None]:
def evaluate_strategy_on_task(model, strategy_template, task_questions):
    log_probs = []
    for item in task_questions:
        prompt = strategy_template.format(question=item["q"])
        seq_probs = model.get_sequence_log_probs(prompt, " " + item["a"])
        log_probs.append(seq_probs["total_log_prob"])
    return {"mean": np.mean(log_probs), "std": np.std(log_probs)}

In [None]:
results_matrix = []
pbar = tqdm(total=len(STRATEGIES) * len(TASKS), desc="Building matrix")

for strategy_name, template in STRATEGIES.items():
    for task_name, questions in TASKS.items():
        result = evaluate_strategy_on_task(model, template, questions)
        results_matrix.append({
            "strategy": strategy_name,
            "task": task_name,
            "mean_log_prob": result["mean"],
            "std_log_prob": result["std"]
        })
        pbar.update(1)
pbar.close()

matrix_df = pd.DataFrame(results_matrix)
pivot_matrix = matrix_df.pivot(index='strategy', columns='task', values='mean_log_prob')
print(pivot_matrix.round(3))

## 3. Visualize the Matrix

In [None]:
import os
os.makedirs('../results', exist_ok=True)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Absolute performance
sns.heatmap(pivot_matrix, annot=True, fmt='.2f', cmap='RdYlGn', ax=axes[0],
            cbar_kws={'label': 'Log Probability'})
axes[0].set_title('Absolute Performance')

# Normalized per task
normalized = pivot_matrix.apply(lambda x: (x - x.min()) / (x.max() - x.min() + 1e-10), axis=0)
sns.heatmap(normalized, annot=True, fmt='.2f', cmap='RdYlGn', ax=axes[1],
            cbar_kws={'label': 'Normalized (0-1)'})
axes[1].set_title('Normalized per Task (0=Worst, 1=Best)')

plt.tight_layout()
plt.savefig('../results/exp9_interaction_matrix.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Identify Universal vs Task-Specific Strategies

In [None]:
# Strategy analysis
strategy_stats = pd.DataFrame({
    "mean_across_tasks": pivot_matrix.mean(axis=1),
    "std_across_tasks": pivot_matrix.std(axis=1),
    "best_count": (normalized == 1.0).sum(axis=1),
    "worst_count": (normalized == 0.0).sum(axis=1)
}).sort_values("mean_across_tasks", ascending=False)

print("=== Strategy Analysis ===")
print("\nRanked by average performance:")
print(strategy_stats.round(3))

In [None]:
# Identify universal strategies (low variance, good mean)
print("\n=== Universal vs Specialized ===")

mean_threshold = strategy_stats["mean_across_tasks"].median()
std_threshold = strategy_stats["std_across_tasks"].median()

print("\nUniversal (good everywhere, low variance):")
universal = strategy_stats[
    (strategy_stats["mean_across_tasks"] > mean_threshold) &
    (strategy_stats["std_across_tasks"] < std_threshold)
]
for s in universal.index:
    print(f"  {s}")

print("\nSpecialized (high variance):")
specialized = strategy_stats[strategy_stats["std_across_tasks"] > std_threshold]
for s in specialized.index:
    best_task = pivot_matrix.loc[s].idxmax()
    worst_task = pivot_matrix.loc[s].idxmin()
    print(f"  {s}: best on '{best_task}', worst on '{worst_task}'")

In [None]:
# Best strategy per task
print("\n=== Best Strategy per Task ===")
for task in TASKS.keys():
    best_strategy = pivot_matrix[task].idxmax()
    best_score = pivot_matrix[task].max()
    print(f"  {task:15s}: {best_strategy} (score={best_score:.3f})")

## 5. Interaction Effects

In [None]:
# Check for significant interactions (where strategy effect depends on task)
# Compare to additive model: E[Y] = strategy_effect + task_effect

strategy_means = pivot_matrix.mean(axis=1)
task_means = pivot_matrix.mean(axis=0)
grand_mean = pivot_matrix.values.mean()

# Additive prediction
additive_pred = pd.DataFrame(
    [[strategy_means[s] + task_means[t] - grand_mean for t in TASKS.keys()] for s in STRATEGIES.keys()],
    index=STRATEGIES.keys(),
    columns=TASKS.keys()
)

# Residuals (interaction effects)
residuals = pivot_matrix - additive_pred

print("=== Interaction Effects (Residuals from Additive Model) ===")
print("Positive = strategy works better than expected for this task")
print("Negative = strategy works worse than expected for this task\n")
print(residuals.round(3))

In [None]:
# Visualize interactions
fig, ax = plt.subplots(figsize=(12, 6))

sns.heatmap(residuals, annot=True, fmt='.2f', cmap='RdBu_r', center=0, ax=ax,
            cbar_kws={'label': 'Interaction Effect'})
ax.set_title('Strategy Ã— Task Interactions\n(Deviation from Additive Model)')

plt.tight_layout()
plt.savefig('../results/exp9_interactions.png', dpi=150, bbox_inches='tight')
plt.show()

## 6. Key Findings

In [None]:
print("="*60)
print("EXPERIMENT 9 SUMMARY: Prompt-Task Interaction Matrix")
print("="*60)

print("\n1. Best Overall Strategies:")
for s in strategy_stats.head(3).index:
    print(f"   {s}: mean={strategy_stats.loc[s, 'mean_across_tasks']:.3f}")

print("\n2. Task-Specific Findings:")
for task in TASKS.keys():
    best = pivot_matrix[task].idxmax()
    print(f"   {task}: best with '{best}'")

print("\n3. Strongest Interactions:")
flat_residuals = residuals.unstack().sort_values()
print("   Worst mismatches:")
for (task, strategy), val in flat_residuals.head(3).items():
    print(f"     {strategy} on {task}: {val:+.3f}")
print("   Best synergies:")
for (task, strategy), val in flat_residuals.tail(3).items():
    print(f"     {strategy} on {task}: {val:+.3f}")

In [None]:
# Save results
import json

save_data = {
    "performance_matrix": pivot_matrix.to_dict(),
    "strategy_stats": strategy_stats.to_dict(),
    "interactions": residuals.to_dict()
}

with open('../results/exp9_matrix_results.json', 'w') as f:
    json.dump(save_data, f, indent=2, default=float)

print("Results saved.")