Skip to content

Commit

Permalink
Add prices into starting values calculation.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 479303284
Change-Id: I9a8569e323301c651d8feec24926a6cae286bab2
  • Loading branch information
pabloduque0 authored and Copybara-Service committed Oct 6, 2022
1 parent 5a1c1b5 commit c91b591
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
39 changes: 32 additions & 7 deletions lightweight_mmm/optimize_media.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,11 @@ def _get_lower_and_upper_bounds(
if media.ndim == 3:
lower_pct = jnp.expand_dims(lower_pct, axis=-1)
upper_pct = jnp.expand_dims(upper_pct, axis=-1)

mean_data = media.mean(axis=0)
lower_bounds = jnp.maximum(mean_data * (1 - lower_pct), 0)
upper_bounds = mean_data * (1 + upper_pct)

if media_scaler:
lower_bounds = media_scaler.inverse_transform(lower_bounds)
upper_bounds = media_scaler.inverse_transform(upper_bounds)
Expand All @@ -142,9 +144,12 @@ def _get_lower_and_upper_bounds(
ub=upper_bounds * n_time_periods)


def _generate_starting_values(n_time_periods: int, media: jnp.ndarray,
media_scaler: preprocessing.CustomScaler,
budget: Union[float, int]) -> jnp.ndarray:
def _generate_starting_values(
n_time_periods: int, media: jnp.ndarray,
media_scaler: preprocessing.CustomScaler,
budget: Union[float, int],
prices: jnp.ndarray,
) -> jnp.ndarray:
"""Generates starting values based on historic allocation and budget.
In order to make a comparison we can take the allocation of the last
Expand All @@ -157,6 +162,8 @@ def _generate_starting_values(n_time_periods: int, media: jnp.ndarray,
media: Historic media data the model was trained with.
media_scaler: Scaler that was used to scale the media data before training.
budget: Total budget to allocate during the optimization time.
prices: An array with shape (n_media_channels,) for the cost of each media
channel unit.
Returns:
An array with the starting value for each media channel for the
Expand All @@ -169,8 +176,11 @@ def _generate_starting_values(n_time_periods: int, media: jnp.ndarray,
if media.ndim == 3:
previous_allocation = previous_allocation.sum(axis=-1)

multiplier = budget / previous_allocation.sum()
return previous_allocation * multiplier
avg_spend_per_channel = previous_allocation * prices
pct_spend_per_channel = avg_spend_per_channel / avg_spend_per_channel.sum()
budget_per_channel = budget * pct_spend_per_channel
media_unit_per_channel = budget_per_channel / prices
return media_unit_per_channel


def find_optimal_budgets(
Expand All @@ -185,6 +195,8 @@ def find_optimal_budgets(
bounds_lower_pct: Union[float, jnp.ndarray] = .2,
bounds_upper_pct: Union[float, jnp.ndarray] = .2,
max_iterations: int = 200,
solver_func_tolerance: float = 1e-06,
solver_step_size: float = 1.4901161193847656e-08,
seed: Optional[int] = None) -> optimize.OptimizeResult:
"""Finds the best media allocation based on MMM model, prices and a budget.
Expand All @@ -210,6 +222,14 @@ def find_optimal_budgets(
consider as new upper bound.
max_iterations: Number of max iterations to use for the SLSQP scipy
optimizer. Default is 200.
solver_func_tolerance: Precision goal for the value of the prediction in
the stopping criterion. Maps directly to scipy's `ftol`. Intended only
for advanced users. For more details see:
https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.
solver_step_size: Step size used for numerical approximation of the
Jacobian. Maps directly to scipy's `eps`. Intended only for advanced
users. For more details see:
https://docs.scipy.org/doc/scipy/reference/optimize.minimize-slsqp.html#optimize-minimize-slsqp.
seed: Seed to use for PRNGKey during sampling. For replicability run
this function and any other function that gets predictions with the same
seed.
Expand Down Expand Up @@ -255,7 +275,9 @@ def find_optimal_budgets(
n_time_periods=n_time_periods,
media=media_mix_model.media,
media_scaler=media_scaler,
budget=budget)
budget=budget,
prices=prices,
)
if not media_scaler:
media_scaler = preprocessing.CustomScaler(multiply_by=1, divide_by=1)
if media_mix_model.n_geos == 1:
Expand All @@ -274,9 +296,12 @@ def find_optimal_budgets(
x0=starting_values,
bounds=bounds,
method="SLSQP",
jac="3-point",
options={
"maxiter": max_iterations,
"disp": True
"disp": True,
"ftol": solver_func_tolerance,
"eps": solver_step_size,
},
constraints={
"type": "eq",
Expand Down
38 changes: 38 additions & 0 deletions lightweight_mmm/optimize_media_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,44 @@ def test_find_optimal_budgets_has_right_output_length_datatype(
self.assertIsInstance(results[1], jax.Array)
self.assertIsInstance(results[2], jax.Array)

@parameterized.named_parameters([
dict(
testcase_name="national_prices",
model_name="national_mmm",
prices=np.array([1., 0.8, 1.2, 1.5, 0.5]),
),
dict(
testcase_name="national_ones",
model_name="national_mmm",
prices=np.ones(5),
),
dict(
testcase_name="geo_prices",
model_name="geo_mmm",
prices=np.array([1., 0.8, 1.2, 1.5, 0.5]),
),
dict(
testcase_name="geo_ones",
model_name="geo_mmm",
prices=np.ones(5),
),
])
def test_generate_starting_values_calculates_correct_values(
self, model_name, prices):
mmm = getattr(self, model_name)
n_time_periods = 10
budget = mmm.n_media_channels * n_time_periods
starting_values = optimize_media._generate_starting_values(
n_time_periods=10,
media_scaler=None,
media=mmm.media,
budget=budget,
prices=prices,
)

# Given that data is all ones, starting values will be equal to prices.
np.testing.assert_array_almost_equal(
starting_values, jnp.repeat(n_time_periods, repeats=5))

if __name__ == "__main__":
absltest.main()

0 comments on commit c91b591

Please sign in to comment.