Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 544062601
Change-Id: Ifbc8967d1d78fbce458700ec6fd9f562c977250c
  • Loading branch information
michevan authored and Copybara-Service committed Jun 28, 2023
1 parent bb1c1a8 commit b4c99fa
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
2 changes: 1 addition & 1 deletion lightweight_mmm/lightweight_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ def fit(

self.custom_priors = custom_priors
if media_names is not None:
self.media_names = media_names
self.media_names = list(media_names)
else:
self.media_names = [f"channel_{i}" for i in range(media.shape[1])]
self.n_media_channels = media.shape[1]
Expand Down
15 changes: 15 additions & 0 deletions lightweight_mmm/lightweight_mmm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import jax.numpy as jnp
import numpy as np
import numpyro.distributions as dist
import pandas as pd

from lightweight_mmm import lightweight_mmm
from lightweight_mmm import models
Expand Down Expand Up @@ -463,5 +464,19 @@ def test_fitted_mmm_does_not_equal_default_mmm(self, media_mix_model):
fitted_mmm_object = getattr(self, media_mix_model)
self.assertNotEqual(default_mmm_object, fitted_mmm_object)

def test_equality_function_works_with_media_names_as_pandas_index(self):
mmm_object = lightweight_mmm.LightweightMMM()
mmm_object.fit(
media=jnp.ones((50, 5)),
target=jnp.ones(50),
media_prior=jnp.ones(5) * 50,
extra_features=jnp.ones((50, 2)),
number_warmup=2,
number_samples=4,
number_chains=1,
media_names=pd.Index([f'channel_{i}' for i in range(5)]))

self.assertEqual(mmm_object, mmm_object)

if __name__ == "__main__":
absltest.main()

0 comments on commit b4c99fa

Please sign in to comment.