In [None]:
import pymc as pm
import pandas as pd
import numpy as np
from scipy import stats
from typing import Any, Tuple, Dict, Type, List, Union, Optional

class DistSpec:
    """
    Container to lazily define a PyMC distribution outside of the pm.Model context.
    It holds the class and arguments, inferring the variable name later.
    """
    def __init__(self, dist_class: Type[pm.Distribution], *args: Any, **kwargs: Any):
        """
        Args:
            dist_class: The PyMC distribution class (e.g., pm.Normal, pm.Uniform).
            *args: Positional arguments for the distribution (should NOT include the name).
            **kwargs: Keyword arguments for the distribution.
        """
        self._dist_class = dist_class
        self._args = args
        self._kwargs = kwargs

    def build(self, name: str) -> pm.Distribution:
        """Instantiates the PyMC distribution within the current model context, 
        using the provided 'name' for the variable."""
        # PyMC automatically uses the active context
        return self._dist_class(name, *self._args, **self._kwargs)

def quap(
    priors: Dict[str, DistSpec],
    data: pd.DataFrame,
    x_col: str, 
    y_col: str,
    model_vars: List[str] = ['a', 'b', 'sigma'], # Define required variables here
    interval_prob: float = 0.89 
) -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Performs a Quadratic Approximation of the Posterior (QUAP) for a linear model 
    (Y ~ a + b*X) using PyMC's MAP finding and Laplace Approximation.

    Args:
        priors: Dictionary mapping variable names ('a', 'b', 'sigma') to DistSpec objects.
        data: DataFrame containing the observation data.
        x_col: Name of the independent variable column (e.g., 'height').
        y_col: Name of the dependent variable column (e.g., 'weight').
        model_vars: List of variable names to analyze in the model.
        interval_prob: The probability mass for the credible interval (default 0.89).

    Returns:
        A tuple: (DataFrame of results, Dictionary of MAP/Covariance).
    """
    # 1. Input Validation and Preparation
    if not all(v in priors for v in model_vars):
         raise ValueError(f"Priors for all model variables {model_vars} are required.")
    if x_col not in data.columns or y_col not in data.columns:
        raise ValueError(f"DataFrame must contain both '{x_col}' and '{y_col}'.")
    
    X_data = data[x_col].values
    Y_data = data[y_col].values
    
    # Calculate the quantiles
    tail_prob = (1 - interval_prob) / 2
    lower_quantile = tail_prob
    upper_quantile = 1 - tail_prob

    # 2. Model Definition and Fitting
    with pm.Model() as model:
        # Build all prior variables dynamically using the dict key as the name
        built_priors = {}
        for name, spec in priors.items():
            built_priors[name] = spec.build(name=name)
        
        # Access variables by their names (e.g., built_priors['a'])
        a_var = built_priors['a']
        b_var = built_priors['b']
        sigma_var = built_priors['sigma']

        # Linear Predictor (mu = a + b*X). This section is hardcoded to the 
        # linear model structure (a + b*X) for this specific function.
        mu = pm.Deterministic('mu', a_var + b_var * X_data)

        # Likelihood (Observation)
        Y_obs = pm.Normal('Y_obs', mu=mu, sigma=sigma_var, observed=Y_data)

        # MAP and Hessian Calculation (QUAP)
        mp = pm.find_MAP(progressbar=False)
        hess = pm.find_hessian(mp, model=model, negate_output=False) 
        covariance_matrix = np.linalg.inv(-hess)
        stds = np.sqrt(np.diag(covariance_matrix))

    # 3. Results Extraction and Interval Calculation
    
    result_cols = ['mean', 'std', f'{lower_quantile*100:.1f}%', f'{upper_quantile*100:.1f}%']
    result = pd.DataFrame(columns=result_cols)
    
    # Use model_vars list to control which results are printed
    for i, var in enumerate(model_vars):
        map_point = float(mp[var]) 
        sd = stds[i] # Assumes order in 'model_vars' matches PyMC's internal order, 
                     # but using pm.find_hessian on all variables usually keeps order.
        
        # Calculate the credible interval using the Normal approximation
        lo, hi = stats.norm.ppf([lower_quantile, upper_quantile], loc=map_point, scale=sd)
        result.loc[var] = map_point, sd, lo, hi

    return result.round(4), {'mp': mp, 'covariance_matrix': covariance_matrix}

# Create a dummy DataFrame (as before)
np.random.seed(42)
df = pd.DataFrame({
    'height': np.random.normal(150, 5, 100),
    'weight': 35 + 0.5 * np.random.normal(150, 5, 100) + np.random.normal(0, 2, 100)
})

# 1. Define Priors as a Dictionary (Keys are variable names)
# NOTE: DistSpec no longer takes the name as an argument.

# 2. Run QUAP
results_df, details = quap(
    priors={
        'a': DistSpec(pm.Normal, mu=0, sigma=10),       # Intercept
        'b': DistSpec(pm.Uniform, lower=0, upper=1),     # Slope
        'sigma': DistSpec(pm.Uniform, lower=0, upper=10) # Error SD
    }, 
    data=df, 
    x_col='height', 
    y_col='weight'
)

print("✅ QUAP Results (89% CI) using Dictionary Input:\n")
print(results_df)

✅ QUAP Results (89% CI) using Dictionary Input:

          mean     std     5.5%    94.5%
a      42.9310  8.6486  29.1088  56.7532
b       0.4495  0.2338   0.0759   0.8232
sigma   3.6984  0.1263   3.4966   3.9003
