In [None]:
import os
import yaml
import pandas as pd
from pathlib import Path
import pytensor.tensor as pt 
import warnings
import arviz as az
import plotly.io as pio
import matplotlib.pyplot as plt
import pymc as pm
from IPython.display import Markdown, display
import numpy as np
import plotly.graph_objects as go
import json
from sklearn.metrics import r2_score, mean_squared_error
from sklearn.model_selection import train_test_split
from tabulate import tabulate
import pandas as pd
import numpy as np
from sklearn.preprocessing import MaxAbsScaler, MinMaxScaler
pio.renderers.default = "plotly_mimetype+notebook+vscode"

pd.set_option('display.max_columns', None)
RANDOM_SEED = 69

warnings.filterwarnings('ignore')
if 'base_dir' not in globals():
    base_dir = Path.cwd()
main_dir = base_dir.parent.parent
if Path.cwd() != main_dir:
    os.chdir(main_dir)

from src.helpers.mediator_bayesian_helpers import *

# Master Variables

In [None]:
def load_config(config_path='notebooks/modeling/config_files/14022025_SPLIT_EMEA_MEDIATOR_DAV.yml'):
    with open(config_path, 'r') as file:
        config = yaml.safe_load(file)
    return config

config = load_config()
date_column = config['variables']['date_column'][0]
start_date = config['data']['start_date']
end_date = config['data']['end_date']
dep_var = config['variables']['dep_var']
mediator_var = config['variables']['mediator_var']
dependent_variable_seasonality = config['data']['dependent_variable_seasonality']
dependent_variable_trend = config['data']['dependent_variable_trend']
mediator_variable_trend = config['data']['mediator_variable_trend']
mediator_variable_seasonality = config['data']['mediator_variable_seasonality']



# Read datamaster

In [None]:
data = pd.read_excel(config['data']['file_path'])
data = data.dropna(axis=1, how='all')
data = data.loc[:, (data != 0).any(axis=0)]
data.head()


In [None]:
plot_data(data, date_column=date_column, scale_data=True, plot_height=800, columns_to_plot=data.drop(columns = [date_column]).columns)

# Filter date

In [None]:
data[date_column] = pd.to_datetime(data[date_column])
data = data[(data[date_column] < end_date) & (data[date_column] > start_date)] 
model_df = data[(data[date_column] < end_date) & (data[date_column] > start_date)]  
model_df.head()

# Plot playback data

In [None]:
playback_vars = model_df.filter(regex='playback').columns.to_list() + [dep_var] + [mediator_var]
heatmap_figure = plot_correlation_matrix(model_df, playback_vars)
plot_data(model_df, date_column=date_column, scale_data=True, plot_height=800, columns_to_plot=playback_vars)

# Marketing data

In [None]:
marketing_vars = model_df.filter(regex='marketing').columns.to_list()  + [dep_var] + [mediator_var]
heatmap_figure = plot_correlation_matrix(model_df, marketing_vars)
plot_data(model_df, date_column=date_column, scale_data=True, plot_height=800, columns_to_plot=marketing_vars)

# Campaigns Data

In [None]:
campaigns_vars = model_df.filter(regex='campaign').columns.to_list()
if len(campaigns_vars)!=0:
    n_splits = 3
    split_points = [i * len(campaigns_vars) // n_splits for i in range(n_splits + 1)]
    for i in range(n_splits):
        split_vars = [dep_var] + [mediator_var] + campaigns_vars[split_points[i]:split_points[i+1]]
        heatmap_figure = plot_correlation_matrix(
            model_df, 
            split_vars, 
            title=f" Part {i+1}"
        )

    plot_data(model_df, date_column=date_column, scale_data=True, plot_height=800, columns_to_plot=campaigns_vars)

# Competitor data

In [None]:
tmp_model_df = model_df.loc[:, ~model_df.columns.str.contains('snapchat', case=False)]
comp_vars = tmp_model_df.filter(regex='competitors').columns.to_list()
if len(comp_vars)!=0:
    n_splits = 4
    split_points = [i * len(comp_vars) // n_splits for i in range(n_splits + 1)]
    for i in range(n_splits):
        split_vars = [dep_var] + [mediator_var] + comp_vars[split_points[i]:split_points[i+1]]
        heatmap_figure = plot_correlation_matrix(
            tmp_model_df, 
            split_vars, 
            title=f"Competitors - Part {i+1}"
        )
    plot_data(
        tmp_model_df, 
        date_column=date_column, 
        scale_data=True, 
        plot_height=800, 
        columns_to_plot=comp_vars,
    )

# Trend

In [None]:
trend_feature = pd.Series(range(len(model_df)), index=model_df.index)
model_df["dep_variable_trend_feature"] = trend_feature
model_df["mediator_variable_trend_feature"] = trend_feature

# Premium

In [None]:
premium_vars = model_df.filter(regex='premium').columns.to_list() + [dep_var] + [mediator_var]
heatmap_figure = plot_correlation_matrix(tmp_model_df, premium_vars)
plot_data(tmp_model_df, date_column=date_column, scale_data=True, plot_height=800, columns_to_plot=premium_vars)

#### Nan Checks

In [None]:
model_df.isna().sum()[model_df.isna().sum() != 0]

## Features Preparation Dependent Variable

In [None]:
original_paid_features = sorted(config['variables']['original_paid_features'])
original_organic_features = sorted(config['variables']['original_organic_features'])
original_competitor_features = sorted(config['variables']['original_competitor_features'])
original_control_features = sorted(config['variables']['original_control_features'])
EXCLUDED_FEATURES = sorted([])

all_original_features = sorted(original_paid_features +
                               original_organic_features +
                               original_competitor_features + 
                               original_control_features)

FEATURES = {
    "features_all_possible": all_original_features,
    "features_included": sorted(set(all_original_features) - set(EXCLUDED_FEATURES)),
    "features_excluded": EXCLUDED_FEATURES,
    "features_paid": sorted(set(original_paid_features) - set(EXCLUDED_FEATURES)),
    "features_organic": sorted(set(original_organic_features) - set(EXCLUDED_FEATURES)),
    "features_competitor": sorted(set(original_competitor_features) - set(EXCLUDED_FEATURES)),
    "features_control": sorted(set(original_control_features) - set(EXCLUDED_FEATURES)),

    
    
} 
y = model_df[dep_var].copy()
X = model_df[FEATURES["features_included"] + [date_column, "dep_variable_trend_feature"]].copy()


X.index = model_df.date.copy()

## Feature Preparation Mediator

In [None]:
mediator_paid_features = sorted(config['variables']['mediator_original_paid_features'])
mediator_organic_features = sorted(config['variables']['mediator_original_organic_features'])
mediator_competitor_features = sorted(config['variables']['mediator_original_competitor_features'])
mediator_control_features = sorted(config['variables']['mediator_original_control_features'])
EXCLUDED_MEDIATOR_FEATURES = sorted([])  # Update if any mediator features need to be excluded

all_mediator_features = sorted(
    mediator_paid_features +
    mediator_organic_features +
    mediator_competitor_features +
    mediator_control_features
)

MEDIATOR_FEATURES = {
    "mediator_features_all_possible": all_mediator_features,
    "mediator_features_included": sorted(set(all_mediator_features) - set(EXCLUDED_MEDIATOR_FEATURES)),
    "mediator_features_excluded": EXCLUDED_MEDIATOR_FEATURES,
    "mediator_paid": sorted(set(mediator_paid_features) - set(EXCLUDED_MEDIATOR_FEATURES)),
    "mediator_organic": sorted(set(mediator_organic_features) - set(EXCLUDED_MEDIATOR_FEATURES)),
    "mediator_competitor": sorted(set(mediator_competitor_features) - set(EXCLUDED_MEDIATOR_FEATURES)),
    "mediator_control": sorted(set(mediator_control_features) - set(EXCLUDED_MEDIATOR_FEATURES)),
}
y_mediator = model_df[mediator_var].copy()

features_for_mediator = MEDIATOR_FEATURES["mediator_features_included"] + [date_column, "mediator_variable_trend_feature"]

X_mediator = model_df[features_for_mediator].copy()
X_mediator.index = model_df[date_column].copy()

# Scaling Dep Variable And Mediator

In [None]:
SCALER_MAPPING = {
    "MinMaxScaler": MinMaxScaler,
    "MaxAbsScaler": MaxAbsScaler
}
scaling_config = config.get("scaling", {})

paid_scaler = SCALER_MAPPING.get(scaling_config.get("paid", {}).get("type"), MinMaxScaler)()
organic_scaler = SCALER_MAPPING.get(scaling_config.get("organic", {}).get("type"), MinMaxScaler)()
competitor_scaler = SCALER_MAPPING.get(scaling_config.get("competitor", {}).get("type"), MinMaxScaler)()
target_scaler = SCALER_MAPPING.get(scaling_config.get("target", {}).get("type"), MinMaxScaler)()
trend_scaler = MinMaxScaler()

X[FEATURES["features_paid"]] = paid_scaler.fit_transform(X[FEATURES["features_paid"]])
X[FEATURES["features_organic"]] = organic_scaler.fit_transform(X[FEATURES["features_organic"]])
X[FEATURES["features_competitor"]] = competitor_scaler.fit_transform(X[FEATURES["features_competitor"]])

X["dep_variable_trend_feature"] = trend_scaler.fit_transform(X[["dep_variable_trend_feature"]])

y = target_scaler.fit_transform(y.to_numpy().reshape(-1, 1)).flatten()

mediator_scaler = SCALER_MAPPING.get(scaling_config.get("mediator", {}).get("type"), MinMaxScaler)()
mediator_paid_scaler = SCALER_MAPPING.get(scaling_config.get("mediator_paid", {}).get("type"), MinMaxScaler)()
mediator_organic_scaler = SCALER_MAPPING.get(scaling_config.get("mediator_organic", {}).get("type"), MinMaxScaler)()
mediator_competitor_scaler = SCALER_MAPPING.get(scaling_config.get("mediator_competitor", {}).get("type"), MinMaxScaler)()
mediator_control_scaler = SCALER_MAPPING.get(scaling_config.get("mediator_control", {}).get("type"), MinMaxScaler)()
mediator_trend_scaler = MinMaxScaler()

X_mediator[MEDIATOR_FEATURES["mediator_paid"]] = mediator_paid_scaler.fit_transform(X_mediator[MEDIATOR_FEATURES["mediator_paid"]])
X_mediator[MEDIATOR_FEATURES["mediator_organic"]] = mediator_organic_scaler.fit_transform(X_mediator[MEDIATOR_FEATURES["mediator_organic"]])
X_mediator[MEDIATOR_FEATURES["mediator_competitor"]] = mediator_competitor_scaler.fit_transform(X_mediator[MEDIATOR_FEATURES["mediator_competitor"]])
X_mediator[MEDIATOR_FEATURES["mediator_control"]] = mediator_control_scaler.fit_transform(X_mediator[MEDIATOR_FEATURES["mediator_control"]])
X_mediator["mediator_variable_trend_feature"] = mediator_trend_scaler.fit_transform(X_mediator[["mediator_variable_trend_feature"]])

y_mediator = mediator_scaler.fit_transform(y_mediator.to_numpy().reshape(-1, 1)).flatten()

# Seasonality

In [None]:
if not pd.api.types.is_datetime64_any_dtype(X.index):
    X.index = pd.to_datetime(X.index)
seasonality_n_order = config['data'].get('seasonality_n_order', 0)
SEASONALITY_CONFIG = {"seasonality_n_order": seasonality_n_order}
periods = X.index.dayofyear / 365.25

if seasonality_n_order > 0:
    fourier_features = pd.DataFrame(
        {
            f"{func}_order_{order}": getattr(np, func)(2 * np.pi * periods * order)
            for order in range(1, seasonality_n_order + 1)
            for func in ("sin", "cos")
        },
        index=X.index
    )
    SEASONALITY_CONFIG["seasonality_features"] = list(fourier_features.columns)
    X = pd.concat([X, fourier_features], axis=1)
else:
    print("Seasonality order is 0. Skipping Fourier feature generation.")

## Train test split

In [None]:
X_train, X_test, y_train, y_test, unscaled_y_train, unscaled_y_test = train_test_split(
    X, 
    y, 
    model_df[dep_var].to_numpy(),
    test_size=config['data']['test_size'], 
    shuffle=False
)
X_mediator_train, X_mediator_test, y_mediator_train, y_mediator_test, unscaled_y_mediator_train, unscaled_y_mediator_test = train_test_split(
    X_mediator, 
    y_mediator, 
    model_df[mediator_var].to_numpy(),
    test_size=config['data']['test_size'], 
    shuffle=False
)

In [None]:
plot_data(
    model_df, 
    date_column=date_column, 
    scale_data=True, 
    plot_height=800,
    columns_to_plot = X.drop(columns=[date_column]).columns.tolist() + [dep_var] + [mediator_var], 
)

## Model
### Parameters

In [None]:
rng = np.random.default_rng(42)
l_max=13

## Model Building

In [None]:
coords = {
    "paid": FEATURES['features_paid'],
    "organic": FEATURES['features_organic'],
    "competitor": FEATURES['features_competitor'],
    "control": FEATURES['features_control'],
    "mediator_paid": MEDIATOR_FEATURES['mediator_paid'],
    "mediator_organic": MEDIATOR_FEATURES['mediator_organic'],
    "mediator_competitor": MEDIATOR_FEATURES['mediator_competitor'],
    "mediator_control": MEDIATOR_FEATURES['mediator_control'],
}
    
    
if dependent_variable_seasonality:
    coords["fourier_mode"] = np.arange(len(SEASONALITY_CONFIG["seasonality_features"]))
if dependent_variable_trend:
    coords["dep_variable_trend_feature"] = X_train["dep_variable_trend_feature"]
# Create coordinate dictionaries based on feature configurations


with pm.Model(coords=coords) as mmm:
    mmm.add_coord(date_column, X_train.index, mutable=True)
    paid_data = pm.MutableData("paid_data",
                               value=X_train[FEATURES['features_paid']].to_numpy(),
                               dims=(date_column, "paid"))
    
    organic_data = pm.MutableData("organic_data",
                                  value=X_train[FEATURES['features_organic']].to_numpy(),
                                  dims=(date_column, "organic"))
    
    competitor_data = pm.MutableData("competitor_data",
                                     value=X_train[FEATURES['features_competitor']].to_numpy(),
                                     dims=(date_column, "competitor"))
    
    control_data = pm.MutableData("control_data",
                                 value=X_train[FEATURES['features_control']].to_numpy(),
                                 dims=(date_column, "control"))
    
    if dependent_variable_trend:
        trend_data = pm.MutableData("trend_data",
                                   value=X_train["dep_variable_trend_feature"].to_numpy(dtype=np.float64),
                                   dims=date_column)
    
    if dependent_variable_seasonality:
        seasonality_data = pm.MutableData("seasonality_data",
                                        value=X_train[SEASONALITY_CONFIG["seasonality_features"]].to_numpy(),
                                        dims=(date_column, "fourier_mode"))
    
    y_obs_data = pm.MutableData("y_obs_data",
                                value=y_train,
                                dims=date_column)
    
    mediator_paid_data = pm.MutableData("mediator_paid_data",
                                        value=X_mediator_train[MEDIATOR_FEATURES['mediator_paid']].to_numpy(),
                                        dims=(date_column, "mediator_paid"))
    
    mediator_organic_data = pm.MutableData("mediator_organic_data",
                                           value=X_mediator_train[MEDIATOR_FEATURES['mediator_organic']].to_numpy(),
                                           dims=(date_column, "mediator_organic"))
    
    mediator_competitor_data = pm.MutableData("mediator_competitor_data",
                                              value=X_mediator_train[MEDIATOR_FEATURES['mediator_competitor']].to_numpy(),
                                              dims=(date_column, "mediator_competitor"))
    
    mediator_control_data = pm.MutableData("mediator_control_data",
                                           value=X_mediator_train[MEDIATOR_FEATURES['mediator_control']].to_numpy(),
                                           dims=(date_column, "mediator_control"))
    
    if mediator_variable_trend:
        mediator_trend_data = pm.MutableData("mediator_trend_data",
                                           value=X_mediator_train["mediator_variable_trend_feature"].to_numpy(dtype=np.float64),
                                           dims=date_column)
    
    mediator_obs_data = pm.MutableData("mediator_obs_data",
                                      value=y_mediator_train,
                                      dims=date_column)
    
    ####### MAIN MODEL PARAMETERS
    intercept = pm.Normal("intercept", mu=0, sigma=config['priors']['intercept_sigma'])
    
    beta_paid_coeffs = pm.Normal("beta_paid_coeffs", 
                                mu=0, 
                                sigma=config['priors']['beta_paid_coeffs_sigma'], 
                                dims="paid")
    
    alpha_paid = pm.Beta("alpha_paid", 
                        alpha=config['priors']['alpha_alpa'], 
                        beta=config['priors']['alpha_beta'], 
                        dims="paid")
    
    beta_organic_coeffs = pm.Normal("beta_organic_coeffs", 
                                   mu=0, 
                                   sigma=config['priors']['beta_organic_coeffs_sigma'], 
                                   dims="organic")
    
    alpha_organic = pm.Beta("alpha_organic", 
                          alpha=config['priors']['alpha_alpa'], 
                          beta=config['priors']['alpha_beta'], 
                          dims="organic")
    
    sigma_beta_competitor = pm.HalfNormal("sigma_beta_competitor", 
                                        sigma=config['priors']['beta_competitor_sigma'], 
                                        dims="competitor")
    
    non_neg_beta_competitor_offset = pm.Normal("non_neg_beta_competitor_offset", 
                                             mu=0, 
                                             sigma=1, 
                                             dims="competitor")
    
    beta_competitor_coeffs = pm.Deterministic("beta_competitor_coeffs", 
                                            -sigma_beta_competitor * pt.abs(non_neg_beta_competitor_offset), 
                                            dims="competitor")
    
    beta_control_coeffs = pm.Normal("beta_control_coeffs", 
                                  mu=0, 
                                  sigma=config['priors']['beta_control_coeffs_sigma'], 
                                  dims="control")
    
    if dependent_variable_trend:
        beta_trend = pm.HalfNormal("beta_trend", sigma=config['priors']['beta_trend_sigma'])
    
    if dependent_variable_seasonality:
        beta_fourier = pm.Laplace("beta_fourier", 
                                 mu=0, 
                                 b=config['priors']['beta_fourier_sigma'], 
                                 dims="fourier_mode")
    
    sigma = pm.HalfNormal("sigma", sigma=1)
    
    paid_adstock = pm.Deterministic("paid_adstock",
                                   geometric_adstock(x=paid_data, 
                                                   alpha=alpha_paid, 
                                                   l_max=l_max, 
                                                   normalize=True),
                                   dims=(date_column, "paid"))
    
    organic_adstock = pm.Deterministic("organic_adstock",
                                      geometric_adstock(x=organic_data, 
                                                      alpha=alpha_organic, 
                                                      l_max=l_max, 
                                                      normalize=True),
                                      dims=(date_column, "organic"))
    
    paid_contributions = pm.Deterministic("paid_contributions",
                                        paid_adstock * beta_paid_coeffs,
                                        dims=(date_column, "paid"))
    
    organic_contributions = pm.Deterministic("organic_contributions",
                                           organic_adstock * beta_organic_coeffs,
                                           dims=(date_column, "organic"))
    
    competitor_contributions = pm.Deterministic("competitor_contributions",
                                              competitor_data * beta_competitor_coeffs,
                                              dims=(date_column, "competitor"))
    
    control_contributions = pm.Deterministic("control_contributions",
                                           control_data * beta_control_coeffs,
                                           dims=(date_column, "control"))
    
    if dependent_variable_trend:
        trend = pm.Deterministic("trend", 
                               beta_trend * trend_data, 
                               dims=date_column)
    
    if dependent_variable_seasonality:
        seasonality_effect = pm.Deterministic("seasonality", 
                                            pt.dot(seasonality_data, beta_fourier), 
                                            dims=date_column)
    
    ###### MEDIATOR MODEL PARAMETERS
    mediator_intercept = pm.Normal("mediator_intercept", mu=0, sigma=config['priors']['mediator_intercept_sigma'])
    
    mediator_beta_paid_coeffs = pm.Normal("mediator_beta_paid_coeffs", 
                                        mu=0, 
                                        sigma=config['priors']['beta_mediator_paid_coeffs_sigma'], 
                                        dims="mediator_paid")
    
    mediator_alpha_paid = pm.Beta("mediator_alpha_paid", 
                                alpha=config['priors']['alpha_alpa'], 
                                beta=config['priors']['alpha_beta'], 
                                dims="mediator_paid")
    
    mediator_beta_organic_coeffs = pm.Normal("mediator_beta_organic_coeffs", 
                                           mu=0, 
                                           sigma=config['priors']['beta_mediator_organic_coeffs_sigma'], 
                                           dims="mediator_organic")
    
    mediator_alpha_organic = pm.Beta("mediator_alpha_organic", 
                                   alpha=config['priors']['alpha_alpa'], 
                                   beta=config['priors']['alpha_beta'], 
                                   dims="mediator_organic")
    
    mediator_sigma_beta_competitor = pm.HalfNormal("mediator_sigma_beta_competitor", 
                                                 sigma=config['priors']['beta_mediator_competitor_sigma'], 
                                                 dims="mediator_competitor")
    
    mediator_non_neg_beta_competitor_offset = pm.Normal("mediator_non_neg_beta_competitor_offset", 
                                                      mu=0, 
                                                      sigma=1, 
                                                      dims="mediator_competitor")
    
    mediator_beta_competitor_coeffs = pm.Deterministic("mediator_beta_competitor_coeffs", 
                                                     -mediator_sigma_beta_competitor * pt.abs(mediator_non_neg_beta_competitor_offset), 
                                                     dims="mediator_competitor")
    
    mediator_beta_control_coeffs = pm.Normal("mediator_beta_control_coeffs", 
                                           mu=0, 
                                           sigma=config['priors']['beta_mediator_control_coeffs_sigma'], 
                                           dims="mediator_control")
    
    if mediator_variable_trend:
        mediator_beta_trend = pm.HalfNormal("mediator_beta_trend", 
                                          sigma=config['priors']['beta_mediator_trend_sigma'])
    
    if mediator_variable_seasonality:
        mediator_beta_fourier = pm.Laplace("mediator_beta_fourier", 
                                         mu=0, 
                                         b=config['priors']['beta_mediator_fourier_sigma'], 
                                         dims="fourier_mode")
    
    mediator_sigma = pm.HalfNormal("mediator_sigma", sigma=1)
    
    mediator_paid_adstock = pm.Deterministic("mediator_paid_adstock",
                                           geometric_adstock(x=mediator_paid_data, 
                                                           alpha=mediator_alpha_paid, 
                                                           l_max=l_max, 
                                                           normalize=True),
                                           dims=(date_column, "mediator_paid"))
    
    mediator_organic_adstock = pm.Deterministic("mediator_organic_adstock",
                                              geometric_adstock(x=mediator_organic_data, 
                                                              alpha=mediator_alpha_organic, 
                                                              l_max=l_max, 
                                                              normalize=True),
                                              dims=(date_column, "mediator_organic"))
    
    mediator_paid_contributions = pm.Deterministic("mediator_paid_contributions",
                                                 mediator_paid_adstock * mediator_beta_paid_coeffs,
                                                 dims=(date_column, "mediator_paid"))
    
    mediator_organic_contributions = pm.Deterministic("mediator_organic_contributions",
                                                    mediator_organic_adstock * mediator_beta_organic_coeffs,
                                                    dims=(date_column, "mediator_organic"))
    
    mediator_competitor_contributions = pm.Deterministic("mediator_competitor_contributions",
                                                       mediator_competitor_data * mediator_beta_competitor_coeffs,
                                                       dims=(date_column, "mediator_competitor"))
    
    mediator_control_contributions = pm.Deterministic("mediator_control_contributions",
                                                    mediator_control_data * mediator_beta_control_coeffs,
                                                    dims=(date_column, "mediator_control"))
    
    mediator_core_components = [
        mediator_paid_contributions.sum(axis=-1),
        mediator_organic_contributions.sum(axis=-1),
        mediator_competitor_contributions.sum(axis=-1),
        mediator_control_contributions.sum(axis=-1),
        mediator_intercept
    ]
    
    if mediator_variable_trend:
        mediator_trend = pm.Deterministic("mediator_trend", 
                                        mediator_beta_trend * mediator_trend_data, 
                                        dims=date_column)
        mediator_core_components.append(mediator_trend)
    
    if mediator_variable_seasonality:
        mediator_seasonality_effect = pm.Deterministic("mediator_seasonality", 
                                                     pt.dot(seasonality_data, mediator_beta_fourier), 
                                                     dims=date_column)
        mediator_core_components.append(mediator_seasonality_effect)
    
    mu_mediator = pm.Deterministic("mu_mediator",
                                  sum(mediator_core_components),
                                  dims=date_column)
    
    mediator_obs = pm.Normal("mediator_obs", 
                           mu=mu_mediator, 
                           sigma=mediator_sigma, 
                           observed=mediator_obs_data, 
                           dims=date_column)
    
    beta_mediator_effect = pm.Normal("beta_mediator_effect", mu=0, sigma=1)
    
    #standardized_mediator = pm.Deterministic(
    #    "standardized_mediator",
    #    (mu_mediator - pt.mean(mu_mediator)) / (pt.std(mu_mediator) + 1e-8),
    #    dims=date_column
    #)
    
    mu_core_components = [
        paid_contributions.sum(axis=-1),
        organic_contributions.sum(axis=-1),
        competitor_contributions.sum(axis=-1),
        control_contributions.sum(axis=-1),
        intercept,
        beta_mediator_effect * mu_mediator
    ]
    
    if dependent_variable_trend:
        mu_core_components.append(trend)
    
    if dependent_variable_seasonality:
        mu_core_components.append(seasonality_effect)
    
    # FINAL MODEL: Combine Main and Mediator Effects
    mu_combined = pm.Deterministic("mu",
                                  sum(mu_core_components),
                                  dims=date_column)
    
    y_obs = pm.Normal("y_obs", 
                    mu=mu_combined, 
                    sigma=sigma, 
                    observed=y_obs_data, 
                    dims=date_column)
    
    mmm_prior_predictive = pm.sample_prior_predictive(samples=1000, random_seed=rng)

In [None]:
pm.model_to_graphviz(mmm)


In [None]:
rng = np.random.default_rng(42)
with mmm:
    trace = pm.sample(draws=1000, # number of samples to draw from posterior distribution
                      tune=500, # number of burn-in samples, samples that are discarded 
                      chains=4, # number Markov Chains (separate sequences of samples to pull)
                      cores=7, # how many cores to run the model with, defaults to same number of chains
                      target_accept=0.9, # default is 0.8, increasing gives models more chances to not get blocked (stuck in some param space) 
                      nuts_sampler="numpyro",
                      random_seed=rng)

    mmm_posterior_predictive = pm.sample_posterior_predictive(trace=trace,
                                                              random_seed = rng)
train_posterior_predictive_likelihood = az.extract(
    data=mmm_posterior_predictive,
    group="posterior_predictive",
    var_names="y_obs")

train_posterior_predictive_likelihood_inv = target_scaler.inverse_transform(X=train_posterior_predictive_likelihood)
train_preds = train_posterior_predictive_likelihood_inv.mean(axis=1)

mediator_train_posterior_predictive_likelihood = az.extract(
    data=mmm_posterior_predictive,
    group="posterior_predictive",
    var_names="mediator_obs")

mediator_train_posterior_predictive_likelihood_inv = mediator_scaler.inverse_transform(X=mediator_train_posterior_predictive_likelihood)
mediator_train_preds = mediator_train_posterior_predictive_likelihood_inv.mean(axis=1)


## Test data

In [None]:
with mmm:
    mmm.coords[date_column] = X_test.index

    data_dict = {
        "paid_data": X_test[FEATURES['features_paid']].to_numpy(),
        "organic_data": X_test[FEATURES['features_organic']].to_numpy(),
        "competitor_data": X_test[FEATURES['features_competitor']].to_numpy(),
        "control_data": X_test[FEATURES['features_control']].to_numpy(),
        "y_obs_data": y_test,
        
        "mediator_paid_data": X_mediator_test[MEDIATOR_FEATURES['mediator_paid']].to_numpy(),
        "mediator_organic_data": X_mediator_test[MEDIATOR_FEATURES['mediator_organic']].to_numpy(),
        "mediator_competitor_data": X_mediator_test[MEDIATOR_FEATURES['mediator_competitor']].to_numpy(),
        "mediator_control_data": X_mediator_test[MEDIATOR_FEATURES['mediator_control']].to_numpy(),
        "mediator_obs_data": y_mediator_test
    }
    
    if dependent_variable_trend:
        data_dict["trend_data"] = X_test["dep_variable_trend_feature"].to_numpy(dtype=np.float64)
    
    if dependent_variable_seasonality:
        data_dict["seasonality_data"] = X_test[SEASONALITY_CONFIG["seasonality_features"]].to_numpy()
    
    if mediator_variable_trend:
        data_dict["mediator_trend_data"] = X_mediator_test["mediator_variable_trend_feature"].to_numpy(dtype=np.float64)
    
    pm.set_data(data_dict, coords=mmm.coords)
    
    test_posterior_predictive = pm.sample_posterior_predictive(
        trace, var_names=["y_obs", "mediator_obs"], random_seed=rng)
    
    main_posterior_predictive = az.extract(
        data=test_posterior_predictive,
        group="posterior_predictive",
        var_names="y_obs")
    
    mediator_posterior_predictive = az.extract(
        data=test_posterior_predictive,
        group="posterior_predictive",
        var_names="mediator_obs")
    
    main_posterior_predictive_inv = target_scaler.inverse_transform(X=main_posterior_predictive)
    mediator_posterior_predictive_inv = mediator_scaler.inverse_transform(X=mediator_posterior_predictive)
    
    y_pred = main_posterior_predictive_inv.mean(axis=1)
    mediator_y_pred = mediator_posterior_predictive_inv.mean(axis=1)


## Analysis

In [None]:

evaluation_results = evaluate_mmm_with_mediator(
    # Main model evaluation data
    unscaled_y_train=unscaled_y_train,
    unscaled_y_test=unscaled_y_test,
    train_preds=train_preds,
    y_pred=y_pred,
    
    unscaled_y_mediator_train=unscaled_y_mediator_train,
    unscaled_y_mediator_test=unscaled_y_mediator_test,
    mediator_train_preds=mediator_train_preds,
    mediator_y_pred=mediator_y_pred
)

# Prediction Plots

In [None]:
time_series_plots = plot_mmm_time_series(
    train_index=X_train.index, 
    test_index=X_test.index,
    main_y_train=unscaled_y_train, 
    main_y_test=unscaled_y_test,
    main_train_preds=train_preds, 
    main_test_preds=y_pred,
    mediator_y_train=unscaled_y_mediator_train, 
    mediator_y_test=unscaled_y_mediator_test,
    mediator_train_preds=mediator_train_preds, 
    mediator_test_preds=mediator_y_pred,
    main_y_label="Watchtime", 
    mediator_y_label="DAV",
)


## Bayesian Traces

In [None]:
az.plot_trace(trace, var_names=('beta'), filter_vars="like")

## Coefficients

In [None]:
comparison_table = compare_coefficients(trace)

## Efficiencies

In [None]:
fig_plotly = plot_normalized_coefficients(
    trace=trace,
    include_mediator=True,
    figsize=(14, 12)
)
fig_plotly.show()


## Contributions

### Main Model

In [None]:
unadj_contributions, adj_contributions = calculate_contributions(
    trace=trace, 
    X_train=X_train,
    original_paid_features=FEATURES['features_paid'],
    original_competitor_features=FEATURES['features_competitor'],
    original_organic_features=FEATURES['features_organic'],
    original_control_features=FEATURES['features_control'],
    seasonality=False, 
    trend=False, 
    intercept=False
)
contributions_df, fig = plot_contributions(adj_contributions, keep_intercept_trend_season=True)
fig


### Mediator

In [None]:
mediator_unadj, mediator_adj = calculate_mediator_contributions(
    trace=trace, 
    X_train=X_train,
    mediator_paid_features=MEDIATOR_FEATURES['mediator_paid'],
    mediator_competitor_features=MEDIATOR_FEATURES['mediator_competitor'],
    mediator_organic_features=MEDIATOR_FEATURES['mediator_organic'],
    mediator_control_features=MEDIATOR_FEATURES['mediator_control'],
    seasonality=False, 
    trend=False, 
    intercept=False
)
mediator_contributions_df, mediator_fig = plot_contributions(
    mediator_adj, 
    keep_intercept_trend_season=True
)
mediator_fig


# Adstock

In [None]:
organic_figures = plot_adstock_effects(
    trace=trace,
    feature_type="organic",  # Use "organic" instead of "paid"
    include_mediator=True,
    l_max=16,
    impulse_value=100
)

organic_figures['main'].show()

if 'mediator' in organic_figures:
    organic_figures['mediator'].show()
organic_combined_fig = plot_combined_adstock_effects(
    trace=trace,
    feature_type="organic",
    l_max=16
)
organic_combined_fig.show()

paid_figures = plot_adstock_effects(
    trace=trace,
    feature_type="paid",
    include_mediator=True
)
paid_figures['main'].show()

alpha = az.summary(trace, var_names='alpha_organic', round_to=3)['mean'].values
print(alpha)
original_organic_features = list(trace.posterior.coords["organic"].values)

data = pd.DataFrame({channel: [100] + [0]* 15 for channel in original_organic_features})
data[[paid + '_adstocked' for paid in original_organic_features]] = (
    geometric_adstock(x=data[original_organic_features], alpha=alpha, l_max=16, normalize=True).eval()
)

for channel in original_organic_features:
    adstocked_column = channel + '_adstocked'
    cumulative_column = channel + '_cumulative'
    data[cumulative_column] = data[adstocked_column].cumsum()

fig = go.Figure()
for channel in original_organic_features:
    cumulative_column = channel + '_cumulative'
    fig.add_trace(go.Scatter(
        x=data.index, 
        y=data[cumulative_column], 
        mode='lines+markers', 
        name=channel.capitalize()
    ))
    
fig.update_layout(
    title='Cumulative Adstock Effects for Each Campaign',
    xaxis_title='Weeks',
    yaxis_title='Cumulative Adstock Effect',
    legend_title='Campaign',
    template='plotly_white',
    width=1200,
    height=800
)

fig.show()

## Contributions Over Time

In [None]:
unadj_contributions, adj_contributions = calculate_contributions(
    trace, 
    X_train, 
    original_paid_features, 
    original_competitor_features, 
    original_organic_features, 
    original_control_features,
    seasonality=False, 
    trend=False, 
    intercept=False
)

fig = plot_contributions_over_time(
    unadj_contributions=unadj_contributions,
    plot_order=['DAVs_daily_active_viewers_viewers_activity', 'youtube_brand_index', 'novelty_score_WT'],
    title="Predicted Watchtime and Breakdown"
)

fig.show()

In [None]:
mediator_unadj, mediator_adj = calculate_mediator_contributions(
    trace, 
    X_train, 
    mediator_paid_features, 
    mediator_competitor_features, 
    mediator_organic_features, 
    mediator_control_features,
    seasonality=False, 
    trend=False, 
    intercept=False
)

mediator_fig = plot_contributions_over_time(
    unadj_contributions=mediator_unadj,
    plot_order=['DAVs_daily_active_viewers_viewers_activity', 'youtube_brand_index', 'novelty_score_WT'],
    title="Predicted Mediator Contribution and Breakdown"
)

mediator_fig.show()