In [1]:
import numpy as np
import pandas as pd

from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.transformers import geometric_adstock, logistic_saturation
from pymc_marketing.prior import Prior
from mmm_eval.data import generate_data
from mmm_eval import PYMCConfig

from mmm_eval.utils import PyMCConfigRehydrator
from mmm_eval.cli.evaluate import load_config

# Generate data

In [2]:
data = generate_data()
data.to_csv("data.csv", index=False)

X = data.drop(columns=["revenue","quantity"])
y = data["quantity"]

# Fit PyMC

In [3]:
my_model = MMM(
    date_column="date_week" ,
    channel_columns=["channel_1","channel_2"],
    adstock=GeometricAdstock(l_max=4),
    saturation=LogisticSaturation()
)

my_model.fit(X=X, y=y, chains=4, target_accept=0.85)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, adstock_alpha, saturation_lam, saturation_beta, y_sigma]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 4 seconds.
There were 2 divergences after tuning. Increase `target_accept` or reparameterize.


# Save Config to JSON

Copy the args for MMM() and .fit() into dict structure. Loader will handle string representation.
Need to provide target column name b/c all columns will be in one CSV in the input data.

In [4]:
fit_kwargs = {
    "X": X,  
    "y": y,   
    "chains": 4,
    "target_accept": 0.85,
}
target_column = "revenue"

pymc_config = PYMCConfig(my_model, fit_kwargs=fit_kwargs, target_column=target_column)
pymc_config.model_config.config

{'date_column': 'date_week',
 'channel_columns': ['channel_1', 'channel_2'],
 'adstock': GeometricAdstock(prefix='adstock', l_max=4, normalize=True, mode='After', priors={'alpha': Prior("Beta", alpha=1, beta=3, dims="channel")}),
 'saturation': LogisticSaturation(prefix='saturation', priors={'lam': Prior("Gamma", alpha=3, beta=1, dims="channel"), 'beta': Prior("HalfNormal", sigma=2, dims="channel")}),
 'time_varying_intercept': False,
 'time_varying_media': False,
 'sampler_config': {},
 'validate_data': True,
 'control_columns': None,
 'yearly_seasonality': None,
 'adstock_first': True,
 'dag': None,
 'treatment_nodes': None,
 'outcome_node': None}

## Check hydration

In [5]:
print(pymc_config.model_config.is_hydrated)
print(pymc_config.fit_kwargs.is_hydrated)

True
True


## Save to json and load back into memory

In [6]:
pymc_config.save_config_to_json(save_path=".", file_name="test_config")

<mmm_eval.data.loaders.PYMCConfig at 0x11ded1b90>

In [7]:
## Load from json (no hydration)

dehydrated_config = load_config("test_config.json")
dehydrated_config

{'model_config': {'date_column': "'date_week'",
  'channel_columns': "['channel_1', 'channel_2']",
  'adstock': 'GeometricAdstock(prefix=\'adstock\', l_max=4, normalize=True, mode=\'After\', priors={\'alpha\': Prior("Beta", alpha=1, beta=3, dims="channel")})',
  'saturation': 'LogisticSaturation(prefix=\'saturation\', priors={\'lam\': Prior("Gamma", alpha=3, beta=1, dims="channel"), \'beta\': Prior("HalfNormal", sigma=2, dims="channel")})',
  'time_varying_intercept': 'False',
  'time_varying_media': 'False',
  'sampler_config': '{}',
  'validate_data': 'True',
  'control_columns': 'None',
  'yearly_seasonality': 'None',
  'adstock_first': 'True',
  'dag': 'None',
  'treatment_nodes': 'None',
  'outcome_node': 'None'},
 'fit_config': {'X': '     date_week     price  channel_1  channel_2  event_1  event_2  dayofyear\n0   2018-04-02  5.000322  31.858002   0.000000      0.0      0.0         92\n1   2018-04-09  5.015090  11.238848   0.000000      0.0      0.0         99\n2   2018-04-16  5.

In [8]:
pymc_config.load_config_from_json(".", "test_config")
print(pymc_config.model_config.is_hydrated)
print(pymc_config.fit_kwargs.is_hydrated)

ValidationError: 6 validation errors for PyMCModelSchema
channel_columns
  Input should be a valid list [type=list_type, input_value="['channel_1', 'channel_2']", input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/list_type
sampler_config
  Input should be a valid dictionary [type=dict_type, input_value='{}', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/dict_type
control_columns
  Input should be a valid list [type=list_type, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/list_type
yearly_seasonality
  Input should be a valid integer, unable to parse string as an integer [type=int_parsing, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/int_parsing
treatment_nodes.list[str]
  Input should be a valid list [type=list_type, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/list_type
treatment_nodes.tuple[str]
  Input should be a valid tuple [type=tuple_type, input_value='None', input_type=str]
    For further information visit https://errors.pydantic.dev/2.11/v/tuple_type

In [None]:
c["model_config"]

{'date_column': "'date_week'",
 'channel_columns': "['x1', 'x2']",
 'adstock': 'GeometricAdstock(prefix=\'adstock\', l_max=4, normalize=True, mode=\'After\', priors={\'alpha\': Prior("Beta", alpha=1, beta=3)})',
 'saturation': 'LogisticSaturation(prefix=\'saturation\', priors={\'lam\': Prior("Gamma", alpha=3, beta=1), \'beta\': Prior("HalfNormal", sigma=2)})'}

# Fit a model with the input data + rehydrated config

In [26]:
import pandas as pd

df = pd.read_csv("data.csv")
df["date_week"] = pd.to_datetime(df["date_week"])

model_config = pymc_config.model_config.config
fit_config = pymc_config.fit_kwargs.config
target_column = pymc_config.target_column

X = df.drop(columns=[target_column])
y = df[target_column]

In [27]:
model_config

{'date_column': "'date_week'",
 'channel_columns': "['channel_1', 'channel_2']",
 'adstock': GeometricAdstock(prefix='adstock', l_max=4, normalize=True, mode='After', priors={'alpha': Prior("Beta", alpha=1, beta=3, dims="channel")}),
 'saturation': LogisticSaturation(prefix='saturation', priors={'lam': Prior("Gamma", alpha=3, beta=1, dims="channel"), 'beta': Prior("HalfNormal", sigma=2, dims="channel")}),
 'time_varying_intercept': 'False',
 'time_varying_media': 'False',
 'sampler_config': '{}',
 'validate_data': 'True',
 'control_columns': 'None',
 'yearly_seasonality': 'None',
 'adstock_first': 'True',
 'dag': 'None',
 'treatment_nodes': 'None',
 'outcome_node': 'None'}

In [28]:
date_column = eval(model_config["date_column"])
channel_columns = list(eval(model_config["channel_columns"]))
adstock = model_config["adstock"]
saturation = model_config["saturation"]

m2 = MMM(
    date_column=date_column,
    channel_columns=channel_columns,
    adstock=adstock,
    saturation=saturation
)


In [29]:
chains = int(eval(fit_config["chains"]))
target_accept = float(fit_config["target_accept"])

m2.fit(X=X, y=y, chains=chains, target_accept=target_accept)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [intercept, adstock_alpha, saturation_lam, saturation_beta, y_sigma]


Output()

Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 5 seconds.
There were 16 divergences after tuning. Increase `target_accept` or reparameterize.
