In [1]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# import seaborn as sns
import plotly.express as px

In [2]:
az.style.use("arviz-darkgrid")
plt.rcParams["figure.figsize"] = [12, 7]
plt.rcParams["figure.dpi"] = 100

%load_ext autoreload
%autoreload 2
%config InlineBackend.figure_format = "retina"

In [3]:
from media_transforms import adstock, hill, calculate_seasonality
import jax.numpy as jnp
## Adstock as defined by us to compare later
# from statsmodels.tsa.filters.filtertools import recursive_filter
# def my_adstock(x, theta):
#     return recursive_filter(x, theta)

### Simulate data

In [4]:
seed: int = sum(map(ord, "mmm"))
rng: np.random.Generator = np.random.default_rng(seed=seed)

# date range
min_date = pd.to_datetime("2019-01-01")
max_date = pd.to_datetime("2022-12-31")

df = pd.DataFrame(
    data={"date": pd.date_range(start=min_date, end=max_date, freq="W-MON")}
)

n = df.shape[0]

# media data
x1 = rng.uniform(low=0.0, high=1.0, size=n)
df["x1"] = np.where(x1 > 0.7, 5e3*x1, 5e3*x1/3)

x2 = rng.uniform(low=0.0, high=1.0, size=n)
df["x2"] = np.where(x2 > 0.9, 2.5e3*x2, 0)

# Create a line plot for x1
fig_1 = px.line(df, x='date', y='x1', title='')

# Create a line plot for x2
fig_2 = px.line(df, x='date', y='x2', title='')

# Show the plots
fig_1.show()
fig_2.show()

### Adstock

In [5]:
## Adstocking media
lag_weight_x1 = 0.15
lag_weight_x2 = 0.55
df['x1_adstock'] = adstock(jnp.array(df['x1']), lag_weight_x1)
df['x2_adstock'] = adstock(jnp.array(df['x2']), lag_weight_x2)
# Create a line plot for x1
fig_1 = px.line(df, x='date', y='x1_adstock', title='')

# Create a line plot for x2
fig_2 = px.line(df, x='date', y='x2_adstock', title='')

# Show the plots
fig_1.show()
fig_2.show()

### Hill

In [6]:
df['x2_adstock'].median()

45.433292

In [7]:
# Saturation with hill
slope_x1 = 2
slope_x2 = 0.85
half_max_effective_concentration_x1 = df['x1_adstock'].median()
half_max_effective_concentration_x2 = df['x2_adstock'].median()
df['x1_hill_adstock'] = hill(
    data=jnp.array(df['x1_adstock']), 
    half_max_effective_concentration=half_max_effective_concentration_x1, 
    slope=slope_x1
)
df['x2_hill_adstock'] = hill(
    data=jnp.array(df['x2_adstock']), 
    half_max_effective_concentration=half_max_effective_concentration_x2, 
    slope=slope_x2
)

# Create a line plot for x1
fig_1 = px.line(df, x='date', y='x1_hill_adstock', title='')

# Create a line plot for x2
fig_2 = px.line(df, x='date', y='x2_hill_adstock', title='')

# Show the plots
fig_1.show()
fig_2.show()

### Time series components

In [8]:
## Adding some trend and seasonality
df["trend"] = (np.linspace(start=0.0, stop=100, num=n) + 10) ** (1/2) - 1
df["seasonality"] = calculate_seasonality(
    number_periods=208,
    degrees=2,
    gamma_seasonality=0.5
)
fig = px.line(df, x='date', y=['trend', 'seasonality'], title='Trend and Seasonality')
fig.update_xaxes(title_text='Date Week')
fig.update_yaxes(title_text='Value')

### Indicator features

In [9]:
## Adding some indicators
df["event_1"] = (df["date"] == "2019-05-13").astype(float)
df["event_2"] = (df["date"] == "2020-09-14").astype(float)
df.head()

Unnamed: 0,date,x1,x2,x1_adstock,x2_adstock,x1_hill_adstock,x2_hill_adstock,trend,seasonality,event_1,event_2
0,2019-01-07,1061.933393,0.0,902.643372,0.0,0.451056,0.0,2.162278,1.0,0.0,0.0
1,2019-01-14,374.628255,0.0,453.830566,0.0,0.171987,0.0,2.23776,1.161752,0.0,0.0
2,2019-01-21,974.667554,0.0,896.541992,0.0,0.4477,0.0,2.311523,1.280218,0.0,0.0
3,2019-01-28,237.995175,0.0,336.777222,0.0,0.102641,0.0,2.383678,1.350627,0.0,0.0
4,2019-02-04,3867.451536,0.0,3337.850342,0.0,0.918273,0.0,2.454326,1.370614,0.0,0.0


### Target

In [10]:
## Creating target
df["intercept"] = 5e2
df["epsilon"] = rng.normal(loc=0.0, scale=0.25, size=n)

beta_x1 = 87.0
beta_x2 = 55.0
beta_trend = 10
beta_seasonality = 4.0
beta_event1 = 1.5
beta_event2 = 2.5
betas = [beta_x1, beta_x2, beta_trend, beta_seasonality, beta_event1, beta_event2]

df["y"] = (
    df["intercept"]
    + beta_trend*df["trend"]
    + beta_seasonality*df["seasonality"]
    + beta_event1 * df["event_1"]
    + beta_event2 * df["event_2"]
    + beta_x1 * df["x1_hill_adstock"]
    + beta_x2 * df["x2_hill_adstock"]
    + df["epsilon"]
)

fig = px.line(df, x='date', y='y', title='Target')
fig.update_xaxes(title_text='Date Week')
fig.update_yaxes(title_text='')

### Contributions

In [11]:
contribution_share_media = (
    (beta_x1 * df["x1_hill_adstock"] + beta_x2 * df["x2_hill_adstock"])
).sum() / (
    df['y']
).sum()
contribution_share_media

0.11447574565410344

### Icp

In [12]:
icp_x1 = beta_x1*df['x1_hill_adstock'].sum()/df['x1'].sum()
icp_x2 = beta_x2*df['x2_hill_adstock'].sum()/df['x2'].sum()
print(f"ICP of x1: {icp_x1:.2f}")
print(f"ICP of x2: {icp_x2:.2f}")

ICP of x1: 0.03
ICP of x2: 0.10


### Modeling

In [13]:
import utils, preprocessing, lightweight_mmm
from numpyro import distributions as dist


IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html



In [14]:
df.head()

Unnamed: 0,date,x1,x2,x1_adstock,x2_adstock,x1_hill_adstock,x2_hill_adstock,trend,seasonality,event_1,event_2,intercept,epsilon,y
0,2019-01-07,1061.933393,0.0,902.643372,0.0,0.451056,0.0,2.162278,1.0,0.0,0.0,500.0,-0.193179,564.67148
1,2019-01-14,374.628255,0.0,453.830566,0.0,0.171987,0.0,2.23776,1.161752,0.0,0.0,500.0,-0.078831,541.908626
2,2019-01-21,974.667554,0.0,896.541992,0.0,0.4477,0.0,2.311523,1.280218,0.0,0.0,500.0,-0.319946,566.866027
3,2019-01-28,237.995175,0.0,336.777222,0.0,0.102641,0.0,2.383678,1.350627,0.0,0.0,500.0,-0.410632,537.75844
4,2019-02-04,3867.451536,0.0,3337.850342,0.0,0.918273,0.0,2.454326,1.370614,0.0,0.0,500.0,0.490015,610.405453


In [34]:
modeling_features = [
    'date',
    'x1',
    'x2',
    'event_1',
    'event_2'
]

modeling_data = df.copy()
modeling_data = modeling_data[modeling_features + ['y']]
modeling_data

Unnamed: 0,date,x1,x2,event_1,event_2,y
0,2019-01-07,1061.933393,0.0,0.0,0.0,564.671480
1,2019-01-14,374.628255,0.0,0.0,0.0,541.908626
2,2019-01-21,974.667554,0.0,0.0,0.0,566.866027
3,2019-01-28,237.995175,0.0,0.0,0.0,537.758440
4,2019-02-04,3867.451536,0.0,0.0,0.0,610.405453
...,...,...,...,...,...,...
203,2022-11-28,459.306385,0.0,0.0,0.0,611.674631
204,2022-12-05,198.142316,0.0,0.0,0.0,600.002743
205,2022-12-12,1133.755185,0.0,0.0,0.0,638.978740
206,2022-12-19,382.580521,0.0,0.0,0.0,613.420155


In [35]:
data_size = len(modeling_data)
split_point = data_size - data_size // 10
train_data = modeling_data[:split_point]
train_data

Unnamed: 0,date,x1,x2,event_1,event_2,y
0,2019-01-07,1061.933393,0.0,0.0,0.0,564.671480
1,2019-01-14,374.628255,0.0,0.0,0.0,541.908626
2,2019-01-21,974.667554,0.0,0.0,0.0,566.866027
3,2019-01-28,237.995175,0.0,0.0,0.0,537.758440
4,2019-02-04,3867.451536,0.0,0.0,0.0,610.405453
...,...,...,...,...,...,...
183,2022-07-11,746.595342,0.0,0.0,0.0,657.713404
184,2022-07-18,313.315632,0.0,0.0,0.0,628.343828
185,2022-07-25,1089.652221,0.0,0.0,0.0,653.589232
186,2022-08-01,4398.909176,0.0,0.0,0.0,686.216751


In [36]:
test_data = modeling_data[split_point:]
test_data

Unnamed: 0,date,x1,x2,event_1,event_2,y
188,2022-08-15,3625.026945,0.0,0.0,0.0,675.877566
189,2022-08-22,1133.848609,0.0,0.0,0.0,652.324897
190,2022-08-29,69.960241,0.0,0.0,0.0,597.615322
191,2022-09-05,3906.414407,0.0,0.0,0.0,670.808845
192,2022-09-12,4820.934399,0.0,0.0,0.0,672.727223
193,2022-09-19,1054.267383,0.0,0.0,0.0,651.678111
194,2022-09-26,563.110601,0.0,0.0,0.0,618.32665
195,2022-10-03,364.359686,0.0,0.0,0.0,601.576803
196,2022-10-10,273.160438,0.0,0.0,0.0,594.998672
197,2022-10-17,16.114747,0.0,0.0,0.0,588.74746


In [111]:
train_media_data, train_extra_features, train_target, train_costs = utils.dataframe_to_jax(
    dataframe=train_data,
    media_features=['x1', 'x2'],
    extra_features=['event_1', 'event_2'],
    date_feature="date",
    target="y"
)

# media_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
# extra_features_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
# target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
# cost_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)

# train_media_data = media_scaler.fit_transform(train_media_data)
# train_extra_features = extra_features_scaler.fit_transform(train_extra_features)
# train_target = target_scaler.fit_transform(train_target)
# train_costs = cost_scaler.fit_transform(train_costs)

In [109]:
train_costs

Array([1.7376053, 0.2623946], dtype=float32)

In [129]:
unscaled_costs

Array([346472.38,  52320.56], dtype=float32)

In [133]:
train_costs

Array([346472.38,  52320.56], dtype=float32)

In [79]:
mmm = lightweight_mmm.LightweightMMM(
    model_name="hill_adstock"
)

SEED = 105
number_warmup = 1000
number_samples = 1000
number_chains = 3

custom_priors = {
    # "half_max_effective_concentration": dist.Normal(loc=0.5, scale=0.4),
    "half_max_effective_concentration": dist.Uniform(low=10, high=1000),
    # "half_max_effective_concentration": dist.TruncatedNormal(loc=0.5, low=0.2, high=0.9),
    'slope': dist.Uniform(low=0.5, high=4.0),
    'lag_weight': dist.Uniform(low=0, high=0.98)
}

media_names = ['x_1', 'x_2']

mmm.fit(
    media=train_media_data,
    media_prior=train_costs,
    target=train_target,
    extra_features=train_extra_features,
    number_warmup=number_warmup,
    number_samples=number_samples,
    number_chains=number_chains,
    weekday_seasonality=False,
    seasonality_frequency=52,
    media_names=media_names,
    custom_priors=custom_priors,
    seed=SEED)

mmm_params = mmm._mcmc.get_samples()



There are not enough devices to run parallel chains: expected 3 but got 1. Chains will be drawn sequentially. If you are running MCMC in CPU, consider using `numpyro.set_host_device_count(3)` at the beginning of your program. You can double-check how many devices are available in your system using `jax.local_device_count()`.

sample: 100%|██████████| 2000/2000 [00:26<00:00, 74.26it/s, 511 steps of size 5.38e-03. acc. prob=0.95] 
warmup:  38%|███▊      | 768/2000 [00:10<00:16, 72.93it/s, 511 steps of size 6.57e-03. acc. prob=0.84]  


KeyboardInterrupt: 

In [75]:
mmm.print_summary()


                                         mean       std    median      5.0%     95.0%     n_eff     r_hat
             coef_extra_features[0]      0.00      0.00      0.00      0.00      0.00   2855.99      1.00
             coef_extra_features[1]     -0.00      0.00     -0.00     -0.00      0.00   3305.75      1.00
                      coef_media[0]      0.59      0.02      0.59      0.56      0.62    645.37      1.00
                      coef_media[1]      0.27      0.02      0.27      0.25      0.30    620.20      1.00
                      coef_trend[0]      0.00      0.00      0.00      0.00      0.00   1205.44      1.00
                         expo_trend      0.73      0.06      0.73      0.64      0.84   1104.92      1.00
             gamma_seasonality[0,0]      0.00      0.00      0.00      0.00      0.01   3087.27      1.00
             gamma_seasonality[0,1]      0.01      0.00      0.01      0.00      0.01   3412.44      1.00
             gamma_seasonality[1,0]      0.00

In [57]:
slope_x1

2