# Phase 4: Model Validation and Diagnostics

This notebook validates the MMM models using:
1. MCMC Convergence Diagnostics (R-hat, ESS, trace plots)
2. Posterior Predictive Checks (MAPE, R², coverage)
3. Model Fit Quality Assessment
4. Parameter Interpretation

**Prerequisites:** Run notebooks 02 and 03 first to fit the models

In [None]:
import sys
sys.path.append('..')

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import arviz as az
from scripts.mmm_optimized import UCM_MMM_Optimized
from scripts.bvar_optimized import BVAR_Optimized

print("="*70)
print("NOTE: This notebook assumes you have already run:")
print("  1. 02_Short_Term_Model.ipynb (fitted mmm object)")
print("  2. 03_Long_Term_Model.ipynb (fitted bvar object)")
print("="*70)
print("\nIf you haven't run those notebooks, run them first!")
print("This notebook will demonstrate validation using synthetic results.")
print("\nFor actual validation, you would load the fitted model objects")
print("from the previous notebooks using pickle or by re-running them.")

## MCMC Convergence Diagnostics

In [None]:
# MCMC Convergence Diagnostics

print("="*70)
print("1. CONVERGENCE DIAGNOSTICS")
print("="*70)

# Assuming 'mmm' and 'bvar' objects are available from previous notebooks
# If not, this demonstrates the validation process

print("\n## UCM-MMM Model Convergence")
print("-" * 70)

# Get summary statistics
# summary_mmm = mmm.summary()

# Example of what to check:
print("""
Key metrics to check:

1. R-hat (Gelman-Rubin statistic):
   - R-hat < 1.01 = EXCELLENT convergence
   - R-hat < 1.05 = Good convergence
   - R-hat > 1.05 = Poor convergence (increase draws)
   
   Example check:
   ```python
   summary = mmm.summary()
   rhat_max = summary['r_hat'].max()
   print(f"Max R-hat: {rhat_max:.4f}")
   
   if rhat_max < 1.01:
       print("✓ EXCELLENT convergence!")
   ```

2. Effective Sample Size (ESS):
   - ESS > 1000 = Good (for 500 draws × 4 chains = 2000 total)
   - ESS > 400 = Acceptable
   - ESS < 400 = Poor (increase draws)
   
   Example check:
   ```python
   ess_min = summary['ess_bulk'].min()
   print(f"Min ESS: {ess_min:.0f}")
   
   if ess_min > 1000:
       print("✓ Sufficient effective samples!")
   ```

3. Divergences:
   - 0 divergences = Ideal
   - < 1% divergences = Acceptable
   - > 5% divergences = Problematic
   
   Example check:
   ```python
   # Access from trace
   n_divergences = mmm.trace.sample_stats['diverging'].sum().values
   total_samples = len(mmm.trace.posterior.chain) * len(mmm.trace.posterior.draw)
   pct_divergent = n_divergences / total_samples * 100
   
   print(f"Divergences: {n_divergences} ({pct_divergent:.2f}%)")
   ```
""")

## Posterior Predictive Checks

In [None]:
# Posterior Predictive Checks

print("="*70)
print("2. POSTERIOR PREDICTIVE CHECKS")
print("="*70)

print("""
Posterior predictive checks validate model fit by comparing:
- Actual observed sales
- Model-predicted sales (from posterior)

Key metrics:

1. MAPE (Mean Absolute Percentage Error):
   - MAPE < 10% = Excellent fit
   - MAPE < 20% = Good fit
   - MAPE > 20% = Poor fit (model may need refinement)

2. R² (Coefficient of Determination):
   - R² > 0.9 = Excellent fit
   - R² > 0.7 = Good fit
   - R² < 0.5 = Poor fit

3. 95% CI Coverage:
   - Should be ~95% (actual sales within credible interval)
   - < 90% = Model underestimates uncertainty
   - > 98% = Model overestimates uncertainty

Example implementation:
""")

print("""
```python
import arviz as az

# Generate posterior predictive samples
with mmm.model:
    ppc = az.sample_posterior_predictive(
        mmm.trace,
        var_names=['y_obs'],
        random_seed=42
    )

# Extract predictions
y_pred = ppc.posterior_predictive['y_obs'].values
y_pred_mean = y_pred.mean(axis=(0, 1))
y_pred_lower = np.percentile(y_pred, 2.5, axis=(0, 1))
y_pred_upper = np.percentile(y_pred, 97.5, axis=(0, 1))

# Calculate metrics
actual_sales = df['revenue'].values
residuals = actual_sales - y_pred_mean

# MAPE
mape = np.mean(np.abs(residuals / actual_sales)) * 100
print(f"MAPE: {mape:.2f}%")

# R²
r2 = 1 - np.sum(residuals**2) / np.sum((actual_sales - actual_sales.mean())**2)
print(f"R²: {r2:.3f}")

# Coverage
coverage = np.mean((actual_sales >= y_pred_lower) & (actual_sales <= y_pred_upper)) * 100
print(f"95% CI Coverage: {coverage:.1f}%")
```
""")

# Visualization example
print("\nVisualization:")
print("""
```python
fig, ax = plt.subplots(figsize=(14, 6))

# Plot actual sales
ax.plot(df['Date'], actual_sales, 'o-', color='black', 
        linewidth=2, markersize=3, label='Actual Sales', alpha=0.7)

# Plot predicted mean
ax.plot(df['Date'], y_pred_mean, '-', color='#2E86AB', 
        linewidth=2, label='Predicted Mean')

# Plot credible interval
ax.fill_between(df['Date'], y_pred_lower, y_pred_upper,
                 color='#2E86AB', alpha=0.2, label='95% CI')

ax.set_xlabel('Date', fontsize=12, fontweight='bold')
ax.set_ylabel('Sales Revenue ($)', fontsize=12, fontweight='bold')
ax.set_title('Posterior Predictive Check: Model Fit Quality', 
             fontsize=14, fontweight='bold', pad=20)
ax.legend(loc='upper left')
ax.grid(True, alpha=0.3)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
```
""")

## Prior Sensitivity Analysis

In [None]:
# Parameter Interpretation

print("="*70)
print("3. PARAMETER INTERPRETATION")
print("="*70)

print("""
Understanding model parameters helps validate business logic:

## UCM-MMM Parameters

1. **Adstock (α)**: Carryover rate
   ```python
   alpha = mmm.trace.posterior['alpha'].mean(dim=['chain', 'draw']).values
   
   for i, channel in enumerate(marketing_channels):
       print(f"{channel}: α = {alpha[i]:.3f}")
       
   # Interpretation:
   # α = 0.3 → Effects decay quickly (30% remains next week)
   # α = 0.5 → Moderate persistence (half-life of ~1 week)
   # α = 0.7 → Strong carryover (half-life of ~2 weeks)
   ```

2. **Saturation (λ, κ)**: Diminishing returns
   ```python
   lambda_vals = mmm.trace.posterior['lambda'].mean(dim=['chain', 'draw']).values
   kappa_vals = mmm.trace.posterior['kappa'].mean(dim=['chain', 'draw']).values
   
   for i, channel in enumerate(marketing_channels):
       print(f"{channel}:")
       print(f"  λ (half-saturation): ${lambda_vals[i]:,.0f}")
       print(f"  κ (shape): {kappa_vals[i]:.2f}")
       
   # Interpretation:
   # λ = Half-saturation point (spend where effectiveness drops 50%)
   # κ > 1 → S-shaped curve (slow start, rapid growth, plateau)
   # κ < 1 → Rapid initial returns, then diminishing
   ```

3. **Channel Effects (β)**: Direct impact
   ```python
   # After hierarchical effects
   beta = mmm.trace.posterior['beta_channel'].mean(dim=['chain', 'draw']).values
   
   for i, channel in enumerate(marketing_channels):
       print(f"{channel}: β = {beta[i]:.4f}")
       
   # Interpretation:
   # Higher β = stronger direct effect on sales
   # Can compare relative strength across channels
   ```

## BVAR Parameters

1. **VAR Coefficients (A)**: Lag effects
   ```python
   A_lag1 = bvar.trace.posterior['A_lag1'].mean(dim=['chain', 'draw']).values
   
   # Shows how each variable affects itself and others
   # Large positive values = strong persistence
   # Negative values = mean reversion
   ```

2. **Exogenous Effects (B)**: Marketing impact
   ```python
   B = bvar.trace.posterior['B'].mean(dim=['chain', 'draw']).values
   
   # B[i, j] = effect of marketing channel j on outcome i
   # Positive values = marketing increases brand/sales
   ```

## Validation Checklist:

✓ Parameters have reasonable magnitude
✓ Sign (positive/negative) makes business sense  
✓ Credible intervals don't include extreme values
✓ Hierarchical effects show expected grouping
✓ Adstock/saturation curves look plausible
""")