# Experiment 1: Distribution Shift Analysis

**Goal:** Quantify how prompt variations change the output probability distribution.

**Key Questions:**
- How much does the next-token distribution change with different prompts?
- Which prompt modifications cause the largest distribution shifts?
- Is there a correlation between distribution change and task performance?

In [None]:
import sys
sys.path.append('..')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

from src.model_utils import load_model, ModelConfig
from src.prompt_utils import PromptVariantGenerator, INSTRUCTION_SPECIFICITY, FORMATTING_STYLES, PERSONAS, THINKING_STYLES
from src.metrics import DistributionMetrics, ExperimentResults, compute_all_metrics
from src.visualization import set_style, plot_distribution_comparison, plot_entropy_comparison, plot_dimension_heatmap

set_style()

## 1. Load Model

We'll use TinyLlama-1.1B as our base model. It's small enough for single-GPU experiments while being capable enough to show meaningful prompt effects.

In [None]:
# Load model
model = load_model(
    model_name="TinyLlama/TinyLlama-1.1B-Chat-v1.0",
    device="cuda"  # or "cpu" if no GPU
)
print(f"Model loaded on {model.config.device}")

## 2. Define Test Questions

We'll test across different question types to see if prompt effects are consistent.

In [None]:
# Test questions covering different task types
TEST_QUESTIONS = {
    "factual": "What is the capital of France?",
    "reasoning": "If a train travels 60 miles in 1 hour, how far will it travel in 2.5 hours?",
    "classification": "Is the following sentence positive or negative: 'I absolutely loved this movie!'",
    "open_ended": "What are the main benefits of renewable energy?",
    "coding": "Write a Python function to calculate the factorial of a number."
}

# Expected completions for measuring probability mass
EXPECTED_COMPLETIONS = {
    "factual": "Paris",
    "reasoning": "150",
    "classification": "positive",
    "open_ended": None,  # Open-ended, no single correct answer
    "coding": "def"  # Check if it starts with function definition
}

## 3. Generate Prompt Variants

We'll vary prompts along multiple dimensions to understand what causes distribution shifts.

In [None]:
# Generate variants for one question first
question = TEST_QUESTIONS["factual"]

# Vary across all dimensions
variants = PromptVariantGenerator.create_variants(
    question=question,
    dimensions=['specificity', 'format', 'persona']
)

print(f"Generated {len(variants)} prompt variants")
print("\nExample variants:")
for v in variants[:3]:
    print(f"\nConfig: {v['config']}")
    print(f"Prompt: {v['prompt'][:200]}...")

## 4. Measure Distribution Shifts

For each prompt variant, we'll:
1. Get the next-token probability distribution
2. Calculate entropy
3. Compare to a baseline (no-frills prompt)

In [None]:
def analyze_prompt_variants(model, question, variants, expected_completion=None):
    """Analyze distribution changes across prompt variants."""
    results = ExperimentResults()
    
    # Get baseline distribution (raw question)
    baseline_dist = model.get_next_token_distribution(question)
    baseline_probs = baseline_dist['full_probs']
    
    for variant in tqdm(variants, desc="Analyzing variants"):
        prompt = variant['prompt']
        config = variant['config']
        
        # Get distribution for this variant
        dist = model.get_next_token_distribution(prompt)
        variant_probs = dist['full_probs']
        
        # Calculate comparison metrics
        metrics = compute_all_metrics(
            baseline_probs, variant_probs,
            baseline_dist['top_tokens'], dist['top_tokens']
        )
        
        # If we have expected completion, check probability
        if expected_completion:
            seq_probs = model.get_sequence_log_probs(prompt, " " + expected_completion)
            metrics['target_log_prob'] = seq_probs['total_log_prob']
            metrics['target_avg_log_prob'] = seq_probs['avg_log_prob']
        
        results.add_result({
            'prompt': prompt,
            'config': config,
            'top_5_tokens': dist['top_tokens'][:5],
            **metrics
        })
    
    return results, baseline_dist

In [None]:
# Run analysis on factual question
results, baseline = analyze_prompt_variants(
    model, 
    TEST_QUESTIONS["factual"],
    variants,
    expected_completion=EXPECTED_COMPLETIONS["factual"]
)

print(f"\nAnalyzed {len(results.results)} variants")

## 5. Analyze Results

In [None]:
# Convert to DataFrame for analysis
df = results.to_dataframe()

# Summary statistics
print("=== Key Metrics Summary ===")
for metric in ['kl_divergence', 'jensen_shannon', 'variant_entropy', 'target_log_prob']:
    if metric in df.columns:
        stats = results.summary_statistics(metric)
        print(f"\n{metric}:")
        print(f"  Mean: {stats['mean']:.4f}, Std: {stats['std']:.4f}")
        print(f"  Range: [{stats['min']:.4f}, {stats['max']:.4f}]")

In [None]:
# Which dimension causes largest distribution shifts?
print("\n=== Distribution Shift by Dimension ===")

for dim in ['config_specificity', 'config_format', 'config_persona']:
    if dim in df.columns:
        grouped = df.groupby(dim)['kl_divergence'].mean().sort_values(ascending=False)
        print(f"\n{dim}:")
        for val, kl in grouped.items():
            print(f"  {val}: KL={kl:.4f}")

In [None]:
# Visualize: Top configurations by target probability
if 'target_log_prob' in df.columns:
    top_10 = df.nlargest(10, 'target_log_prob')
    bottom_10 = df.nsmallest(10, 'target_log_prob')
    
    print("\n=== Top 10 Prompts for Target Completion ===")
    for _, row in top_10.iterrows():
        config = {k.replace('config_', ''): row[k] for k in row.index if k.startswith('config_')}
        print(f"Log-prob: {row['target_log_prob']:.3f} | Config: {config}")
    
    print("\n=== Bottom 10 Prompts for Target Completion ===")
    for _, row in bottom_10.iterrows():
        config = {k.replace('config_', ''): row[k] for k in row.index if k.startswith('config_')}
        print(f"Log-prob: {row['target_log_prob']:.3f} | Config: {config}")

## 6. Visualizations

In [None]:
# Compare distributions for best vs worst prompts
if 'target_log_prob' in df.columns:
    best_idx = df['target_log_prob'].idxmax()
    worst_idx = df['target_log_prob'].idxmin()
    
    best_result = results.results[best_idx]
    worst_result = results.results[worst_idx]
    
    # Get distributions
    best_dist = model.get_next_token_distribution(best_result['prompt'])
    worst_dist = model.get_next_token_distribution(worst_result['prompt'])
    
    fig = plot_distribution_comparison(
        [baseline, best_dist, worst_dist],
        ['Baseline (raw question)', 'Best Prompt', 'Worst Prompt'],
        title=f'Distribution Comparison for: "{TEST_QUESTIONS["factual"]}"'
    )
    plt.savefig('../results/exp1_distribution_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Heatmap: Specificity × Format
if 'config_specificity' in df.columns and 'config_format' in df.columns:
    fig = plot_dimension_heatmap(
        df, 'config_specificity', 'config_format', 'kl_divergence',
        title='KL Divergence from Baseline: Specificity × Format'
    )
    plt.savefig('../results/exp1_heatmap_specificity_format.png', dpi=150, bbox_inches='tight')
    plt.show()

In [None]:
# Entropy by configuration
fig = plot_entropy_comparison(
    results.results[:20],  # Top 20 for readability
    title='Next-Token Entropy Across Prompt Variants'
)
plt.savefig('../results/exp1_entropy_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Cross-Question Analysis

Let's see if the same prompt variations have consistent effects across different questions.

In [None]:
# Analyze across all question types
all_results = {}

for q_type, question in TEST_QUESTIONS.items():
    print(f"\nAnalyzing: {q_type}")
    
    variants = PromptVariantGenerator.create_variants(
        question=question,
        dimensions=['specificity', 'format']
    )
    
    results_q, _ = analyze_prompt_variants(
        model, question, variants,
        expected_completion=EXPECTED_COMPLETIONS.get(q_type)
    )
    
    all_results[q_type] = results_q

In [None]:
# Compare which configurations work best across question types
print("\n=== Best Configuration by Question Type ===")

for q_type, results_q in all_results.items():
    df_q = results_q.to_dataframe()
    
    # Best by lowest entropy (more confident)
    best_entropy = df_q.loc[df_q['variant_entropy'].idxmin()]
    config = {k.replace('config_', ''): best_entropy[k] for k in best_entropy.index if k.startswith('config_')}
    print(f"\n{q_type} - Lowest Entropy: {best_entropy['variant_entropy']:.3f}")
    print(f"  Config: {config}")

## 8. Key Findings

Summarize observations from this experiment.

In [None]:
# Aggregate findings
print("=" * 60)
print("EXPERIMENT 1 SUMMARY: Distribution Shift Analysis")
print("=" * 60)

# Calculate aggregate statistics across all question types
all_kl = []
all_entropy_changes = []

for q_type, results_q in all_results.items():
    df_q = results_q.to_dataframe()
    all_kl.extend(df_q['kl_divergence'].tolist())
    all_entropy_changes.extend((df_q['variant_entropy'] - df_q['baseline_entropy']).tolist())

print(f"\n1. Distribution Shift Magnitude:")
print(f"   - Average KL divergence from baseline: {np.mean(all_kl):.4f}")
print(f"   - Max KL divergence observed: {np.max(all_kl):.4f}")

print(f"\n2. Entropy Changes:")
print(f"   - Average entropy change: {np.mean(all_entropy_changes):.4f}")
print(f"   - Prompts that increase entropy: {sum(1 for e in all_entropy_changes if e > 0)}")
print(f"   - Prompts that decrease entropy: {sum(1 for e in all_entropy_changes if e < 0)}")

print(f"\n3. Key Observations:")
print(f"   - [Fill in after running experiments]")
print(f"   - [Which dimensions matter most?]")
print(f"   - [Are effects consistent across question types?]")

In [None]:
# Save all results
import json
import os

os.makedirs('../results', exist_ok=True)

for q_type, results_q in all_results.items():
    results_q.save(f'../results/exp1_results_{q_type}.json')

print("Results saved to ../results/")