In [1]:
import json 
import pandas as pd
import numpy as np
import jax.numpy as jnp
import numpyro
from sklearn.feature_selection import SelectKBest, chi2
from matplotlib import pyplot as plt
from lightweight_mmm import lightweight_mmm
from lightweight_mmm import optimize_media
from lightweight_mmm import preprocessing
from lightweight_mmm import utils
from lightweight_mmm import plot

df=pd.read_csv('Brother_MMM.csv')
df=df[df['Category']=='B2B Printing']

In [2]:

def open_json(string):                                     
    a = open(string)
    data_a = json.load(a)
    return data_a

In [6]:
def model_summary_plot(df,input_json,model_name,number_warmup,number_samples,number_chains,weekday_seasonality,degrees_seasonality):
    inp=open_json('mmm_input.json')
    
    other_feautures=df[inp['other_feautures']].to_numpy()
    costs=df[inp['costs']].sum().to_numpy()
    target=df[inp['target']].to_numpy()
    media_data=df[inp['media_data']].to_numpy()
        
    split_point = int(len(df) * 0.8)
    media_data_train = media_data[:split_point]
    media_data_test = media_data[split_point:]
    target_train=target[:split_point].reshape(-1)
    
    media_scaler =  preprocessing.CustomScaler(divide_operation=jnp.mean)
    target_scaler = preprocessing.CustomScaler(divide_operation=jnp.mean)
    cost_scaler =   preprocessing.CustomScaler(divide_operation=jnp.mean)
    
    media_data_train = media_scaler.fit_transform(media_data_train)
    target_train = target_scaler.fit_transform(target_train)
    costs2 = cost_scaler.fit_transform(costs)
    
    selector = SelectKBest(chi2, k=5)  
    selector.fit(other_feautures, target)
    df = pd.DataFrame(other_feautures)
    best_features = df.columns[selector.get_support()]
    
    other_feautures = df[best_features].to_numpy()
    
    mmm = lightweight_mmm.LightweightMMM(model_name=model_name)
    
    number_warmup=number_samples
    number_samples=number_samples
    
    mmm.fit(
    media=media_data_train,
    media_prior=costs2,
    target=target_train,
    number_warmup=number_warmup,
    number_samples=number_samples,target_accept_prob=0.85,
    number_chains=number_chains,weekday_seasonality=weekday_seasonality,degrees_seasonality=degrees_seasonality
    )
    
    mmm.print_summary()
    
    model = plot.plot_model_fit(media_mix_model=mmm, target_scaler=target_scaler)
    response = plot.plot_response_curves(media_mix_model=mmm, media_scaler=media_scaler, target_scaler=target_scaler)
    media_effect_hat, roi_hat = mmm.get_posterior_metrics()
    media_effect_hat = plot.plot_bars_media_metrics(metric=media_effect_hat,channel_names=["Paid_Search", "Paid_Social", "display_cost","eml_COST"])
    roi_hat = plot.plot_bars_media_metrics(metric=roi_hat,channel_names=["Paid_Search", "Paid_Social", "display_cost","eml_COST"])
    
    model.savefig('model.png')
    response.savefig('response.png')
    media_effect_hat.savefig('media_effect_hat.png')
    roi_hat.savefig('roi_hat.png')
    
    plt.close(model)
    plt.close(response)
    plt.close(media_effect_hat)
    plt.close(roi_hat)
    

In [20]:
mmm=model_summary_plot(df,'mmm_input.json','hill_adstock',500,1000,2,True,6)

  mcmc = numpyro.infer.MCMC(
sample: 100%|██████████████████████████| 2000/2000 [01:19<00:00, 25.08it/s, 255 steps of size 1.61e-02. acc. prob=0.92]
sample: 100%|██████████████████████████| 2000/2000 [01:11<00:00, 28.15it/s, 255 steps of size 1.60e-02. acc. prob=0.96]



                                         mean       std    median      5.0%     95.0%     n_eff     r_hat
                      coef_media[0]      1.17      0.20      1.16      0.85      1.47    286.71      1.00
                      coef_media[1]      0.14      0.13      0.11      0.00      0.29    562.74      1.00
                      coef_media[2]      0.41      0.32      0.35      0.22      0.49     61.82      1.03
                      coef_media[3]      0.00      0.00      0.00      0.00      0.00    773.41      1.00
                      coef_trend[0]     -0.02      0.01     -0.02     -0.04     -0.00    216.18      1.01
                         expo_trend      0.66      0.15      0.62      0.50      0.89    424.73      1.00
             gamma_seasonality[0,0]     -0.17      0.04     -0.17     -0.25     -0.11    215.00      1.00
             gamma_seasonality[0,1]      0.05      0.03      0.05      0.01      0.09    335.46      1.00
             gamma_seasonality[1,0]      0.05

