# 04 - Channel Decomposition & Contribution Analysis

MMM **decomposition** breaks the observed total GMV into the additive contributions of each
marketing channel, baseline, and control variables.

This notebook covers:
1. Posterior predictive check — does the model reproduce observed data?
2. Waterfall decomposition — what drives total GMV?
3. Channel contributions over time — when did each channel contribute most?
4. Contribution share with uncertainty (94% HDI)
5. Adstock decay — how long does each channel's effect last?
6. Saturation — at what spend level does each channel plateau?
7. Summary table saved to `outputs/tables/`

> **Note on convergence:** Results depend on the fitted model's convergence status.
> See notebook 03 for diagnostics. If the model has not fully converged,
> posteriors are prior-dominated and should be interpreted as illustrative, not definitive.

In [None]:
import warnings
from datetime import datetime

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

warnings.filterwarnings("ignore")
sns.set_theme(style="whitegrid")
%matplotlib inline

from pymc_marketing.mmm import MMM

from mmm_demo.config import OUTPUTS_DIR, ModelConfig
from mmm_demo.data import load_mmm_weekly_data
from mmm_demo.diagnostics import ESS_THRESHOLD, RHAT_THRESHOLD, check_convergence
from mmm_demo.model import sample_posterior_predictive

---
## 1. Load Model and Data

> **Prerequisite:** Run notebooks 02 and 03 first.
> This notebook loads the most recent `.nc` file from `outputs/models/`.
> After running notebook 03, the v2 model will be loaded automatically.

In [None]:
df = load_mmm_weekly_data()
config = ModelConfig()
feature_cols = [config.date_column, *config.channel_columns, *config.control_columns]
X = df[feature_cols]
y = df[config.target_column]

# Load most recent saved model
model_dir = OUTPUTS_DIR / "models"
model_files = sorted(model_dir.glob("mmm_fit_*.nc"))
if not model_files:
    raise FileNotFoundError("No saved model found. Run notebook 02 first.")

model_path = model_files[-1]
print(f"Loading: {model_path.name}")
mmm = MMM.load(str(model_path))
idata = mmm.idata
print(
    f"Loaded. Chains: {idata.posterior.dims['chain']}, Draws: {idata.posterior.dims['draw']}"
)

---
## 2. Diagnostic Gate Check

In [None]:
diag = check_convergence(idata)

print(f"Convergence: {'PASSED' if diag.passed else 'FAILED'}")
print(f"  Max R-hat:   {diag.max_rhat:.4f}  (threshold < {RHAT_THRESHOLD})")
print(f"  Min ESS:     {diag.min_ess:.0f}   (threshold > {ESS_THRESHOLD})")
print(f"  Divergences: {diag.divergences}")
print()

if diag.passed:
    print("All convergence checks passed. Results are reliable.")
else:
    print("WARNING: Convergence checks FAILED.")
    print("Proceeding for demo purposes — results are illustrative.")
    print("The posterior is influenced primarily by priors, not data.")
    print("See notebook 03 for root cause analysis and remediation.")

---
## 3. Posterior Predictive Check

A posterior predictive check verifies that the fitted model can **reproduce the observed data**.
If the observed line falls consistently outside the posterior predictive bands, the model is
misspecified or has not converged.

In [None]:
# Sample posterior predictive if not already present
if "posterior_predictive" not in list(idata.groups()):
    print("Sampling posterior predictive...")
    sample_posterior_predictive(mmm, X)
    idata = mmm.idata
    print("Done.")
else:
    print("Posterior predictive already in idata.")

fig = mmm.plot_posterior_predictive(original_scale=True)
if hasattr(fig, "suptitle"):
    fig.suptitle(
        "Posterior Predictive Check\n"
        "(observed should fall within the predictive bands)",
        fontsize=12,
    )
plt.tight_layout()
plt.show()

---
## 4. Waterfall Decomposition

The waterfall plot shows the **cumulative contribution** of each component to total GMV.
Each bar represents one component's average weekly contribution (in original GMV units).

In [None]:
fig = mmm.plot_waterfall_components_decomposition(original_scale=True, figsize=(14, 7))
plt.tight_layout()
plt.show()

---
## 5. Channel Contributions Over Time

Shows how each channel's contribution varies week by week.
The shaded bands are the 94% highest density interval (HDI) — the uncertainty around each estimate.

In [None]:
fig = mmm.plot_components_contributions()
if hasattr(fig, "suptitle"):
    fig.suptitle("Channel Contributions Over Time (with 94% HDI)", fontsize=12)
plt.tight_layout()
plt.show()

In [None]:
# Grouped breakdown: Offline vs Online vs Baseline vs Controls
fig = mmm.plot_grouped_contribution_breakdown_over_time(
    stack_groups={
        "Baseline": ["intercept"],
        "Offline": ["TV", "Sponsorship"],
        "Online": ["Digital", "Online"],
        "Controls": ["NPS", "total_Discount", "sale_days"],
    },
    original_scale=True,
    figsize=(14, 6),
)
plt.suptitle("Grouped Contribution Breakdown Over Time", fontsize=12)
plt.tight_layout()
plt.show()

---
## 6. Contribution Share with Uncertainty

Channel contribution share with **94% HDI** shows the uncertainty around each channel's
estimated share of total GMV. Wide HDIs indicate the model cannot reliably attribute
GMV to specific channels — this is expected with a prior-dominated model.

In [None]:
fig = mmm.plot_channel_contribution_share_hdi(hdi_prob=0.94)
plt.tight_layout()
plt.show()

---
## 7. Mean Contributions Summary Table

In [None]:
contributions = mmm.compute_mean_contributions_over_time(original_scale=True)

# Average weekly contribution per component
avg_contributions = contributions.mean()
total_contrib = avg_contributions.sum()
share_pct = (avg_contributions / total_contrib * 100).round(2)

contrib_summary = pd.DataFrame(
    {
        "avg_weekly_contribution": avg_contributions.round(0),
        "share_%": share_pct,
    }
).sort_values("share_%", ascending=False)

print("Average weekly contribution by component:")
print(contrib_summary.to_string())
print(f"\nTotal (should match mean GMV ~ {y.mean():.0f}): {total_contrib:.0f}")

In [None]:
# Visualize contribution shares
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Bar chart
contrib_summary["avg_weekly_contribution"].plot.bar(
    ax=axes[0], edgecolor="black", color="steelblue"
)
axes[0].set_title("Average Weekly Contribution by Component")
axes[0].set_ylabel("GMV")
axes[0].tick_params(axis="x", rotation=45)

# Pie chart (channels only)
channel_contribs = contrib_summary[contrib_summary.index.isin(config.channel_columns)]
if len(channel_contribs) > 0:
    channel_contribs["avg_weekly_contribution"].plot.pie(
        ax=axes[1], autopct="%1.1f%%", startangle=90
    )
    axes[1].set_title("Channel Contribution Share (marketing only)")
    axes[1].set_ylabel("")

plt.tight_layout()
plt.show()

---
## 8. Adstock Decay

**Adstock** models the lagged carryover effect of advertising:
```
adstock_t = spend_t + alpha * adstock_{t-1}
```
where `alpha ∈ (0, 1)` is the decay rate. Higher alpha → longer-lasting effect.

The prior vs posterior comparison shows:
- If the posterior closely tracks the prior → data hasn't informed the parameter much
- If the posterior is sharper/shifted → the data has updated the prior

In [None]:
fig = mmm.plot_prior_vs_posterior("adstock_alpha", alphabetical_sort=True)
plt.suptitle("Adstock Decay (alpha): Prior vs Posterior", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# Posterior estimates for adstock_alpha with decay interpretation
summary_alpha = az.summary(idata, var_names=["adstock_alpha"])

print("Adstock decay posterior estimates:")
print(
    f"{'Channel':<15} {'alpha mean':>12} {'HDI 3%':>10} {'HDI 97%':>10}  Interpretation"
)
print("-" * 70)

for param, row in summary_alpha.iterrows():
    channel = param.replace("adstock_alpha[", "").rstrip("]")
    alpha = row["mean"]
    remaining_wk2 = alpha**2 * 100
    print(
        f"{channel:<15} {alpha:>12.3f} {row['hdi_3%']:>10.3f} {row['hdi_97%']:>10.3f}  "
        f"{alpha*100:.0f}% carryover to wk+1, {remaining_wk2:.0f}% to wk+2"
    )

print()
print(f"Max lag: {config.adstock_max_lag} weeks")
print()
print("Wide HDIs indicate the data is not strongly informative about decay rates.")

In [None]:
# Decay curves: how much of week-0 spend remains after k weeks
k_max = config.adstock_max_lag
weeks = np.arange(0, k_max + 1)

fig, ax = plt.subplots(figsize=(10, 5))
colors = plt.cm.Set2(np.linspace(0, 1, len(config.channel_columns)))

for i, channel in enumerate(config.channel_columns):
    param = f"adstock_alpha[{channel}]"
    alpha_mean = summary_alpha.loc[param, "mean"]
    alpha_lo = summary_alpha.loc[param, "hdi_3%"]
    alpha_hi = summary_alpha.loc[param, "hdi_97%"]

    decay_mean = alpha_mean**weeks
    decay_lo = alpha_lo**weeks
    decay_hi = alpha_hi**weeks

    ax.plot(
        weeks,
        decay_mean * 100,
        color=colors[i],
        linewidth=2,
        label=f"{channel} (alpha={alpha_mean:.2f})",
        marker="o",
    )
    ax.fill_between(weeks, decay_lo * 100, decay_hi * 100, alpha=0.15, color=colors[i])

ax.set_xlabel("Weeks after spend")
ax.set_ylabel("% of original effect remaining")
ax.set_title(f"Adstock Decay Curves (posterior mean ± 94% HDI, max_lag={k_max})")
ax.legend(bbox_to_anchor=(1.02, 1), loc="upper left")
ax.set_ylim(0, 105)
plt.tight_layout()
plt.show()

---
## 9. Saturation (Diminishing Returns)

`LogisticSaturation` applies an S-curve to channel spend:
- `lam` (lambda): steepness of the curve — higher lambda → faster saturation
- `beta`: channel effect magnitude — scales the contribution

The prior vs posterior comparison reveals how much the data has informed the saturation shape.

In [None]:
fig = mmm.plot_prior_vs_posterior("saturation_lam", alphabetical_sort=True)
plt.suptitle("Saturation Lambda (steepness): Prior vs Posterior", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
fig = mmm.plot_prior_vs_posterior("saturation_beta", alphabetical_sort=True)
plt.suptitle("Saturation Beta (effect magnitude): Prior vs Posterior", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# Saturation response curves using posterior mean lam and beta
from scipy.special import expit

summary_lam = az.summary(idata, var_names=["saturation_lam"])
summary_beta = az.summary(idata, var_names=["saturation_beta"])

x = np.linspace(0, 1, 200)  # MaxAbsScaled spend: 0 = zero spend, 1 = max observed

fig, axes = plt.subplots(1, len(config.channel_columns), figsize=(14, 4), sharey=False)
colors = plt.cm.Set2(np.linspace(0, 1, len(config.channel_columns)))

for i, channel in enumerate(config.channel_columns):
    lam_param = f"saturation_lam[{channel}]"
    beta_param = f"saturation_beta[{channel}]"

    lam_mean = summary_lam.loc[lam_param, "mean"]
    lam_lo = summary_lam.loc[lam_param, "hdi_3%"]
    lam_hi = summary_lam.loc[lam_param, "hdi_97%"]
    beta_mean = summary_beta.loc[beta_param, "mean"]

    # LogisticSaturation in PyMC-Marketing: 2*expit(lam*x) - 1, scaled by beta
    # (equivalent to tanh(lam*x/2) in the limit)
    y_mean = beta_mean * (2 * expit(lam_mean * x) - 1)
    y_lo = beta_mean * (2 * expit(lam_lo * x) - 1)
    y_hi = beta_mean * (2 * expit(lam_hi * x) - 1)

    axes[i].plot(x, y_mean, color=colors[i], linewidth=2, label="Posterior mean")
    axes[i].fill_between(x, y_lo, y_hi, alpha=0.25, color=colors[i], label="94% HDI")
    axes[i].set_title(f"{channel}\nlam={lam_mean:.2f}, beta={beta_mean:.2f}")
    axes[i].set_xlabel("Scaled spend")
    axes[i].set_ylabel("Contribution")
    axes[i].legend(fontsize=7)

plt.suptitle("Saturation Response Curves by Channel (LogisticSaturation)", fontsize=12)
plt.tight_layout()
plt.show()

print("Interpretation:")
print(
    "  - Steeper curves (higher lam) = faster saturation = diminishing returns kick in earlier"
)
print(
    "  - Wide HDI bands = high uncertainty — the data is not informative about saturation shape"
)

---
## 10. Save Contribution Table

In [None]:
# Save full weekly contributions time series
tables_dir = OUTPUTS_DIR / "tables"
tables_dir.mkdir(parents=True, exist_ok=True)

date_str = datetime.now().strftime("%Y-%m-%d")
contrib_path = tables_dir / f"channel_contributions_{date_str}.csv"
contributions.to_csv(contrib_path)
print(f"Weekly contributions saved to: {contrib_path}")

# Save summary table
summary_path = tables_dir / f"contribution_summary_{date_str}.csv"
contrib_summary.to_csv(summary_path)
print(f"Summary table saved to: {summary_path}")

print()
print("Contents of outputs/tables/:")
for f in sorted(tables_dir.glob("*.csv")):
    print(f"  {f.name}")

---
## 11. Conclusions

### What this notebook showed

| Component | Finding |
|-----------|--------|
| **Posterior predictive** | How well the model reproduces observed GMV |
| **Waterfall** | Which components drive total GMV (baseline vs channels vs controls) |
| **Time series** | When each channel contributed most — peak investment months |
| **Contribution share** | How much GMV each channel is estimated to drive, with uncertainty |
| **Adstock** | How long each channel's effect persists (weeks of carryover) |
| **Saturation** | At what spend level each channel starts to plateau |

### Key caveats

- With only 12 effective observations, **contribution shares are prior-dominated**.
  The HDI bands reflect this: wide intervals are honest uncertainty, not a bug.
- Monthly spend distributed uniformly across weeks means **week-level attribution
  cannot distinguish individual weeks within a month** — GMV variation within a month
  is attributed to controls and baseline, not media.
- The results are most useful for **relative ranking** of channels
  (which channels seem to contribute more vs less) rather than absolute attribution.

### Next step

Notebook 05 uses these channel contributions to run budget optimization.