# Hierarchical Bayesian Modeling for Creator Experiments

## The Problem

We run A/B experiments on individual creators. Many creators have **small sample sizes** (few hundred users or fewer), making individual estimates noisy and unreliable.

**Example**: A creator with 80 users might show a treatment effect of +$1.50 ± $3.00. That's useless!

## The Solution: Hierarchical Bayesian Modeling (HBM)

We can **borrow strength** from similar creators (grouped by genre) to improve small-sample estimates through **partial pooling**.

This notebook demonstrates:
1. Why standard ("no pooling") estimates fail for small creators
2. How HBM produces better estimates through partial pooling
3. Quantitative validation that HBM recovers ground truth

---

In [None]:
# Setup
import sys
import warnings
warnings.filterwarnings('ignore')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

# Add src to path
sys.path.append('../')

from src.data_generation import generate_experiment_data, summarize_data
from src.frequentist import no_pooling_estimates, complete_pooling_estimates
from src.hierarchical_model import (
    prepare_creator_summaries,
    fit_hierarchical_model,
    extract_hbm_estimates,
    extract_genre_estimates,
    check_mcmc_diagnostics
)
from src.validation import (
    compare_all_methods,
    stratified_comparison,
    compute_shrinkage_metrics,
    validate_genre_recovery
)
from src.visualization import (
    plot_shrinkage,
    plot_mse_comparison,
    plot_coverage_vs_width,
    plot_individual_creators,
    plot_genre_recovery,
    plot_posterior_distributions,
    plot_trace
)

# Set display options
pd.set_option('display.max_columns', None)
pd.set_option('display.precision', 3)

# Set random seed for reproducibility
np.random.seed(42)

print("✓ All imports successful!")

## 1. Generate Synthetic Data

We use synthetic data with **known ground truth** so we can definitively measure which method works best.

### Data Generating Process

```
Hierarchy:
  Platform
    ↓
  Genres (5 types: comedy, music, gaming, etc.)
    ↓
  Creators (100 per genre = 500 total)
    ↓
  Users (highly variable: 30-5000 per creator)
```

Each creator has a **true treatment effect** that varies around their genre's mean.

In [None]:
# Generate synthetic data
df, truth = generate_experiment_data(seed=42)

# Print summary
summarize_data(df, truth)

In [None]:
# Visualize sample size distribution
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Sample size histogram
creator_sizes = df.groupby('creator_id').size()
axes[0].hist(creator_sizes, bins=50, edgecolor='black', alpha=0.7)
axes[0].set_xlabel('Sample Size per Creator', fontweight='bold')
axes[0].set_ylabel('Number of Creators', fontweight='bold')
axes[0].set_title('Sample Size Distribution\n(Many small creators!)', fontweight='bold')
axes[0].axvline(100, color='red', linestyle='--', linewidth=2, label='n=100')
axes[0].axvline(500, color='orange', linestyle='--', linewidth=2, label='n=500')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# True creator effects by genre
effect_by_genre = pd.DataFrame({
    'creator_id': range(len(truth['creator_effects'])),
    'genre_idx': truth['creator_genre'],
    'true_effect': truth['creator_effects']
})
effect_by_genre['genre'] = effect_by_genre['genre_idx'].map(
    {i: name for i, name in enumerate(truth['genre_names'])}
)

sns.violinplot(data=effect_by_genre, x='genre', y='true_effect', ax=axes[1])
axes[1].set_xlabel('Genre', fontweight='bold')
axes[1].set_ylabel('True Treatment Effect', fontweight='bold')
axes[1].set_title('True Effects Vary by Genre\n(This is what we want to recover)', fontweight='bold')
axes[1].tick_params(axis='x', rotation=45)
axes[1].grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

## 2. Frequentist Baselines

### No Pooling (Standard Approach)
Analyze each creator independently. **Unbiased but high variance** for small creators.

### Complete Pooling
Use only genre-level averages. **Low variance but high bias** (ignores individual variation).

In [None]:
# Compute frequentist estimates
print("Computing no-pooling estimates...")
no_pool = no_pooling_estimates(df)

print("Computing complete-pooling estimates...")
complete_pool = complete_pooling_estimates(df)

print("\n✓ Baseline estimates computed!")

# Quick look at no-pooling for small creators
small_creators = no_pool[no_pool['n_total'] < 100].head(10)
print("\nNo-pooling estimates for small creators:")
print(small_creators[['creator_id', 'n_total', 'effect_hat', 'se', 'ci_lower', 'ci_upper']])
print("\nNotice: Huge uncertainty (wide CIs) for small creators!")

## 3. Fit Hierarchical Bayesian Model

The HBM implements **partial pooling**: estimates are shrunk toward their genre mean, with more shrinkage for noisier (smaller-n) estimates.

### Model Specification

```
mu_global ~ Normal(0, 1)                    # Platform-level mean
sigma_genre ~ HalfNormal(1)                 # Between-genre variance

mu_genre[g] ~ Normal(mu_global, sigma_genre)  # Genre effects

sigma_creator ~ HalfNormal(1)               # Within-genre variance
tau[i] ~ Normal(mu_genre[g_i], sigma_creator)  # Creator effects

observed_effect[i] ~ Normal(tau[i], SE[i])  # Likelihood
```

We use **PyMC** with the NUTS sampler (No-U-Turn Sampler, a variant of Hamiltonian Monte Carlo).

In [None]:
# Prepare summary statistics for HBM
print("Preparing creator summary statistics...")
creator_summaries = prepare_creator_summaries(df)
print(f"✓ Prepared summaries for {len(creator_summaries)} creators")

# Fit the hierarchical model
print("\nFitting hierarchical Bayesian model...")
print("(This may take 2-5 minutes)\n")

idata = fit_hierarchical_model(
    creator_summaries,
    n_genres=truth['n_genres'],
    draws=2000,
    tune=1000,
    chains=4,
    random_seed=42
)

In [None]:
# Check MCMC diagnostics
# CRITICAL: Must verify sampling worked before trusting results!
diagnostics = check_mcmc_diagnostics(idata, verbose=True)

In [None]:
# Extract estimates from posterior
print("Extracting HBM estimates...")
hbm = extract_hbm_estimates(idata, creator_summaries)
genre_est = extract_genre_estimates(idata, truth['n_genres'])

print(f"✓ Extracted estimates for {len(hbm)} creators")

# Look at HBM estimates for the same small creators
small_creator_ids = no_pool[no_pool['n_total'] < 100].head(10)['creator_id'].values
small_hbm = hbm[hbm['creator_id'].isin(small_creator_ids)]

print("\nHBM estimates for small creators:")
print(small_hbm[['creator_id', 'n_total', 'effect_hat', 'se', 'ci_lower', 'ci_upper']])
print("\nNotice: Much narrower CIs than no-pooling!")

## 4. The Key Visualization: Shrinkage Plot

This plot shows **how HBM works**:
- x-axis: No-pooling (frequentist) estimate
- y-axis: HBM estimate
- Point size: sample size (larger = more data)
- Color: genre

**What to look for**:
- Small creators (small points) are pulled away from y=x line toward their genre mean
- Large creators (large points) stay near y=x line (data dominates prior)
- This is **partial pooling** in action!

In [None]:
fig = plot_shrinkage(no_pool, hbm, truth, save_path='../outputs/shrinkage_plot.png')
plt.show()

## 5. Quantitative Comparison

Since we have **ground truth**, we can definitively measure which method works best.

### Metrics:
1. **MSE (Mean Squared Error)**: How close are estimates to the truth?
2. **Coverage**: Do 95% intervals contain the true value 95% of the time?
3. **Interval Width**: Narrower is better (if coverage is maintained)

In [None]:
# Overall comparison
overall_results = compare_all_methods(no_pool, complete_pool, hbm, truth, verbose=True)

In [None]:
# Stratified comparison (by sample size)
stratified_results = stratified_comparison(
    no_pool, complete_pool, hbm, truth,
    bins=[0, 100, 500, np.inf],
    verbose=True
)

In [None]:
# MSE comparison plot
fig = plot_mse_comparison(no_pool, complete_pool, hbm, truth,
                         save_path='../outputs/mse_comparison.png')
plt.show()

In [None]:
# Coverage vs. width trade-off
fig = plot_coverage_vs_width(no_pool, complete_pool, hbm, truth,
                            save_path='../outputs/coverage_vs_width.png')
plt.show()

## 6. Individual Creator Examples

Let's look at specific examples to make the improvement tangible.

In [None]:
fig = plot_individual_creators(no_pool, hbm, truth, n_examples=12,
                              save_path='../outputs/individual_creators.png')
plt.show()

## 7. Genre-Level Recovery

Does HBM correctly recover the genre-level structure?

In [None]:
# Validate genre recovery
genre_validation = validate_genre_recovery(genre_est, truth, verbose=True)

In [None]:
# Plot genre recovery
fig = plot_genre_recovery(genre_est, truth, save_path='../outputs/genre_recovery.png')
plt.show()

## 8. Hyperparameter Recovery

Can the model recover the variance parameters?

In [None]:
fig = plot_posterior_distributions(idata, truth, save_path='../outputs/posteriors.png')
plt.show()

## 9. MCMC Diagnostics

Trace plots to verify the sampler is working correctly.

In [None]:
fig = plot_trace(idata, save_path='../outputs/trace_plots.png')
plt.show()

## 10. Shrinkage Analysis

Quantify how much HBM shrinks estimates, and verify it shrinks small creators more.

In [None]:
shrinkage = compute_shrinkage_metrics(no_pool, hbm, truth)

# Plot shrinkage vs. sample size
fig, ax = plt.subplots(figsize=(10, 6))

ax.scatter(shrinkage['n_total'], shrinkage['shrinkage_pct'],
          alpha=0.5, s=30, edgecolors='white', linewidth=0.5)
ax.set_xlabel('Sample Size', fontweight='bold')
ax.set_ylabel('Shrinkage (as % of deviation from genre mean)', fontweight='bold')
ax.set_title('Shrinkage vs. Sample Size\nSmall creators are shrunk more', fontweight='bold', fontsize=14)
ax.set_xscale('log')
ax.axhline(0, color='black', linestyle='-', linewidth=1, alpha=0.3)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../outputs/shrinkage_vs_size.png', dpi=300, bbox_inches='tight')
plt.show()

# Summary statistics
print("\nShrinkage by size bin:")
shrinkage['size_bin'] = pd.cut(shrinkage['n_total'], bins=[0, 100, 500, np.inf],
                               labels=['Small (<100)', 'Medium (100-500)', 'Large (>500)'])
print(shrinkage.groupby('size_bin')['shrinkage_pct'].agg(['mean', 'median']))

## 11. Summary & Conclusions

### Key Findings:

1. **HBM achieves lower MSE than both baselines**, especially for small creators
2. **HBM maintains ~95% coverage** (well-calibrated intervals)
3. **HBM has narrower intervals** than no-pooling while maintaining coverage
4. **Shrinkage is adaptive**: more for small creators, less for large creators
5. **Genre-level structure is recovered** accurately

### When to use HBM:
- ✅ Many entities with variable sample sizes
- ✅ Natural grouping structure (genres, segments, etc.)
- ✅ Need stable estimates for small entities
- ✅ Willing to assume entities within groups are "exchangeable"

### When NOT to use HBM:
- ❌ All entities have large sample sizes (no pooling works fine)
- ❌ No meaningful grouping structure
- ❌ Groups have completely different mechanisms (not exchangeable)

### Practical Considerations:
- **Computation**: ~2-5 minutes for 500 creators (fast enough for batch processing)
- **Validation**: Always check MCMC diagnostics (R-hat, ESS, divergences)
- **Communication**: Explain shrinkage to stakeholders ("we borrow strength from similar creators")

---

## Next Steps

Potential extensions:
1. **Non-normal outcomes**: Revenue is skewed → try log-transform or hurdle model
2. **Multiple groupings**: Crossed effects (genre × audience segment)
3. **Time-varying effects**: Account for temporal trends
4. **Production deployment**: Use variational inference for faster inference
5. **Informative priors**: Use historical experiment data to set priors

---

## Appendix: Export Results

In [None]:
# Save comparison table
comparison_df = pd.DataFrame({
    'creator_id': no_pool['creator_id'],
    'genre': no_pool['genre'],
    'n_total': no_pool['n_total'],
    'true_effect': truth['creator_effects'],
    'no_pool_est': no_pool['effect_hat'],
    'no_pool_ci_width': no_pool['ci_upper'] - no_pool['ci_lower'],
    'complete_pool_est': complete_pool['effect_hat'],
    'hbm_est': hbm['effect_hat'],
    'hbm_ci_width': hbm['ci_upper'] - hbm['ci_lower']
})

comparison_df['no_pool_error'] = comparison_df['no_pool_est'] - comparison_df['true_effect']
comparison_df['hbm_error'] = comparison_df['hbm_est'] - comparison_df['true_effect']

comparison_df.to_csv('../outputs/comparison_results.csv', index=False)
print("✓ Results saved to outputs/comparison_results.csv")

# Save overall metrics
overall_results.to_csv('../outputs/overall_metrics.csv', index=False)
print("✓ Overall metrics saved to outputs/overall_metrics.csv")