# Experiment 7: Learning Dynamics - Training Distribution Hypothesis

**Goal:** Understand if prompt effectiveness relates to training distribution patterns.

**Hypothesis:** Effective prompts "activate" patterns from training data (StackOverflow, textbooks, Wikipedia, etc.)

**Key Questions:**
- Do prompts mimicking web text formats perform better?
- Do "unnatural" prompt formats fail predictably?
- Can we identify which training sources a prompt is activating?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from src.model_utils import load_model
from src.metrics import ExperimentResults, SequenceMetrics
from src.visualization import set_style

set_style()

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

## 1. Define Format Templates from Known Sources

We create prompts that mimic formats likely seen during training.

In [None]:
# Templates mimicking known web/training data formats
FORMAT_TEMPLATES = {
    # StackOverflow-style
    "stackoverflow": {
        "template": """**Question:**
{question}

**Answer:**""",
        "description": "StackOverflow Q&A format"
    },
    
    # Wikipedia-style
    "wikipedia": {
        "template": """{question}

According to available sources,""",
        "description": "Wikipedia expository style"
    },
    
    # Textbook-style
    "textbook": {
        "template": """**Example Problem:**
{question}

**Solution:**
To solve this problem,""",
        "description": "Textbook problem/solution format"
    },
    
    # Reddit-style
    "reddit": {
        "template": """{question}

Edit: Figured it out!""",
        "description": "Reddit discussion style"
    },
    
    # Documentation-style
    "documentation": {
        "template": """## Overview

{question}

## Answer

The answer is""",
        "description": "Technical documentation format"
    },
    
    # Chat/conversation style
    "chat": {
        "template": """User: {question}
Assistant:""",
        "description": "Chat conversation format"
    },
    
    # Academic paper style
    "academic": {
        "template": """Abstract: This paper addresses the question: {question}

Results: Our analysis shows that""",
        "description": "Academic paper style"
    },
    
    # JSON-style
    "json": {
        "template": """{{
  "question": "{question}",
  "answer": """",
        "description": "JSON data format"
    },
    
    # Plain/minimal
    "plain": {
        "template": "{question}",
        "description": "No formatting, just the question"
    },
    
    # Unnatural/adversarial formats
    "reversed": {
        "template": "{question_reversed}",
        "description": "Reversed text (unnatural)"
    },
    
    "leetspeak": {
        "template": "{question_leet}",
        "description": "Leetspeak substitution (semi-natural)"
    },
    
    "uppercase": {
        "template": "{question_upper}",
        "description": "All uppercase (semi-natural)"
    }
}

# Test questions
TEST_QUESTIONS = [
    {"q": "What is the sum of 45 and 67?", "a": "112"},
    {"q": "What is the capital of Germany?", "a": "Berlin"},
    {"q": "What color do you get by mixing red and blue?", "a": "purple"},
]

## 2. Test Format Effectiveness

In [None]:
def create_format_variants(question):
    """Create all format variants for a question."""
    variants = {}
    
    # Standard variants
    for name, fmt in FORMAT_TEMPLATES.items():
        if name == "reversed":
            variants[name] = question[::-1]
        elif name == "leetspeak":
            leet_map = {'a': '4', 'e': '3', 'i': '1', 'o': '0', 's': '5', 't': '7'}
            leet_q = ''.join(leet_map.get(c.lower(), c) for c in question)
            variants[name] = leet_q
        elif name == "uppercase":
            variants[name] = question.upper()
        else:
            variants[name] = fmt["template"].format(question=question)
    
    return variants


def test_formats(model, question, expected):
    """Test all format variants for a question."""
    variants = create_format_variants(question)
    results = []
    
    for name, prompt in variants.items():
        dist = model.get_next_token_distribution(prompt)
        seq_probs = model.get_sequence_log_probs(prompt, " " + expected)
        
        results.append({
            "format": name,
            "description": FORMAT_TEMPLATES.get(name, {}).get("description", name),
            "prompt": prompt,
            "target_log_prob": seq_probs["total_log_prob"],
            "entropy": dist["entropy"],
            "top_5": dist["top_tokens"][:5],
            "prompt_length": len(prompt)
        })
    
    return results

In [None]:
# Run format testing
all_format_results = []

for test in tqdm(TEST_QUESTIONS, desc="Testing questions"):
    results = test_formats(model, test["q"], test["a"])
    for r in results:
        r["question"] = test["q"]
        r["expected"] = test["a"]
    all_format_results.extend(results)

df = pd.DataFrame(all_format_results)

In [None]:
# Analyze format effectiveness
print("=== Format Effectiveness Ranking ===")

format_perf = df.groupby('format')['target_log_prob'].agg(['mean', 'std']).sort_values('mean', ascending=False)

print("\nRanked by target log-probability:")
for fmt, row in format_perf.iterrows():
    desc = FORMAT_TEMPLATES.get(fmt, {}).get("description", fmt)
    print(f"  {fmt:15s} ({desc:30s}): {row['mean']:.4f} ± {row['std']:.4f}")

In [None]:
# Visualize
import os
os.makedirs('../results', exist_ok=True)

fig, ax = plt.subplots(figsize=(12, 8))

formats = format_perf.index.tolist()
means = format_perf['mean'].values
stds = format_perf['std'].values

# Color by category
natural_formats = ['stackoverflow', 'wikipedia', 'textbook', 'documentation', 'chat', 'academic', 'json']
colors = ['green' if f in natural_formats else 'orange' if f == 'plain' else 'red' for f in formats]

ax.barh(range(len(formats)), means, xerr=stds, color=colors, alpha=0.7, capsize=3)
ax.set_yticks(range(len(formats)))
ax.set_yticklabels(formats)
ax.set_xlabel('Target Log Probability')
ax.set_title('Format Effectiveness: Training Distribution Hypothesis\n(Green=Natural web formats, Red=Unnatural)')

plt.tight_layout()
plt.savefig('../results/exp7_format_effectiveness.png', dpi=150, bbox_inches='tight')
plt.show()

## 3. Naturalness Score Analysis

In [None]:
def compute_prompt_perplexity(model, prompt):
    """
    Compute the perplexity of the prompt itself.
    Lower perplexity = more "natural" to the model.
    """
    # Get log probs for the prompt tokens
    tokens = model.tokenizer.encode(prompt)
    if len(tokens) < 2:
        return float('inf')
    
    # Compute perplexity by predicting each token
    import torch
    
    inputs = model.tokenizer(prompt, return_tensors="pt").to(model.config.device)
    
    with torch.no_grad():
        outputs = model.model(**inputs)
        logits = outputs.logits
    
    # Shift for next-token prediction
    shift_logits = logits[..., :-1, :].contiguous()
    shift_labels = inputs.input_ids[..., 1:].contiguous()
    
    # Calculate cross-entropy loss
    loss_fct = torch.nn.CrossEntropyLoss(reduction='mean')
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
    
    perplexity = torch.exp(loss).item()
    return perplexity

In [None]:
# Compute perplexity for each format
format_perplexities = {}

test_q = TEST_QUESTIONS[0]["q"]
variants = create_format_variants(test_q)

print("=== Format Perplexity (Lower = More Natural) ===")
for name, prompt in tqdm(variants.items(), desc="Computing perplexities"):
    ppl = compute_prompt_perplexity(model, prompt)
    format_perplexities[name] = ppl

# Sort by perplexity
sorted_ppl = sorted(format_perplexities.items(), key=lambda x: x[1])

print("\nRanked by naturalness (lower perplexity = more natural):")
for name, ppl in sorted_ppl:
    print(f"  {name:15s}: perplexity = {ppl:.2f}")

In [None]:
# Correlate perplexity with effectiveness
effectiveness = dict(zip(format_perf.index, format_perf['mean']))

# Match formats
matched_data = []
for fmt in format_perplexities:
    if fmt in effectiveness:
        matched_data.append({
            "format": fmt,
            "perplexity": format_perplexities[fmt],
            "effectiveness": effectiveness[fmt]
        })

matched_df = pd.DataFrame(matched_data)

# Calculate correlation
corr = matched_df["perplexity"].corr(matched_df["effectiveness"])
print(f"\nCorrelation between perplexity and effectiveness: {corr:.4f}")
print(f"Interpretation: {'Lower perplexity (more natural) → better performance' if corr < 0 else 'Perplexity not strongly predictive'}")

In [None]:
# Scatter plot
fig, ax = plt.subplots(figsize=(10, 6))

ax.scatter(matched_df["perplexity"], matched_df["effectiveness"], s=100, alpha=0.7)

for _, row in matched_df.iterrows():
    ax.annotate(row["format"], (row["perplexity"], row["effectiveness"]),
                textcoords="offset points", xytext=(5,5), fontsize=8)

ax.set_xlabel("Prompt Perplexity (Lower = More Natural)")
ax.set_ylabel("Target Log Probability (Higher = Better)")
ax.set_title(f"Naturalness vs Effectiveness (r={corr:.3f})")

plt.tight_layout()
plt.savefig('../results/exp7_naturalness_vs_effectiveness.png', dpi=150, bbox_inches='tight')
plt.show()

## 4. Source-Specific Triggers

In [None]:
# Test source-specific trigger phrases
SOURCE_TRIGGERS = {
    "stackoverflow": [
        "**Question:**", "**Answer:**", "Edit:", "Update:", "SOLVED:"
    ],
    "wikipedia": [
        "According to", "is defined as", "refers to", "[citation needed]"
    ],
    "academic": [
        "Abstract:", "Introduction:", "Methods:", "Results:", "et al."
    ],
    "documentation": [
        "## Overview", "### Parameters", "Returns:", "Example:", "Usage:"
    ],
    "code": [
        "```python", "def ", "return ", "# ", "import "
    ]
}

def test_trigger_phrases(model, question, expected, triggers):
    """Test effectiveness of adding source-specific triggers."""
    baseline = model.get_sequence_log_probs(question, " " + expected)["total_log_prob"]
    
    results = [{"source": "baseline", "trigger": "none", "improvement": 0, "log_prob": baseline}]
    
    for source, trigger_list in triggers.items():
        for trigger in trigger_list:
            prompt = f"{trigger}\n{question}"
            log_prob = model.get_sequence_log_probs(prompt, " " + expected)["total_log_prob"]
            
            results.append({
                "source": source,
                "trigger": trigger,
                "log_prob": log_prob,
                "improvement": log_prob - baseline
            })
    
    return results

In [None]:
# Test triggers
trigger_results = []

for test in TEST_QUESTIONS:
    results = test_trigger_phrases(model, test["q"], test["a"], SOURCE_TRIGGERS)
    for r in results:
        r["question"] = test["q"]
    trigger_results.extend(results)

trigger_df = pd.DataFrame(trigger_results)

# Analyze by source
print("=== Source-Specific Trigger Effectiveness ===")
source_perf = trigger_df[trigger_df['source'] != 'baseline'].groupby('source')['improvement'].mean().sort_values(ascending=False)

for source, improvement in source_perf.items():
    print(f"  {source:15s}: avg improvement = {improvement:+.4f}")

## 5. Key Findings

In [None]:
print("="*60)
print("EXPERIMENT 7 SUMMARY: Training Distribution Hypothesis")
print("="*60)

print("\n1. Format Effectiveness Ranking (top 5):")
for i, (fmt, row) in enumerate(format_perf.head(5).iterrows()):
    print(f"   {i+1}. {fmt}: {row['mean']:.4f}")

print("\n2. Naturalness-Effectiveness Correlation:")
print(f"   r = {corr:.4f}")
if corr < -0.3:
    print("   → Strong support: More natural formats work better")
elif corr > 0.3:
    print("   → Counter-intuitive: Less natural formats work better")
else:
    print("   → Weak relationship: Naturalness alone doesn't predict effectiveness")

print("\n3. Best Source-Specific Triggers:")
for source, improvement in source_perf.head(3).items():
    print(f"   {source}: {improvement:+.4f}")

print("\n4. Key Insight:")
print("   [Fill after running: Do training-like formats help?]")

In [None]:
# Save results
import json

save_data = {
    "format_effectiveness": format_perf.to_dict(),
    "format_perplexities": format_perplexities,
    "naturalness_correlation": float(corr),
    "source_trigger_effectiveness": source_perf.to_dict()
}

with open('../results/exp7_training_distribution_results.json', 'w') as f:
    json.dump(save_data, f, indent=2, default=float)

print("Results saved.")