# Minimal Media Mix Model Example

This notebook demonstrates fitting a minimal Media Mix Model (MMM) using PyMC-Marketing with default parameters on synthetic marketing data.

## Dataset Overview

The dataset contains:
- **Time period**: 104 weeks (2 years) from 2020-01-05 to 2021-12-26
- **Marketing channels**: Search Ads, Social Media, Local Ads, Email
- **Control variables**: Event (c1), Sale (c2)
- **Target**: Weekly sales (y)

See `DATA.md` for complete dataset documentation.

## Setup

In [None]:
from pathlib import Path

import polars as pl
from pymc_marketing.mmm import MMM, GeometricAdstock, LogisticSaturation
from rich import print as rprint
from rich.console import Console
from rich.table import Table

## Load Data

Load the main MMM dataset using Polars for efficient data handling.

In [None]:
def load_mmm_data(data_path: str | Path) -> pl.DataFrame:
    """Load MMM data from CSV file.
    
    Args:
        data_path: Path to the mmm_data.csv file
        
    Returns:
        Polars DataFrame with parsed date column
    """
    return pl.read_csv(data_path).with_columns(
        pl.col("date").str.to_date()
    )


# Load the data
data_path = Path("../data/mmm-simple/mmm_data.csv")
df = load_mmm_data(data_path)

rprint(f"[bold green]Data loaded successfully[/bold green]")
rprint(f"Shape: {df.shape[0]} rows × {df.shape[1]} columns")
rprint(f"Date range: {df['date'].min()} to {df['date'].max()}")

## Data Preview

In [None]:
def display_data_preview(df: pl.DataFrame, n_rows: int = 5) -> None:
    """Display data preview using Rich table.
    
    Args:
        df: Input DataFrame
        n_rows: Number of rows to display
    """
    console = Console()
    table = Table(title="MMM Data Preview", show_header=True, header_style="bold magenta")
    
    # Add columns
    for col in df.columns:
        table.add_column(col)
    
    # Add rows
    for row in df.head(n_rows).iter_rows():
        table.add_row(*[str(val) for val in row])
    
    console.print(table)


display_data_preview(df)

## Data Summary Statistics

In [None]:
def display_summary_stats(df: pl.DataFrame) -> None:
    """Display summary statistics for numeric columns.
    
    Args:
        df: Input DataFrame
    """
    # Select numeric columns (excluding date and geo)
    numeric_cols = [col for col in df.columns if col not in ["date", "geo"]]
    
    stats = df.select(numeric_cols).describe()
    
    console = Console()
    table = Table(title="Summary Statistics", show_header=True, header_style="bold cyan")
    
    # Add columns
    table.add_column("Statistic")
    for col in stats.columns:
        table.add_column(col)
    
    # Add rows
    for row in stats.iter_rows():
        table.add_row(*[f"{val:.2f}" if isinstance(val, float) else str(val) for val in row])
    
    console.print(table)


display_summary_stats(df)

## Prepare Data for MMM

Convert to pandas DataFrame (required by PyMC-Marketing) and define channel columns.

In [None]:
# Convert to pandas (PyMC-Marketing currently requires pandas)
df_pandas = df.to_pandas()

# Define model inputs
channel_columns = [
    "x1_Search-Ads",
    "x2_Social-Media",
    "x3_Local-Ads",
    "x4_Email"
]

control_columns = ["c1", "c2"]

rprint("[bold blue]Model Configuration:[/bold blue]")
rprint(f"Target column: [yellow]y[/yellow]")
rprint(f"Date column: [yellow]date[/yellow]")
rprint(f"Channel columns: [yellow]{channel_columns}[/yellow]")
rprint(f"Control columns: [yellow]{control_columns}[/yellow]")

## Initialize MMM Model

Create an `MMM` model with default geometric adstock and logistic saturation.

In [None]:
def create_mmm_model(
    date_column: str,
    channel_columns: list[str],
    control_columns: list[str],
    adstock_max_lag: int = 8,
    yearly_seasonality: int = 2
) -> MMM:
    """Create an MMM model instance.
    
    Args:
        date_column: Name of date column
        channel_columns: List of marketing channel column names
        control_columns: List of control variable column names
        adstock_max_lag: Maximum lag for adstock transformation
        yearly_seasonality: Number of Fourier modes for yearly seasonality
        
    Returns:
        Initialized MMM model instance
    """
    return MMM(
        date_column=date_column,
        channel_columns=channel_columns,
        control_columns=control_columns,
        adstock=GeometricAdstock(l_max=adstock_max_lag),
        saturation=LogisticSaturation(),
        yearly_seasonality=yearly_seasonality
    )


# Initialize model
mmm = create_mmm_model(
    date_column="date",
    channel_columns=channel_columns,
    control_columns=control_columns,
    adstock_max_lag=8,
    yearly_seasonality=2
)

rprint("[bold green]MMM model initialized successfully[/bold green]")

## Fit the Model

Fit the MMM model using MCMC sampling with default parameters.

**Note**: This may take several minutes depending on your hardware.

In [None]:
import pandas as pd


def fit_mmm(
    model: MMM,
    X: pd.DataFrame,
    y: pd.Series,
    draws: int = 1000,
    tune: int = 1000,
    chains: int = 2,
    random_seed: int = 42
) -> None:
    """Fit MMM model to data.
    
    Args:
        model: MMM model instance
        X: Input features DataFrame
        y: Target variable Series
        draws: Number of MCMC draws per chain
        tune: Number of tuning steps
        chains: Number of MCMC chains
        random_seed: Random seed for reproducibility
    """
    rprint("[bold yellow]Starting model fitting...[/bold yellow]")
    rprint(f"Draws: {draws}, Tune: {tune}, Chains: {chains}")
    
    model.fit(
        X=X,
        y=y,
        draws=draws,
        tune=tune,
        chains=chains,
        random_seed=random_seed
    )
    
    rprint("[bold green]Model fitting completed successfully![/bold green]")


# Prepare X and y
X = df_pandas.drop(columns=["y"])
y = df_pandas["y"]

# Fit the model
fit_mmm(
    model=mmm,
    X=X,
    y=y,
    draws=1000,
    tune=1000,
    chains=2,
    random_seed=42
)

## Model Summary

Display model fit summary statistics.

In [None]:
import arviz as az


def display_fit_summary(model: MMM) -> None:
    """Display model fit summary.
    
    Args:
        model: Fitted MMM model
    """
    summary = az.summary(model.idata)
    rprint("[bold magenta]Model Fit Summary:[/bold magenta]")
    rprint(summary)


display_fit_summary(mmm)

## Channel Contributions

Compute channel contributions over time.

In [None]:
def compute_channel_contributions(model: MMM) -> pd.DataFrame:
    """Compute mean channel contributions over time in original scale.
    
    Args:
        model: Fitted MMM model
        
    Returns:
        DataFrame with channel contributions in original scale
    """
    contributions = model.compute_mean_contributions_over_time(original_scale=True)
    rprint("[bold green]Channel contributions computed (original scale)[/bold green]")
    return contributions


contributions = compute_channel_contributions(mmm)
rprint(contributions.head())

## Compute ROAS

Calculate Return on Ad Spend (ROAS) for each channel and compare with ground truth.

In [None]:
import json


def compute_and_compare_roas(
    contributions: pd.DataFrame,
    channel_spend: pd.DataFrame,
    channel_columns: list[str],
    ground_truth_path: str | Path
) -> None:
    """Compute ROAS and compare with ground truth.
    
    Args:
        contributions: DataFrame with channel contributions
        channel_spend: DataFrame with channel spend data
        channel_columns: List of channel column names
        ground_truth_path: Path to ground truth parameters JSON
    """
    # Compute ROAS: total contribution / total spend
    total_contributions = contributions[channel_columns].sum()
    total_spend = channel_spend[channel_columns].sum()
    roas = total_contributions / total_spend
    
    # Load ground truth
    with open(ground_truth_path) as f:
        ground_truth = json.load(f)
    
    true_roas = ground_truth["roas_values"]["Local"]
    
    # Create comparison table
    console = Console()
    table = Table(
        title="ROAS Comparison: Estimated vs Ground Truth",
        show_header=True,
        header_style="bold green"
    )
    
    table.add_column("Channel", style="cyan")
    table.add_column("Estimated ROAS", justify="right")
    table.add_column("True ROAS", justify="right")
    table.add_column("Error %", justify="right")
    
    for channel in channel_columns:
        # Extract base channel name (remove x1_, x2_, etc. prefix)
        channel_name = channel.split('_', 1)[1] if '_' in channel else channel
        
        est_val = roas[channel]
        true_val = true_roas.get(channel_name, 0.0)
        error_pct = ((est_val - true_val) / true_val * 100) if true_val != 0 else 0.0
        
        table.add_row(
            channel_name,
            f"{est_val:.2f}",
            f"{true_val:.2f}",
            f"{error_pct:+.1f}%"
        )
    
    console.print(table)


# Compute and compare ROAS
ground_truth_path = Path("../data/mmm-simple/ground_truth_parameters.json")
compute_and_compare_roas(
    contributions=contributions,
    channel_spend=df_pandas,
    channel_columns=channel_columns,
    ground_truth_path=ground_truth_path
)

## Model Diagnostics

Plot trace plots and posterior distributions for key parameters.

In [None]:
import arviz as az
import matplotlib.pyplot as plt


def plot_trace_diagnostics(model: MMM) -> None:
    """Plot MCMC trace diagnostics.
    
    Args:
        model: Fitted MMM model
    """
    rprint("[bold blue]Plotting trace diagnostics...[/bold blue]")
    
    # Plot trace for key parameters
    az.plot_trace(
        model.idata,
        var_names=["intercept", "beta_channel", "likelihood_sigma"],
        compact=True,
        figsize=(12, 8)
    )
    plt.tight_layout()
    plt.show()


plot_trace_diagnostics(mmm)

## Save Model

Save the fitted model for later use.

In [None]:
def save_model(model: MMM, output_path: str | Path) -> None:
    """Save fitted model to disk.
    
    Args:
        model: Fitted MMM model
        output_path: Path to save model
    """
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    model.save(str(output_path))
    rprint(f"[bold green]Model saved to {output_path}[/bold green]")


# Save model
model_output_path = Path("../models/mmm_minimal_default.nc")
save_model(mmm, model_output_path)

## Summary

This notebook demonstrated:

1. Loading synthetic MMM data with Polars
2. Initializing an `MMM` model with GeometricAdstock and LogisticSaturation
3. Fitting the model using MCMC sampling
4. Computing channel contributions and ROAS
5. Comparing estimated ROAS with ground truth values
6. Visualizing model diagnostics
7. Saving the fitted model

### Next Steps

- Experiment with different priors and model configurations
- Optimize hyperparameters using Optuna
- Analyze adstock and saturation parameters
- Perform out-of-sample validation