# Experiment: Distribution Shift Analysis

**Goal:** Measure how different system prompts change the output distribution for the SAME user queries.

**Metrics:** KL divergence, JS divergence, entropy, top-k overlap

In [None]:
import sys, os
if 'google.colab' in sys.modules:
    if not os.path.exists('/content/LLM-Instruction-Understanding'):
        !git clone https://github.com/maralkh/LLM-Instruction-Understanding.git
    os.chdir('/content/LLM-Instruction-Understanding')
    !pip install -q -r requirements.txt
    sys.path.insert(0, '/content/LLM-Instruction-Understanding')
else:
    sys.path.insert(0, os.path.abspath('..'))

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

from src.model_utils import load_model
from src.metrics import DistributionMetrics, compute_all_metrics
from src.test_configs import get_all_test_prompts, get_system_prompts, build_chat_prompt

plt.style.use('seaborn-v0_8-whitegrid')

In [None]:
model = load_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

test_prompts = get_all_test_prompts()
system_prompts = get_system_prompts()

print(f"Testing {len(test_prompts)} prompts × {len(system_prompts)} system prompts")

## 1. Collect Distributions

In [None]:
def collect_distributions(model, test_prompts, system_prompts, max_tests=20):
    """Collect output distributions for all combinations."""
    results = []
    test_subset = test_prompts[:max_tests]
    
    for test in tqdm(test_subset, desc="Collecting"):
        for sys_name, sys_info in system_prompts.items():
            full_prompt = build_chat_prompt(sys_info['text'], test['prompt'], model.tokenizer)
            dist = model.get_next_token_distribution(full_prompt, top_k=100)
            
            # Extract token and prob separately from tuples
            top_tokens = [t[0] for t in dist['top_tokens']]
            top_probs = [t[1] for t in dist['top_tokens']]
            
            results.append({
                'test_id': test['id'],
                'category': test['category'],
                'system_prompt': sys_name,
                'entropy': dist['entropy'],
                'top_token': top_tokens[0] if top_tokens else '',
                'top_prob': top_probs[0] if top_probs else 0,
                'top_5_tokens': top_tokens[:5],
                'full_probs': dist['full_probs']
            })
    return results

all_distributions = collect_distributions(model, test_prompts, system_prompts)

## 2. Calculate Divergences from Baseline

In [None]:
def calculate_divergences(distributions, baseline='none'):
    """Calculate divergence from baseline for each system prompt."""
    # Group by test_id
    by_test = {}
    for d in distributions:
        tid = d['test_id']
        if tid not in by_test:
            by_test[tid] = {}
        by_test[tid][d['system_prompt']] = d
    
    divergences = []
    for test_id, sys_dists in by_test.items():
        if baseline not in sys_dists:
            continue
        base = sys_dists[baseline]
        
        for sys_name, sys_dist in sys_dists.items():
            if sys_name == baseline:
                continue
            
            # Use improved metrics with numerical stability
            js = DistributionMetrics.jensen_shannon(base['full_probs'], sys_dist['full_probs'])
            kl = DistributionMetrics.kl_divergence(base['full_probs'], sys_dist['full_probs'])
            overlap = DistributionMetrics.top_k_overlap(base['top_5_tokens'], sys_dist['top_5_tokens'], k=5)
            
            divergences.append({
                'test_id': test_id,
                'category': base['category'],
                'system_prompt': sys_name,
                'js_divergence': js,
                'kl_divergence': kl,
                'top_5_overlap': overlap,
                'entropy_change': sys_dist['entropy'] - base['entropy'],
                'top_token_changed': base['top_token'] != sys_dist['top_token']
            })
    
    return pd.DataFrame(divergences)

divergence_df = calculate_divergences(all_distributions)
print(f"Calculated {len(divergence_df)} divergences")
print(f"\nNaN check: JS={divergence_df['js_divergence'].isna().sum()}, KL={divergence_df['kl_divergence'].isna().sum()}")
print(f"\nJS stats: min={divergence_df['js_divergence'].min():.6f}, max={divergence_df['js_divergence'].max():.6f}, mean={divergence_df['js_divergence'].mean():.6f}")
print(f"KL stats: min={divergence_df['kl_divergence'].min():.6f}, max={divergence_df['kl_divergence'].max():.6f}, mean={divergence_df['kl_divergence'].mean():.6f}")

## 3. System Prompt Impact Analysis

In [None]:
# Filter out any invalid values
valid_df = divergence_df[divergence_df['js_divergence'].apply(lambda x: np.isfinite(x))]

sys_impact = valid_df.groupby('system_prompt').agg({
    'js_divergence': ['mean', 'std'],
    'entropy_change': 'mean',
    'top_token_changed': 'mean'
}).round(4)
sys_impact.columns = ['js_mean', 'js_std', 'entropy_change', 'top_change_rate']
sys_impact = sys_impact.sort_values('js_mean', ascending=False)

print("=== System Prompt Impact ===")
print(sys_impact)

In [None]:
import os
os.makedirs('../results', exist_ok=True)

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

s = sys_impact.sort_values('js_mean')

# JS Divergence
ax = axes[0]
ax.barh(range(len(s)), s['js_mean'], xerr=s['js_std'], capsize=3, alpha=0.7)
ax.set_yticks(range(len(s)))
ax.set_yticklabels(s.index)
ax.set_xlabel('JS Divergence')
ax.set_title('Distribution Shift from Baseline')

# Entropy change
ax = axes[1]
colors = ['green' if x > 0 else 'red' for x in s['entropy_change']]
ax.barh(range(len(s)), s['entropy_change'], color=colors, alpha=0.7)
ax.set_yticks(range(len(s)))
ax.set_yticklabels(s.index)
ax.set_xlabel('Entropy Change')
ax.set_title('Uncertainty Change\n(+ve = more uncertain)')

# Top token change rate
ax = axes[2]
ax.barh(range(len(s)), s['top_change_rate'] * 100, alpha=0.7)
ax.set_yticks(range(len(s)))
ax.set_yticklabels(s.index)
ax.set_xlabel('% Top Token Changed')
ax.set_title('How Often Top Prediction Changes')

plt.tight_layout()
plt.savefig('../results/distribution_shift.png', dpi=150)
plt.show()

## 4. Category Sensitivity

In [None]:
# Heatmap: System Prompt × Category
pivot = valid_df.pivot_table(values='js_divergence', index='system_prompt', columns='category', aggfunc='mean')

fig, ax = plt.subplots(figsize=(12, 10))
sns.heatmap(pivot, annot=True, fmt='.3f', cmap='YlOrRd', ax=ax)
ax.set_title('JS Divergence: System Prompt × Category')
plt.tight_layout()
plt.savefig('../results/distribution_heatmap.png', dpi=150)
plt.show()

## 5. Key Findings

In [None]:
print("="*60)
print("DISTRIBUTION SHIFT ANALYSIS SUMMARY")
print("="*60)

print("\n1. HIGHEST IMPACT System Prompts:")
for sys in sys_impact.head(3).index:
    print(f"   • {sys}: JS={sys_impact.loc[sys, 'js_mean']:.4f}")

print("\n2. LOWEST IMPACT System Prompts:")
for sys in sys_impact.tail(3).index:
    print(f"   • {sys}: JS={sys_impact.loc[sys, 'js_mean']:.4f}")

# Save
import json
with open('../results/distribution_shift.json', 'w') as f:
    json.dump({'system_impact': sys_impact.to_dict()}, f, indent=2, default=float)
print("\nResults saved.")