# Media Mix Modeling (MMM) Tutorial

This notebook demonstrates how to build and analyze a Marketing Mix Model using PyMC Marketing.

## Prepare Notebook

Let's import the necessary libraries:

In [1]:
import sys
sys.path.insert(0, ".")

In [5]:
from pymc_marketing.mmm.mmm import MMM
from pymc_marketing.mmm.components import adstock, saturation
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from pymc_extras.prior import Prior
import pytensor

In [None]:


az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%config InlineBackend.figure_format = "retina"

%load_ext autoreload
%autoreload 2

In [None]:
# Set random seed for reproducibility
seed = sum(map(ord, "mmm"))
rng = np.random.default_rng(seed=seed)

## Load Data

We'll use a synthetic dataset that simulates weekly sales data along with spend on two marketing channels (x1 and x2), plus some control variables for special events.

In [None]:
# Load the data
url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/data/mmm_example.csv"
data = pd.read_csv(url, parse_dates=["date_week"])

print(f"Data shape: {data.shape}")
data.head()

Let's visualize our target variable (sales) and the media spend over time:

In [None]:
fig, axes = plt.subplots(3, 1, figsize=(12, 9), sharex=True)

# Sales
axes[0].plot(data["date_week"], data["y"], color="black", linewidth=2)
axes[0].set(ylabel="Sales", title="Target Variable: Sales")

# Channel 1
axes[1].plot(data["date_week"], data["x1"], color="C0", linewidth=2)
axes[1].set(ylabel="Spend", title="Channel x1")

# Channel 2
axes[2].plot(data["date_week"], data["x2"], color="C1", linewidth=2)
axes[2].set(xlabel="Date", ylabel="Spend", title="Channel x2");

## Feature Engineering

For our MMM model, we'll include:

- **Trend**: A linear trend to capture long-term growth
- **Seasonality**: Yearly seasonality (handled automatically by the model)
- **Events**: Binary indicators for special events
- **Media channels**: Our two advertising channels

In [None]:
# Add a simple linear trend feature
data["t"] = range(len(data))

# Split into features (X) and target (y)
X = data.drop("y", axis=1)
y = data["y"]

print(f"Features: {X.columns.tolist()}")

## Model Specification

Now we'll configure our MMM model. The key components are:

- **Adstock transformation**: We use GeometricAdstock with a maximum lag of 8 weeks
- **Saturation transformation**: We use LogisticSaturation to capture diminishing returns
- **Priors**: We can customize priors based on domain knowledge

### Setting Priors

One powerful feature of Bayesian modeling is the ability to incorporate prior knowledge. Here's a simple heuristic for channel priors based on spend share:

In [None]:
# Calculate spend share for each channel
total_spend_per_channel = data[["x1", "x2"]].sum(axis=0)
spend_share = total_spend_per_channel / total_spend_per_channel.sum()

print("Spend share per channel:")
print(spend_share)

# Use spend share to inform prior on channel contributions
n_channels = 2
prior_sigma = n_channels * spend_share.to_numpy()

print(f"\nPrior sigma for channels: {prior_sigma}")

Now let's define our model configuration:

In [None]:
my_model_config = {
    "intercept": Prior("Normal", mu=0.5, sigma=0.2),
    "saturation_beta": Prior("HalfNormal", sigma=prior_sigma),
    "gamma_control": Prior("Normal", mu=0, sigma=0.05),
    "gamma_fourier": Prior("Laplace", mu=0, b=0.2),
    "likelihood": Prior("Normal", sigma=Prior("HalfNormal", sigma=6)),
}

# Sampler configuration
my_sampler_config = {"progressbar": True}

# Initialize the MMM model
mmm = MMM(
    model_config=my_model_config,
    sampler_config=my_sampler_config,
    date_column="date_week",
    adstock=GeometricAdstock(l_max=8),
    saturation=LogisticSaturation(),
    channel_columns=["x1", "x2"],
    control_columns=["event_1", "event_2", "t"],
    yearly_seasonality=2,
)

## Prior Predictive Check

> **Tip**: The prior predictive check is a good way to check that our priors are reasonable. Hence, it is strongly recommended to perform this check before fitting the model. If you are new to Bayesian modeling, take a look into our Prior Predictive Modeling guide notebook.

Before fitting, let's check that our priors are reasonable:

In [None]:
# Generate prior predictive samples
mmm.sample_prior_predictive(X, y, samples=1_000, random_seed=rng)

# Plot prior predictive distribution
fig, ax = plt.subplots(figsize=(12, 6))
mmm.plot_prior_predictive(ax=ax, original_scale=True)
ax.legend(loc="upper left")
ax.set_title("Prior Predictive Check");

Overall, the prior predictive check looks good.

## Model Fitting

Now let's fit the model to our data using MCMC sampling. Observe that we can use different samplers by passing the `nuts_sampler` argument. For instance, we can use `numpyro`, `nutpie`, or `blackjax` samplers (see Other NUTS Samplers for more details).

In [None]:
# Fit the model
_ = mmm.fit(
    X=X,
    y=y,
    chains=4,
    target_accept=0.85,
    nuts_sampler="numpyro",
    random_seed=rng,
)

## Model Diagnostics

After fitting, we should check the model quality. Let's start with divergences:

In [None]:
# Check for divergences
n_divergences = mmm.idata["sample_stats"]["diverging"].sum().item()
print(f"Number of divergences: {n_divergences}")

if n_divergences == 0:
    print("✓ No divergences - sampling was successful!")
else:
    print("⚠ Warning: Model had divergences. Consider increasing target_accept.")

### Parameter Summary

Let's examine the estimated parameters:

In [None]:
# Plot traces for key parameters
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=[
        "saturation_beta",
        "saturation_lam",
        "adstock_alpha",
    ],
    compact=True,
    backend_kwargs={"figsize": (10, 6), "layout": "constrained"},
)
plt.gcf().suptitle("Trace Plots", fontsize=16);

Good trace plots should show:

- **Left side**: Smooth, bell-shaped distributions
- **Right side**: "Fuzzy caterpillar" patterns (good mixing) with no trends

## Posterior Predictive Check

How well does our model fit the observed data?

In [None]:
# Sample from posterior predictive distribution
mmm.sample_posterior_predictive(X, extend_idata=True, combined=True)

# Plot model fit
fig = mmm.plot_posterior_predictive(original_scale=True)

The model captures the observed data well if the black dots (actual sales) fall within the shaded uncertainty bands.

## Contribution Analysis

Now for the fun part—understanding how much each component contributes to sales!

### Component Contributions Over Time

Let's visualize the contribution of each component of the model over time:

In [None]:
fig = mmm.plot_components_contributions(original_scale=True)
plt.suptitle("Component Contributions to Sales", fontsize=16, y=1.02);

We see that we have captured the linear trend, events contributions and the seasonalities in the data. The remaining variation is due to the media channels, which is exactly what we want to understand.

### Waterfall Chart: Total Contribution by Component

A waterfall chart shows the total contribution of each component across the entire time period:

In [None]:
# Waterfall decomposition
fig = mmm.plot_waterfall_components_decomposition();

This chart answers the question: "How much did each component contribute to total sales?"

### Channel Contribution Share

What percentage of media-driven sales comes from each channel?

In [None]:
# Plot channel contribution share
fig = mmm.plot_channel_contribution_share_hdi(figsize=(7, 5));

### Direct Contribution Curves

These curves show the relationship between spend and contribution, accounting for saturation:

In [None]:
# Plot direct contribution curves (saturation curves)
fig = mmm.plot_direct_contribution_curves()
plt.suptitle("Direct Contribution Curves", fontsize=16, y=1.02);

Notice how the curves flatten at higher spend levels—this is the saturation effect in action!

### Channel Contribution Grid

A complementary view of the media performance is to evaluate the channel contribution at different share spend levels for the complete training period. Concretely, if we denote by α the input channel data percentage level, so that for α=1 we have the model input spend data and for α=1.5 we have a 50% increase in the spend, then we can compute the channel contribution at a grid of α-values and plot the results:

In [None]:
mmm.plot_channel_contribution_grid(start=0, stop=1.5, num=12, absolute_xrange=True);