# Learning with FB-Google
Trying to use `pymc-marketing`with simulated data.
## Independent and instant
No AdStock, no saturation, no causal effect from FB to Google.


# Part I: Generating data
Independent sales instant effect.
Following: https://github.com/pymc-labs/pymc-marketing/blob/main/docs/source/notebooks/mmm/mmm_example.ipynb

In [None]:
import warnings

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import seaborn as sns

# from pymc_marketing.mmm.delayed_saturated_mmm import MMM
# from pymc_marketing.mmm.delayed_saturated_mmm import DelayedSaturatedMMM
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation

from utilfbgoogle import MelkDelayedSaturatedMMM


warnings.filterwarnings("ignore", category=FutureWarning)

az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

## Date Range
First we set a time range for our data. We consider one year at daily granularity

In [None]:
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

# date range
min_date = pd.to_datetime("2024-01-01")
# min_date = pd.to_datetime("2024-04-16")
max_date = pd.to_datetime("2024-06-30")

df = pd.DataFrame(
    data={"date": pd.date_range(start=min_date, end=max_date, freq="D")}
).assign(
    year=lambda x: x["date"].dt.year,
    month=lambda x: x["date"].dt.month,
    dayofyear=lambda x: x["date"].dt.dayofyear,
)

n = df.shape[0]
print(f"Number of observations: {n}")

## Media spend
- Facebook = `x1`
  - Spiky
- Google = `x2`
  - More uniform

In [None]:
# media data - spend in thousands
x1 = rng.uniform(low=0.0, high=1.0, size=n)
df["x1"] = np.where(x1 > 0.9, x1, x1 / 2) * 10

x2 = rng.uniform(low=0.0, high=1.0, size=n) * 10 
df["x2"] = x2 


fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(10, 7), sharex=True, sharey=True, layout="constrained"
)
sns.lineplot(x="date", y="x1", data=df, color="C0", ax=ax[0])
sns.lineplot(x="date", y="x2", data=df, color="C1", ax=ax[1])
ax[1].set(xlabel="week")
fig.suptitle("Media Costs Data", fontsize=16);

## Target variable

In [None]:
df["intercept"] = 0.5
df["epsilon"] = rng.normal(loc=0.0, scale=0.25, size=n)

amplitude = 2.0
beta_1 = 0.5
beta_2 = 1.0
betas = [beta_1, beta_2]


df["y"] = amplitude * (
    df["intercept"]
    + beta_1 * df["x1"]
    + beta_2 * df["x2"]
    + df["epsilon"]
)

fig, ax = plt.subplots()
sns.lineplot(x="date", y="y", color="black", data=df, ax=ax)
ax.set(title="Sales (Target Variable)", xlabel="date", ylabel="y (thousands)");

The true contributions that we have embedded in the data are these

In [None]:
fig, ax = plt.subplots()

contributions = (np.array(
        [
                df["intercept"].sum(),
                (beta_1 * df["x1"]).sum(),
                (beta_2 * df["x2"]).sum(),
        ] 
) * 2).tolist()

ax.bar(
    ["intercept", "x1", "x2"],
    contributions,
    color=["C0" if x >= 0 else "C3" for x in contributions],
    alpha=0.8,
)
ax.bar_label(
    ax.containers[0],
    fmt="{:,.2f}",
    label_type="edge",
    padding=2,
    fontsize=15,
    fontweight="bold",
)
ax.set(title="Sales Attribution", ylabel="Sales (thousands)");

## Media Contribution Interpretation
From the data generating process we can compute the relative contribution of each channel to the target variable. We will recover these values back from the model.

In [None]:
contribution_share_x1: float = (beta_1 * df["x1"]).sum() / (
    beta_1 * df["x1"] + beta_2 * df["x2"]
).sum()

contribution_share_x2: float = (beta_2 * df["x2"]).sum() / (
    beta_1 * df["x1"] + beta_2 * df["x2"]
).sum()

print(f"Contribution Share of x1: {contribution_share_x1:.2f}")
print(f"Contribution Share of x2: {contribution_share_x2:.2f}")

We can obtain the contribution plots for each channel per day where we see the direct relation between spend and ideal contribution

In [None]:
fig, ax = plt.subplots(
    nrows=2, ncols=1, figsize=(12, 8), sharex=True, sharey=False, layout="constrained"
)

for i, x in enumerate(["x1", "x2"]):
    sns.scatterplot(
        x=df[x],
        y=betas[i] * df[f"{x}"],
        color=f"C{i}",
        ax=ax[i],
    )
    ax[i].set(
        title=f"$x_{i + 1}$ contribution",
        ylabel=f"$\\beta_{i + 1} \\cdot x_{i + 1}$ ",
        xlabel="x",
    )

# Spend and ROAS

In [None]:
fig, ax = plt.subplots(figsize=(7, 5))
df[["x1", "x2"]].sum().plot(kind="bar", color=["C0", "C1"], ax=ax)
ax.set(title="Total Media Spend", xlabel="Media Channel", ylabel="Costs (thousands)");

In looking at the ROAS, we directly reccover the betas * amplitude

In [None]:
roas_1 = (amplitude * beta_1 * df["x1"]).sum() / df["x1"].sum()
roas_2 = (amplitude * beta_2 * df["x2"]).sum() / df["x2"].sum()

fig, ax = plt.subplots(figsize=(7, 5))
(
    pd.Series(data=[roas_1, roas_2], index=["x1", "x2"]).plot(
        kind="bar", color=["C0", "C1"]
    )
)

ax.set(title="ROAS (Approximation)", xlabel="Media Channel", ylabel="ROAS");

## Data for modeling
We keep only what the model should see

In [None]:
columns_to_keep = [
    "date",
    "y",
    "x1",
    "x2"
]

data = df[columns_to_keep].copy()

data.head()

# Part II: Modeling

First, let's compute the share of spend per channel:

In [None]:
total_spend_per_channel = data[["x1", "x2"]].sum(axis=0)

spend_share = total_spend_per_channel / total_spend_per_channel.sum()

spend_share

Find good priors for the sigmas for the channels

In [None]:
# The scale helpful to make a HalfNormal distribution have unit variance
HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)

n_channels = 2

prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()

prior_sigma.tolist()

Use sklearn convention for input in predictors and data

In [None]:
X = data.drop("y", axis=1)
y = data["y"]

In [None]:
my_model_config = {
    "saturation_beta": {
        "dist": "LogNormal",
        "kwargs": {"mu": np.array([2, 1]), "sigma": prior_sigma},
    },
    "likelihood": {
        "dist": "Normal",
        "kwargs": {
            "sigma": {"dist": "HalfNormal", "kwargs": {"sigma": 2}}
            # Also possible define sigma as:
            # {'sigma': 5}
        },
    },
}

In [None]:
my_sampler_config = {"progressbar": True}

In [None]:
mmm = MelkDelayedSaturatedMMM(
    model_config=my_model_config,
    sampler_config=my_sampler_config,
    date_column="date",
    channel_columns=["x1", "x2"],
    adstock_max_lag=8,
)

In [None]:
mmm.fit(X=X, y=y, target_accept=0.90, chains=4, cores=4, random_seed=rng)

In [None]:
pm.model_to_graphviz(model=mmm.model)

In [None]:
var_names=[
        "intercept",
        "likelihood_sigma",
        "beta_channel",
        "lam",
        "alpha"
    ]

az.summary(
    data=mmm.fit_result,
    # var_names=["beta_channel"],
    var_names=var_names
)

In [None]:
_ = az.plot_trace(
    data=mmm.fit_result,
    var_names=var_names,
    compact=True,
    backend_kwargs={"figsize": (12, 10), "layout": "constrained"},
)
plt.gcf().suptitle("Model Trace", fontsize=16);

In [None]:
mmm.sample_posterior_predictive(X, extend_idata=True, combined=True)

In [None]:
mmm.plot_posterior_predictive(original_scale=True);

In [None]:
mmm.plot_errors(original_scale=True);

In [None]:
errors = mmm.get_errors(original_scale=True)

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))
az.plot_dist(
    errors, quantiles=[0.25, 0.5, 0.75], color="C3", fill_kwargs={"alpha": 0.7}, ax=ax
)
ax.axvline(x=0, color="black", linestyle="--", linewidth=1, label="zero")
ax.legend()
ax.set(title="Errors Posterior Distribution");

In [None]:
mmm.plot_components_contributions();

In [None]:
groups = {
    "Base": [
        "intercept",
    ],
    "Channel 1": ["x1"],
    "Channel 2": ["x2"],
}

fig = mmm.plot_grouped_contribution_breakdown_over_time(
    stack_groups=groups,
    original_scale=True,
    area_kwargs={
        "color": {
            "Channel 1": "C0",
            "Channel 2": "C1",
            "Base": "gray",
        },
        "alpha": 0.7,
    },
)

fig.suptitle("Contribution Breakdown over Time", fontsize=16);

In [None]:
mmm.plot_waterfall_components_decomposition();

In [None]:
get_mean_contributions_over_time_df = mmm.compute_mean_contributions_over_time(
    original_scale=True
)

aget_mean_contributions_over_time_df.head()

In [None]:
alpha1 = 0.005 # dummy
alpha2 = 0.006 # dummy

fig = mmm.plot_channel_parameter(param_name="alpha", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(x=alpha1, color="C0", linestyle="--", label=r"$\alpha_1$")
ax.axvline(x=alpha2, color="C1", linestyle="--", label=r"$\alpha_2$")
ax.legend(loc="upper right");

In [None]:
lam1 = 0.5 # dummy
lam2 = 0.7 # dummy

fig = mmm.plot_channel_parameter(param_name="lam", figsize=(9, 5))
ax = fig.axes[0]
ax.axvline(x=lam1, color="C0", linestyle="--", label=r"$\lambda_1$")
ax.axvline(x=lam2, color="C1", linestyle="--", label=r"$\lambda_2$")
ax.legend(loc="upper right");

In [None]:
fig = mmm.plot_channel_contribution_share_hdi(figsize=(7, 5))
ax = fig.axes[0]
ax.axvline(
    x=contribution_share_x1,
    color="C1",
    linestyle="--",
    label="true contribution share ($x_1$)",
)
ax.axvline(
    x=contribution_share_x2,
    color="C2",
    linestyle="--",
    label="true contribution share ($x_2$)",
)
ax.legend(loc="upper center", bbox_to_anchor=(0.5, -0.05), ncol=1);

In [None]:
fig = mmm.plot_direct_contribution_curves()
[ax.set(xlabel="x") for ax in fig.axes];

In [None]:
mmm.plot_channel_contributions_grid(start=0, stop=1.5, num=12);

In [None]:
mmm.plot_channel_contributions_grid(start=0, stop=1.5, num=12, absolute_xrange=True);

In [None]:
df

In [None]:
channels_contribution_original_scale = mmm.compute_channel_contribution_original_scale()
channels_contribution_original_scale_hdi = az.hdi(
    ary=channels_contribution_original_scale
)

fig, ax = plt.subplots(
    nrows=2, figsize=(15, 8), ncols=1, sharex=True, sharey=False, layout="constrained"
)

for i, x in enumerate(["x1", "x2"]):
    # Estimate true contribution in the original scale from the data generating process
    sns.lineplot(
        x=df["date"],
        y=amplitude * betas[i] * df[f"{x}"],
        color="black",
        label=f"{x} true contribution",
        ax=ax[i],
    )
    # HDI estimated contribution in the original scale
    ax[i].fill_between(
        x=df["date"],
        y1=channels_contribution_original_scale_hdi.sel(channel=x)["x"][:, 0],
        y2=channels_contribution_original_scale_hdi.sel(channel=x)["x"][:, 1],
        color=f"C{i}",
        label=rf"{x} $94\%$ HDI contribution",
        alpha=0.4,
    )
    # Mean estimated contribution in the original scale
    sns.lineplot(
        x=df["date"],
        y=get_mean_contributions_over_time_df[x].to_numpy(),
        color=f"C{i}",
        label=f"{x} posterior mean contribution",
        alpha=0.8,
        ax=ax[i],
    )
    ax[i].legend(loc="center left", bbox_to_anchor=(1, 0.5))
    ax[i].set(title=f"Channel {x}")

In [None]:
channel_contribution_original_scale = mmm.compute_channel_contribution_original_scale()

roas_samples = (
    channel_contribution_original_scale.stack(sample=("chain", "draw")).sum("date")
    / data[["x1", "x2"]].sum().to_numpy()[..., None]
)

fig, ax = plt.subplots(figsize=(10, 6))
sns.histplot(
    roas_samples.sel(channel="x1").to_numpy(), binwidth=0.01, alpha=0.3, kde=True, ax=ax
)
sns.histplot(
    roas_samples.sel(channel="x2").to_numpy(), binwidth=0.01, alpha=0.3, kde=True, ax=ax
)
ax.axvline(x=roas_1, color="C0", linestyle="--", label=r"true ROAS $x_{1}$")
ax.axvline(x=roas_2, color="C1", linestyle="--", label=r"true ROAS $x_{2}$")
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
ax.set(title="Posterior ROAS distribution", xlabel="ROAS");

In [None]:
last_date = X["date"].max()

# New dates starting from last in dataset
n_new = 5
new_dates = pd.date_range(start=last_date, periods=1 + n_new, freq="D")[1:]

X_out_of_sample = pd.DataFrame(
    {
        "date": new_dates,
    }
)

# Same channel spends as last day
X_out_of_sample["x1"] = X["x1"].iloc[-1]
X_out_of_sample["x2"] = X["x2"].iloc[-1]


X_out_of_sample

In [None]:
X_out_of_sample.info()

In [None]:
y_out_of_sample = mmm.sample_posterior_predictive(
    X_pred=X_out_of_sample, extend_idata=False
)

y_out_of_sample

In [None]:
def plot_in_sample(X, y, ax, n_points: int = 15):
    (
        y.to_frame()
        .set_index(X["date"])
        .iloc[-n_points:]
        .plot(ax=ax, marker="o", color="black", label="actuals")
    )
    return ax


def plot_out_of_sample(X_out_of_sample, y_out_of_sample, ax, color, label):
    y_out_of_sample_groupby = y_out_of_sample["y"].to_series().groupby("date")

    lower, upper = quantiles = [0.025, 0.975]
    conf = y_out_of_sample_groupby.quantile(quantiles).unstack()
    ax.fill_between(
        X_out_of_sample["date"].dt.to_pydatetime(),
        conf[lower],
        conf[upper],
        alpha=0.25,
        color=color,
        label=f"{label} interval",
    )

    mean = y_out_of_sample_groupby.mean()
    mean.plot(ax=ax, marker="o", label=label, color=color, linestyle="--")
    ax.set(ylabel="Original Target Scale", title="Out of sample predictions for MMM")
    return ax


_, ax = plt.subplots()
plot_in_sample(X, y, ax=ax)
plot_out_of_sample(
    X_out_of_sample, y_out_of_sample, ax=ax, label="out of sample", color="C0"
)
ax.legend(loc="upper left");

In [None]:
y_out_of_sample_with_adstock = mmm.sample_posterior_predictive(
    X_pred=X_out_of_sample, extend_idata=False, include_last_observations=True
)

In [None]:
_, ax = plt.subplots()
plot_in_sample(X, y, ax=ax)
plot_out_of_sample(
    X_out_of_sample, y_out_of_sample, ax=ax, label="out of sample", color="C0"
)
plot_out_of_sample(
    X_out_of_sample,
    y_out_of_sample_with_adstock,
    ax=ax,
    label="adstock out of sample",
    color="C1",
)
ax.legend();

In [None]:
spends = [0.3, 0.5, 1, 2]

fig, axes = plt.subplots(
    nrows=len(spends),
    ncols=1,
    figsize=(11, 9),
    sharex=True,
    sharey=True,
    layout="constrained",
)

axes = axes.flatten()

for ax, spend in zip(axes, spends, strict=True):
    mmm.plot_new_spend_contributions(spend_amount=spend, progressbar=False, ax=ax)

fig.suptitle("New Spend Contribution Simulations", fontsize=18, fontweight="bold");