# 06 — Data Science Report: MMM Technical Validation

This report provides a consolidated technical view of the fitted Marketing Mix Model for review by the data science team. It loads the most recently saved model trace from `outputs/models/` and reproduces all key diagnostics, convergence checks, and channel analysis without re-running MCMC sampling. The intended audience is data scientists and ML engineers who need to assess model quality, understand parameter posteriors, and evaluate the reliability of channel attribution results.

## Table of Contents

1. [Setup & Imports](#1-setup--imports)
2. [Model Specification](#2-model-specification)
3. [Convergence Diagnostics](#3-convergence-diagnostics)
4. [Trace Plots](#4-trace-plots)
5. [Posterior Predictive Check](#5-posterior-predictive-check)
6. [Prior vs Posterior](#6-prior-vs-posterior)
7. [Channel Decomposition](#7-channel-decomposition)
8. [Adstock Decay Curves](#8-adstock-decay-curves)
9. [Saturation Response Curves](#9-saturation-response-curves)
10. [Model Limitations](#10-model-limitations)

---
## 1. Setup & Imports

In [None]:
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.special import expit

warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid")
%matplotlib inline

from pymc_marketing.mmm import MMM

from mmm_demo.config import OUTPUTS_DIR, ModelConfig
from mmm_demo.data import load_mmm_weekly_data
from mmm_demo.diagnostics import ESS_THRESHOLD, RHAT_THRESHOLD, check_convergence
from mmm_demo.model import sample_posterior_predictive

print(f"ArviZ version:          {az.__version__}")
print(f"RHAT_THRESHOLD:         {RHAT_THRESHOLD}")
print(f"ESS_THRESHOLD:          {ESS_THRESHOLD}")

---
## 2. Model Specification

### Model Architecture

The model uses PyMC-Marketing's `MMM` class with two transformation layers applied to each channel's spend before it enters the linear predictor:

**GeometricAdstock** — models the lagged carryover (memory) effect of advertising:
$$x^{\text{adstock}}_t = \sum_{l=0}^{L} \alpha^l \cdot x_{t-l}$$
where $\alpha \in (0, 1)$ is the per-channel decay parameter and $L$ is `adstock_max_lag`. Higher $\alpha$ means a longer-lasting advertising effect.

**LogisticSaturation** — models diminishing returns from increased spend:
$$f(x) = \beta \cdot \left(2 \cdot \sigma(\lambda x) - 1\right)$$
where $\sigma$ is the logistic sigmoid, $\lambda$ controls steepness (higher = faster saturation), and $\beta$ scales the channel's overall contribution magnitude. Inputs are MaxAbsScaled to $[0, 1]$ by PyMC-Marketing before this transformation.

In [None]:
# Load weekly data
df = load_mmm_weekly_data()
config = ModelConfig()
feature_cols = [config.date_column, *config.channel_columns, *config.control_columns]
x = df[feature_cols]
y = df[config.target_column]

# Load most recent saved model
model_dir = OUTPUTS_DIR / "models"
model_files = sorted(model_dir.glob("mmm_fit_*.nc"))
if not model_files:
    raise FileNotFoundError(
        "No saved model found in outputs/models/. Run notebook 02 first."
    )

model_path = model_files[-1]
print(f"Loading model: {model_path.name}")
mmm = MMM.load(str(model_path))
idata = mmm.idata
print("Model loaded.")

In [None]:
# Model class and transformation types
print("=" * 55)
print("MODEL CLASS & TRANSFORMATIONS")
print("=" * 55)
print(f"Model class:       {type(mmm).__name__}")
print(f"Adstock type:      {type(mmm.adstock).__name__} (l_max={mmm.adstock.l_max})")
print(f"Saturation type:   {type(mmm.saturation).__name__}")
print()
print(f"Channel columns ({len(mmm.channel_columns)}):")
for ch in mmm.channel_columns:
    print(f"  {ch}")
print()
print(f"Control columns ({len(mmm.control_columns)}):")
for ctrl in mmm.control_columns:
    print(f"  {ctrl}")
print()
print("Sampling hyperparameters:")
print(f"  adstock_max_lag:   {config.adstock_max_lag}")
print(f"  chains:            {config.chains}")
print(f"  draws:             {config.draws}")
print(f"  tune:              {config.tune}")
print(f"  target_accept:     {config.target_accept}")
print("  init:              advi+adapt_diag")

In [None]:
# Prior distributions summary table
model_config = config.get_model_config()

prior_rows = []
for param_name, prior_obj in model_config.items():
    prior_rows.append({"Parameter": param_name, "Prior": str(prior_obj)})

prior_df = pd.DataFrame(prior_rows).set_index("Parameter")
print("Prior distributions:")
print(prior_df.to_string())
print()
print(
    "Note: All channel inputs and the target are MaxAbsScaled to [0, 1] internally.\n"
    "Priors are calibrated for this scaled space."
)

---
## 3. Convergence Diagnostics

MCMC inference is only valid when the sampler has converged — i.e., all chains are exploring the same region of the posterior. Three standard metrics are evaluated:

| Metric | Threshold | Meaning if violated |
|--------|-----------|--------------------|
| **R-hat** | < 1.01 | Chains disagree — parameter not identified |
| **ESS (bulk)** | > 400 | Too few independent samples — estimates unreliable |
| **Divergences** | = 0 | Sampler hit a pathological region of the posterior |

Any failure should be understood before interpreting channel contributions.

In [None]:
diag = check_convergence(idata)

print("=" * 55)
print("CONVERGENCE DIAGNOSTIC SUMMARY")
print("=" * 55)
print(f"Overall:      {'PASSED' if diag.passed else 'FAILED'}")
print()
print(
    f"R-hat         max={diag.max_rhat:.4f}  threshold < {RHAT_THRESHOLD}  "
    f"{'OK' if diag.rhat_ok else 'FAIL'}"
)
print(
    f"ESS (bulk)    min={diag.min_ess:.0f}     threshold > {ESS_THRESHOLD}   "
    f"{'OK' if diag.ess_ok else 'FAIL'}"
)
print(
    f"Divergences   {diag.divergences}           threshold = 0      "
    f"{'OK' if diag.divergences == 0 else 'FAIL'}"
)
print()
print(f"Message: {diag.summary}")

In [None]:
# Full az.summary() for structural parameters only (exclude mu[date] rows)
full_summary = az.summary(idata)
param_mask = ~full_summary.index.str.startswith("mu[")
param_summary = full_summary[param_mask].copy()
param_summary_sorted = param_summary.sort_values("r_hat", ascending=False)

display_cols = ["mean", "sd", "hdi_3%", "hdi_97%", "ess_bulk", "r_hat"]

print(f"Total parameters in model:        {len(full_summary)}")
print(f"Structural parameters:            {len(param_summary)}")
print(
    f"With R-hat > {RHAT_THRESHOLD}:            "
    f"{(param_summary['r_hat'] > RHAT_THRESHOLD).sum()}"
)
print(
    f"With ESS (bulk) < {ESS_THRESHOLD}:         "
    f"{(param_summary['ess_bulk'] < ESS_THRESHOLD).sum()}"
)
print()
print("Parameter summary (sorted by R-hat descending):")
param_summary_sorted[display_cols]

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))

# R-hat bar chart
rhat_values = param_summary_sorted["r_hat"]
rhat_colors = ["tomato" if r > RHAT_THRESHOLD else "steelblue" for r in rhat_values]
rhat_values.plot.bar(ax=axes[0], color=rhat_colors, edgecolor="black", linewidth=0.4)
axes[0].axhline(
    RHAT_THRESHOLD,
    color="red",
    linestyle="--",
    linewidth=1.5,
    label=f"Threshold ({RHAT_THRESHOLD})",
)
axes[0].set_title("R-hat per Structural Parameter\n(red bars = failed threshold)")
axes[0].set_ylabel("R-hat")
axes[0].tick_params(axis="x", rotation=90, labelsize=7)
axes[0].legend()

# ESS bar chart
ess_values = param_summary_sorted["ess_bulk"]
ess_colors = ["tomato" if e < ESS_THRESHOLD else "steelblue" for e in ess_values]
ess_values.plot.bar(ax=axes[1], color=ess_colors, edgecolor="black", linewidth=0.4)
axes[1].axhline(
    ESS_THRESHOLD,
    color="red",
    linestyle="--",
    linewidth=1.5,
    label=f"Threshold ({ESS_THRESHOLD})",
)
axes[1].set_title("ESS (bulk) per Structural Parameter\n(red bars = failed threshold)")
axes[1].set_ylabel("ESS (bulk)")
axes[1].tick_params(axis="x", rotation=90, labelsize=7)
axes[1].legend()

plt.suptitle(
    "Convergence Metrics by Parameter",
    fontsize=13,
    y=1.02,
)
plt.tight_layout()
plt.show()

---
## 4. Trace Plots

Trace plots are the primary visual convergence diagnostic. Each plot shows two panels per parameter:

- **Left (KDE):** Marginal posterior density per chain. Overlapping KDEs from all chains confirm they explored the same distribution.
- **Right (trace):** Sample values over MCMC iterations. A converged chain looks like a **fuzzy caterpillar** — no trends, drift, or sticking. Chains that drift or separate indicate non-convergence.

In [None]:
key_params = ["intercept", "adstock_alpha", "saturation_lam", "saturation_beta"]

az.plot_trace(idata, var_names=key_params, compact=True, figsize=(13, 10))
plt.suptitle(
    "Trace Plots — Key Parameters\n"
    "LEFT: KDE per chain should overlap | RIGHT: trace should look like a fuzzy caterpillar",
    y=1.02,
    fontsize=11,
)
plt.tight_layout()
plt.show()

---
## 5. Posterior Predictive Check

A posterior predictive check (PPC) tests whether the fitted model can reproduce the observed data. Samples are drawn from $p(y^{\text{rep}} \mid y)$ — the posterior predictive distribution — and compared to the actual `total_gmv` time series. If the observed values fall consistently outside the predictive bands, the model is misspecified or the chains have not converged.

In [None]:
# Sample posterior predictive if not already present in idata
if "posterior_predictive" not in list(idata.groups()):
    print("Sampling posterior predictive (this may take a moment)...")
    sample_posterior_predictive(mmm, x)
    idata = mmm.idata
    print("Done.")
else:
    print("Posterior predictive already present in idata — skipping resampling.")

fig = mmm.plot_posterior_predictive(original_scale=True)
if hasattr(fig, "suptitle"):
    fig.suptitle(
        "Posterior Predictive Check\n"
        "Observed series should fall within the posterior predictive bands",
        fontsize=12,
    )
plt.tight_layout()
plt.show()

---
## 6. Prior vs Posterior

Prior vs posterior plots reveal how much the data has updated the prior for each parameter. When the posterior closely tracks the prior (same shape, similar location), the likelihood is weak — the data is not strongly informative for that parameter. When the posterior is sharper or shifted relative to the prior, the data has overridden the prior and is driving inference.

With only 12 effective monthly media patterns in the data, we expect many parameters to remain close to their priors.

In [None]:
# Adstock decay: Beta(1, 3) prior concentrates mass toward 0 (short-lived effects)
# If the posterior shifts right (toward 1), data supports longer carryover
fig = mmm.plot_prior_vs_posterior("adstock_alpha", alphabetical_sort=True)
plt.suptitle(
    "Adstock Decay (alpha): Prior vs Posterior\n"
    "Prior: Beta(1, 3). Posterior shift toward 1 = longer-lasting ad effect",
    fontsize=11,
    y=1.02,
)
plt.tight_layout()
plt.show()

In [None]:
# Saturation steepness: Gamma(3, 1) prior has mode at 2, mean at 3
# Higher lam = faster saturation = diminishing returns kick in at lower spend levels
fig = mmm.plot_prior_vs_posterior("saturation_lam", alphabetical_sort=True)
plt.suptitle(
    "Saturation Lambda (steepness): Prior vs Posterior\n"
    "Prior: Gamma(3, 1). Higher lam = faster saturation curve",
    fontsize=11,
    y=1.02,
)
plt.tight_layout()
plt.show()

In [None]:
# Saturation magnitude: HalfNormal(sigma=2) prior; loosely constrains channel contribution scale
# Beta controls the maximum response a channel can contribute (in MaxAbsScaled target units)
fig = mmm.plot_prior_vs_posterior("saturation_beta", alphabetical_sort=True)
plt.suptitle(
    "Saturation Beta (channel effect magnitude): Prior vs Posterior\n"
    "Prior: HalfNormal(sigma=2). Governs the maximum contribution per channel",
    fontsize=11,
    y=1.02,
)
plt.tight_layout()
plt.show()

---
## 7. Channel Decomposition

MMM decomposition breaks total GMV into additive components: baseline (intercept), each marketing channel's contribution, and control variable effects. The three plots below show decomposition from three perspectives: cumulative (waterfall), over time (contributions timeline), and as a share with uncertainty (HDI).

In [None]:
# Waterfall chart: cumulative average contribution per component
fig = mmm.plot_waterfall_components_decomposition(original_scale=True, figsize=(14, 7))
plt.tight_layout()
plt.show()

In [None]:
# Channel contributions over time with 94% HDI bands
fig = mmm.plot_components_contributions()
if hasattr(fig, "suptitle"):
    fig.suptitle(
        "Channel Contributions Over Time (94% HDI)",
        fontsize=12,
    )
plt.tight_layout()
plt.show()

In [None]:
# Channel contribution share with 94% credible intervals
# Wide HDI bands reflect genuine posterior uncertainty given the small effective sample size
fig = mmm.plot_channel_contribution_share_hdi(hdi_prob=0.94)
plt.tight_layout()
plt.show()

---
## 8. Adstock Decay Curves

GeometricAdstock applies exponential decay over time. The decay curve for a channel with parameter $\alpha$ shows the fraction of the original spend effect that remains $k$ weeks later: $\alpha^k$. A channel with $\alpha = 0.5$ retains 50% of its effect after one week, 25% after two weeks, and so on up to `adstock_max_lag = 4` weeks. Posterior HDI bands reflect how well the data has constrained each channel's decay rate.

In [None]:
summary_alpha = az.summary(idata, var_names=["adstock_alpha"])

print("Adstock decay posterior estimates (adstock_alpha):")
print(
    f"{'Channel':<15} {'mean':>8} {'sd':>8} {'hdi_3%':>8} {'hdi_97%':>9} "
    f"{'ess_bulk':>10} {'r_hat':>7}  Interpretation"
)
print("-" * 90)

for param, row in summary_alpha.iterrows():
    channel = param.replace("adstock_alpha[", "").rstrip("]")
    alpha = row["mean"]
    remaining_wk1 = alpha * 100
    remaining_wk2 = alpha**2 * 100
    print(
        f"{channel:<15} {alpha:>8.3f} {row['sd']:>8.3f} {row['hdi_3%']:>8.3f} "
        f"{row['hdi_97%']:>9.3f} {row['ess_bulk']:>10.0f} {row['r_hat']:>7.3f}  "
        f"{remaining_wk1:.0f}% at wk+1, {remaining_wk2:.0f}% at wk+2"
    )

print()
print(f"adstock_max_lag = {config.adstock_max_lag} weeks")

In [None]:
k_max = config.adstock_max_lag
weeks = np.arange(0, k_max + 1)
colors = plt.cm.Set2(np.linspace(0, 1, len(config.channel_columns)))

fig, ax = plt.subplots(figsize=(10, 5))

for i, channel in enumerate(config.channel_columns):
    param = f"adstock_alpha[{channel}]"
    alpha_mean = summary_alpha.loc[param, "mean"]
    alpha_lo = summary_alpha.loc[param, "hdi_3%"]
    alpha_hi = summary_alpha.loc[param, "hdi_97%"]

    decay_mean = alpha_mean**weeks * 100
    decay_lo = alpha_lo**weeks * 100
    decay_hi = alpha_hi**weeks * 100

    ax.plot(
        weeks,
        decay_mean,
        color=colors[i],
        linewidth=2,
        marker="o",
        label=f"{channel} (alpha={alpha_mean:.2f})",
    )
    ax.fill_between(weeks, decay_lo, decay_hi, alpha=0.18, color=colors[i])

ax.set_xlabel("Weeks after spend")
ax.set_ylabel("% of original effect remaining")
ax.set_title(
    f"Adstock Decay Curves — Posterior Mean with 94% HDI\n"
    f"(GeometricAdstock, max_lag={k_max} weeks)"
)
ax.set_ylim(0, 105)
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left")
plt.tight_layout()
plt.show()

---
## 9. Saturation Response Curves

LogisticSaturation models diminishing returns from increased channel spend. The response curve plots show, for each channel, the predicted contribution as a function of scaled spend (0 = no spend, 1 = maximum observed spend). The formula is $f(x) = \beta \cdot (2 \sigma(\lambda x) - 1)$ where $\sigma$ is the logistic sigmoid. HDI bands reflect uncertainty in $\lambda$ while holding $\beta$ at its posterior mean.

In [None]:
summary_lam = az.summary(idata, var_names=["saturation_lam"])
summary_beta = az.summary(idata, var_names=["saturation_beta"])

print("Saturation parameter posterior estimates:")
print()
print("saturation_lam (steepness — higher = faster saturation):")
display_cols = ["mean", "sd", "hdi_3%", "hdi_97%", "ess_bulk", "r_hat"]
print(summary_lam[display_cols].to_string())
print()
print("saturation_beta (effect magnitude — scales max channel contribution):")
print(summary_beta[display_cols].to_string())

In [None]:
x_spend = np.linspace(0, 1, 200)  # MaxAbsScaled spend: 0 = no spend, 1 = max observed
n_channels = len(config.channel_columns)
colors = plt.cm.Set2(np.linspace(0, 1, n_channels))

fig, axes = plt.subplots(1, n_channels, figsize=(14, 4), sharey=False)

for i, channel in enumerate(config.channel_columns):
    lam_param = f"saturation_lam[{channel}]"
    beta_param = f"saturation_beta[{channel}]"

    lam_mean = summary_lam.loc[lam_param, "mean"]
    lam_lo = summary_lam.loc[lam_param, "hdi_3%"]
    lam_hi = summary_lam.loc[lam_param, "hdi_97%"]
    beta_mean = summary_beta.loc[beta_param, "mean"]

    y_mean = beta_mean * (2 * expit(lam_mean * x_spend) - 1)
    y_lo = beta_mean * (2 * expit(lam_lo * x_spend) - 1)
    y_hi = beta_mean * (2 * expit(lam_hi * x_spend) - 1)

    axes[i].plot(x_spend, y_mean, color=colors[i], linewidth=2, label="Posterior mean")
    axes[i].fill_between(
        x_spend, y_lo, y_hi, alpha=0.25, color=colors[i], label="94% HDI"
    )
    axes[i].set_title(
        f"{channel}\nlam={lam_mean:.2f}, beta={beta_mean:.2f}", fontsize=9
    )
    axes[i].set_xlabel("Scaled spend (0=min, 1=max observed)")
    axes[i].set_ylabel("Contribution (scaled)")
    axes[i].legend(fontsize=7)

plt.suptitle(
    "Saturation Response Curves by Channel\n"
    "f(x) = beta * (2 * sigmoid(lam * x) - 1) — LogisticSaturation, inputs MaxAbsScaled",
    fontsize=11,
)
plt.tight_layout()
plt.show()

print(
    "Steeper curves (higher lam) = diminishing returns kick in at lower spend levels.\n"
    "Wide HDI bands = the data is not strongly informative about the saturation shape."
)

---
## 10. Model Limitations

The following limitations should be understood when interpreting any results from this model:

- **Effective N = 12.** Media spend (from `MediaInvestment.csv`) is only available at monthly granularity. All weeks within a month receive identical channel spend values via pro-rata distribution. The MCMC sampler therefore sees only 12 distinct media patterns — far fewer than the 52 weekly rows suggest. With 17 structural parameters and 12 effective observations, the posterior is prior-dominated for most channel-level parameters.

- **Prior-dominated posteriors.** Because the likelihood is weak relative to the prior, the prior vs posterior plots (Section 6) will often show minimal updating. Posterior estimates for adstock decay, saturation steepness, and channel beta reflect our prior beliefs as much as the data. Results should not be treated as data-driven point estimates.

- **Monthly spend distributed pro-rata across weeks.** Each week within a month receives an equal share of that month's total spend. This artificially smooths within-month spend variation and prevents the model from attributing GMV fluctuations within a month to media activity. Week-level attribution below the monthly grain is unreliable.

- **Single market, single year.** The dataset covers one year (Jul 2015 – Jun 2016) for a single DT Mart market. The model cannot generalize across markets, geographies, or time periods outside this window. Structural shifts in channel effectiveness cannot be detected.

- **No competitor or macroeconomic data.** External factors such as competitor promotions, price changes, and macroeconomic conditions are not included. Any effect they have on GMV will be absorbed by the intercept, control variables, or incorrectly attributed to marketing channels.