# 02 - Model Fitting

Bayesian Marketing Mix Model fitting using PyMC-Marketing.

- 12 monthly observations (Jul 2015 - Jun 2016)
- 7 media channels, 3 control variables
- Informative priors (critical with only 12 data points)

## Workflow
1. Configure model
2. Load & preprocess data
3. Initialize model
4. Prior predictive check
5. Fit (MCMC sampling)
6. Convergence diagnostics (gate)
7. Posterior predictive check
8. Channel decomposition & contribution analysis

In [None]:
import warnings

import arviz as az
import matplotlib.pyplot as plt
import seaborn as sns

warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid")
%matplotlib inline

from mmm_test.config import OUTPUTS_DIR, ModelConfig
from mmm_test.data import load_mmm_data
from mmm_test.diagnostics import check_convergence, validate_before_interpretation
from mmm_test.model import (
    build_model,
    fit_model,
    sample_posterior_predictive,
    sample_prior_predictive,
)
from mmm_test.plotting import plot_channel_contributions, plot_trace

---
## Step 1: Configure Model

In [None]:
config = ModelConfig()

print(f"Date column: {config.date_column}")
print(f"Target: {config.target_column}")
print(f"Channels ({len(config.channel_columns)}): {config.channel_columns}")
print(f"Controls: {config.control_columns}")
print(f"Adstock max lag: {config.adstock_max_lag}")
print(f"Target accept: {config.target_accept}")
print(f"Chains: {config.chains}, Draws: {config.draws}, Tune: {config.tune}")
print("\nPriors (model_config):")
for k, v in config.get_model_config().items():
    print(f"  {k}: {v}")

---
## Step 2: Load & Preprocess Data

In [None]:
df = load_mmm_data()
print(f"Shape: {df.shape}")
print(f"Date range: {df['Date'].min()} to {df['Date'].max()}")
print(
    f"Target (total_gmv) range: {df['total_gmv'].min():.0f} to {df['total_gmv'].max():.0f}"
)
df[["Date", "total_gmv"] + config.channel_columns + config.control_columns]

In [None]:
# Channel spend summary
print("Channel spend summary (min / mean / max):")
for ch in config.channel_columns:
    print(f"  {ch}: {df[ch].min():.0f} / {df[ch].mean():.0f} / {df[ch].max():.0f}")
print("\nControls:")
print(f"  NPS: {df['NPS'].min():.1f} - {df['NPS'].max():.1f}")
print(
    f"  Discount: {df['total_Discount'].min():.0f} - {df['total_Discount'].max():.0f}"
)
print(f"  Sale days: {df['sale_days'].min()} - {df['sale_days'].max()}")

In [None]:
# Prepare X (features) and y (target)
feature_cols = [
    config.date_column,
    *config.channel_columns,
    *config.control_columns,
]
X = df[feature_cols]
y = df[config.target_column]

print(f"X shape: {X.shape}")
print(f"y shape: {y.shape}")
print(f"X columns: {X.columns.tolist()}")

---
## Step 3: Initialize Model

In [None]:
mmm = build_model(config)

print(f"Model type: {type(mmm).__name__}")
print(f"Date column: {mmm.date_column}")
print(f"Channel columns: {mmm.channel_columns}")
print(f"Control columns: {mmm.control_columns}")
print(f"Adstock: {mmm.adstock} (l_max={mmm.adstock.l_max})")
print(f"Saturation: {mmm.saturation}")

---
## Step 4: Prior Predictive Check

Before fitting, verify that the priors produce plausible total_gmv values.
With only 12 data points, the priors will dominate the posterior.

In [None]:
sample_prior_predictive(mmm, X, y, samples=500)

fig = mmm.plot_prior_predictive(original_scale=True)
if hasattr(fig, "suptitle"):
    fig.suptitle("Prior Predictive Check", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Assess prior predictive range
prior_pred = mmm.idata.prior_predictive
print(f"Observed target range: {y.min():.0f} to {y.max():.0f}")
print(f"Observed target mean: {y.mean():.0f}")
print()
print(
    "If prior predictive range is unrealistic, adjust priors in config.get_model_config()"
)

---
## Step 5: Fit Model (MCMC Sampling)

This will take several minutes. Running 4 chains x 1000 draws + 1000 tune,
with target_accept=0.95 to reduce divergences.

In [None]:
%%time
idata = fit_model(mmm, X, y, config)
print("\nSampling complete.")
print(f"Posterior dims: {dict(idata.posterior.dims)}")

---
## Step 6: Convergence Diagnostics

**Gate check** - all must pass before interpreting results:
- R-hat < 1.01 for all parameters
- ESS (bulk) > 400 for all parameters
- Zero divergences

In [None]:
diag = check_convergence(idata)

print(f"PASSED: {diag.passed}")
print(f"R-hat OK: {diag.rhat_ok} (max: {diag.max_rhat:.4f})")
print(f"ESS OK: {diag.ess_ok} (min: {diag.min_ess:.0f})")
print(f"Divergences: {diag.divergences}")
print(f"\nSummary: {diag.summary}")

In [None]:
# Trace plots for visual inspection
trace_path = plot_trace(idata)
print(f"Trace plot saved to: {trace_path}")

In [None]:
# Full parameter summary
summary = az.summary(idata)
summary

In [None]:
# Gate: raise if diagnostics fail
validate_before_interpretation(idata)
print("Convergence gate PASSED - safe to interpret results.")

---
## Step 7: Posterior Predictive Check

Verify the fitted model can reproduce the observed data.

In [None]:
sample_posterior_predictive(mmm)

fig = mmm.plot_posterior_predictive(original_scale=True)
if hasattr(fig, "suptitle"):
    fig.suptitle("Posterior Predictive Check", fontsize=14)
plt.tight_layout()
plt.show()

---
## Step 8: Channel Decomposition & Contributions

In [None]:
# Channel contributions plot
contrib_path = plot_channel_contributions(mmm)
print(f"Channel contributions plot saved to: {contrib_path}")

In [None]:
# Waterfall decomposition
fig = mmm.plot_waterfall_components_decomposition(original_scale=True, figsize=(14, 8))
plt.show()

In [None]:
# Channel contribution share with HDI (credible intervals)
fig = mmm.plot_channel_contribution_share_hdi(hdi_prob=0.94)
plt.show()

In [None]:
# Grouped contribution breakdown over time
fig = mmm.plot_grouped_contribution_breakdown_over_time(
    stack_groups={
        "Baseline": ["intercept"],
        "Offline": ["TV", "Sponsorship"],
        "Online": [
            "Digital",
            "Content.Marketing",
            "Online.marketing",
            "Affiliates",
            "SEM",
        ],
        "Controls": ["NPS", "total_Discount", "sale_days"],
    },
    original_scale=True,
    figsize=(14, 6),
)
plt.show()

### Adstock & Saturation Curves

In [None]:
# Prior vs posterior for adstock decay
fig = mmm.plot_prior_vs_posterior("adstock_alpha", alphabetical_sort=True)
plt.suptitle("Adstock Decay: Prior vs Posterior", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# Prior vs posterior for saturation parameters
fig = mmm.plot_prior_vs_posterior("saturation_lam", alphabetical_sort=True)
plt.suptitle("Saturation Lambda: Prior vs Posterior", fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
fig = mmm.plot_prior_vs_posterior("saturation_beta", alphabetical_sort=True)
plt.suptitle("Saturation Beta (Channel Effect): Prior vs Posterior", fontsize=14)
plt.tight_layout()
plt.show()

### Mean Contributions

In [None]:
# Compute mean contributions over time
contributions = mmm.compute_mean_contributions_over_time(original_scale=True)
print("Mean contributions over time:")
contributions

---
## Save Model

In [None]:
from datetime import datetime

date_str = datetime.now().strftime("%Y-%m-%d")
model_path = OUTPUTS_DIR / "models" / f"mmm_fit_{date_str}_v1.nc"
model_path.parent.mkdir(parents=True, exist_ok=True)
mmm.save(str(model_path))
print(f"Model saved to: {model_path}")

---
## Summary

### Results
- Model fitted with 4 chains x 1000 draws
- Convergence diagnostics: check results above
- Channel contributions decomposed with uncertainty (HDI)
- Adstock and saturation parameters estimated per channel

### Key Caveats
- Only 12 data points - priors heavily influence posterior estimates
- Wide credible intervals expected due to small sample size
- Channel estimates may be unreliable where multicollinearity is high

### Next Steps
- Notebook 03: Budget optimization using the fitted model
- Sensitivity analysis on priors (how much do results change with different priors?)
- Consider aggregating daily data to weekly for more observations