### Transforms on the Target Variable 

In [None]:
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import PowerTransformer

class SignedLogTransformer(BaseEstimator, TransformerMixin):
    """
    Signed‑log (log‑modulus) transform:
      y' = sign(y) * log1p(abs(y))
    Inverse:
      y  = sign(y') * (exp(|y'|) - 1)
    """
    def fit(self, X=None, y=None):
        return self

    def transform(self, y):
        y = np.asarray(y, dtype=float)
        return np.sign(y) * np.log1p(np.abs(y))

    def inverse_transform(self, y_prime):
        y_prime = np.asarray(y_prime, dtype=float)
        return np.sign(y_prime) * (np.expm1(np.abs(y_prime)))


class AsinhTransformer(BaseEstimator, TransformerMixin):
    """
    Hyperbolic arcsine transform:
      y' = arcsinh(y / c)
    Inverse:
      y = sinh(y') * c
    """
    def __init__(self, c=None):
        # c = scale parameter; if None, will be set to median(|y|)
        self.c = c

    def fit(self, X=None, y=None):
        y = np.asarray(y, dtype=float).ravel()
        if self.c is None:
            # avoid c == 0
            self.c = np.median(np.abs(y)) or 1.0
        return self

    def transform(self, y):
        y = np.asarray(y, dtype=float)
        return np.arcsinh(y / self.c)

    def inverse_transform(self, y_prime):
        y_prime = np.asarray(y_prime, dtype=float)
        return np.sinh(y_prime) * self.c


class YeoJohnsonTransformer(BaseEstimator, TransformerMixin):
    """
    Yeo‑Johnson power transform (handles negatives):
      Fits lambda by MLE under Gaussian assumption.
    """
    def __init__(self, standardize=False):
        self.standardize = standardize
        self.pt = PowerTransformer(method='yeo-johnson', standardize=self.standardize)

    def fit(self, X=None, y=None):
        y = np.asarray(y, dtype=float).reshape(-1, 1)
        self.pt.fit(y)
        return self

    def transform(self, y):
        y = np.asarray(y, dtype=float).reshape(-1, 1)
        return self.pt.transform(y).flatten()

    def inverse_transform(self, y_prime):
        y_prime = np.asarray(y_prime, dtype=float).reshape(-1, 1)
        return self.pt.inverse_transform(y_prime).flatten()


# Usage example
if __name__ == "__main__":
    # Generate some synthetic heavy‑tailed data
    y = np.random.standard_t(df=2, size=1000) * 50

    # 1) Signed‑log
    sl = SignedLogTransformer()
    y_sl = sl.fit_transform(y)
    y_sl_inv = sl.inverse_transform(y_sl)

    # 2) Asinh
    a = AsinhTransformer()
    y_a = a.fit_transform(y)
    y_a_inv = a.inverse_transform(y_a)

    # 3) Yeo‑Johnson
    yj = YeoJohnsonTransformer()
    y_yj = yj.fit_transform(y)
    y_yj_inv = yj.inverse_transform(y_yj)

    # Check that round‑trip errors are near zero
    print("Signed‑log error:", np.max(np.abs(y - y_sl_inv)))
    print("Asinh error:    ", np.max(np.abs(y - y_a_inv)))
    print("Yeo‑Johnson error:", np.max(np.abs(y - y_yj_inv)))


### Error diagnostic plots 

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Tuple, Union, Callable, Optional

def error_heatmap(
    err: np.ndarray,
    x1: np.ndarray,
    x2: np.ndarray,
    bins: Union[int, Tuple[int, int]] = (10, 10),
    quantile_bins: bool = True,
    min_count: int = 25,
    agg: Union[str, Callable[[np.ndarray], float]] = "mae",
    cmap: str = "YlOrRd",
    annotate: bool = True,
    title: Optional[str] = None,
    x1_name: str = "x1",
    x2_name: str = "x2",
    figsize: Tuple[int, int] = (10, 7),
):
    """
    Make a 2D heatmap of an error metric across two continuous features.

    Parameters
    ----------
    err : array-like
        Error per observation. Pass residuals (y - y_hat) or absolute error.
    x1, x2 : array-like
        Feature values aligned with `err`.
    bins : int or (int, int)
        Number of bins for (x1, x2).
    quantile_bins : bool
        If True, use quantile (equal-count) bins; else equal-width bins.
    min_count : int
        Cells with < min_count observations are masked.
    agg : {"mae","mse","rmse","mean","median"} or callable
        Aggregation for the cell. If callable, it receives a 1D ndarray of `err`.
        Common choices:
          - "mae": mean absolute error (default)
          - "mse": mean squared error
          - "rmse": root mean squared error
          - "mean": mean of err (signed)
          - "median": median absolute error
    cmap : str
        Matplotlib colormap.
    annotate : bool
        If True, write the value and count in each populated cell.
    title : str
        Optional plot title.
    x1_name, x2_name : str
        Axis labels.
    figsize : (w, h)
        Figure size in inches.

    Returns
    -------
    fig, ax, tables : (matplotlib.figure.Figure, matplotlib.axes.Axes, dict)
        tables = {"value": value_df, "count": count_df, "x1_bins": x1_bins, "x2_bins": x2_bins}
    """

    err = np.asarray(err).ravel()
    x1 = np.asarray(x1).ravel()
    x2 = np.asarray(x2).ravel()
    assert err.shape == x1.shape == x2.shape, "err, x1, x2 must have same length"

    if isinstance(bins, int):
        b1 = b2 = bins
    else:
        b1, b2 = bins

    # define aggregator
    def _agg(vals: np.ndarray) -> float:
        if callable(agg):
            return float(agg(vals))
        a = agg.lower()
        if a == "mae":
            return float(np.mean(np.abs(vals)))
        if a == "mse":
            return float(np.mean(vals ** 2))
        if a == "rmse":
            return float(np.sqrt(np.mean(vals ** 2)))
        if a == "mean":
            return float(np.mean(vals))
        if a == "median":
            return float(np.median(np.abs(vals)))
        raise ValueError(f"Unknown agg: {agg}")

    # bin edges
    def _edges(v, k):
        if quantile_bins:
            qs = np.linspace(0, 1, k + 1)
            # ensure unique edges
            e = np.unique(np.quantile(v, qs))
            # if too many duplicates (constant regions), fall back to equal-width
            if e.size < k + 1:
                vmin, vmax = np.min(v), np.max(v)
                e = np.linspace(vmin, vmax, k + 1)
        else:
            vmin, vmax = np.min(v), np.max(v)
            e = np.linspace(vmin, vmax, k + 1)
        return e

    e1 = _edges(x1, b1)
    e2 = _edges(x2, b2)

    # assign bins
    binned1 = pd.cut(x1, e1, include_lowest=True, duplicates="drop")
    binned2 = pd.cut(x2, e2, include_lowest=True, duplicates="drop")

    df = pd.DataFrame({x1_name: binned1, x2_name: binned2, "err": err})
    # aggregate
    val_tbl = df.pivot_table(
        values="err",
        index=x2_name,
        columns=x1_name,
        aggfunc=_agg,
        dropna=False,
    )
    cnt_tbl = df.pivot_table(
        values="err",
        index=x2_name,
        columns=x1_name,
        aggfunc="count",
        dropna=False,
    )

    # mask sparse cells
    masked_vals = val_tbl.where(cnt_tbl >= min_count)

    # plotting
    fig, ax = plt.subplots(figsize=figsize, dpi=150)
    im = ax.imshow(
        masked_vals.values,
        origin="lower",
        aspect="auto",
        cmap=cmap,
        interpolation="nearest",
    )
    cbar_label = {
        "mae": "Mean Absolute Error",
        "mse": "Mean Squared Error",
        "rmse": "Root MSE",
        "mean": "Mean Error",
        "median": "Median Absolute Error",
    }.get(str(agg).lower(), "Cell Value")
    cbar = fig.colorbar(im, ax=ax)
    cbar.set_label(cbar_label)

    # ticks & labels
    ax.set_xlabel(x1_name)
    ax.set_ylabel(x2_name)
    x_labels = [str(c) for c in masked_vals.columns]
    y_labels = [str(r) for r in masked_vals.index]
    ax.set_xticks(np.arange(len(x_labels)))
    ax.set_yticks(np.arange(len(y_labels)))
    ax.set_xticklabels(x_labels, rotation=90)
    ax.set_yticklabels(y_labels)

    if title:
        ax.set_title(title)

    # annotations
    if annotate:
        nrows, ncols = masked_vals.shape
        for i in range(nrows):
            for j in range(ncols):
                val = masked_vals.iat[i, j]
                n = cnt_tbl.iat[i, j]
                if pd.notna(val) and n >= min_count:
                    ax.text(
                        j, i,
                        f"{val:.2f}\n n={int(n)}",
                        ha="center", va="center",
                        fontsize=7, color="black",
                    )

    fig.tight_layout()

    tables = {
        "value": masked_vals,
        "count": cnt_tbl,
        "x1_bins": e1,
        "x2_bins": e2,
    }
    return fig, ax, tables


# -------------------------
# Example (remove in prod):
if __name__ == "__main__":
    rng = np.random.default_rng(0)
    n = 10000
    x1 = rng.normal(0, 1, n)
    x2 = rng.normal(0, 1, n)
    # fake model errors: larger when both features are large (interaction)
    err = rng.normal(0, 0.05, n) + 0.08 * (np.maximum(0, x1) * np.maximum(0, x2))

    fig, ax, tables = error_heatmap(
        err=err,
        x1=x1,
        x2=x2,
        bins=(10, 10),
        quantile_bins=True,
        min_count=40,
        agg="mae",
        title="Error heatmap by x1 × x2 (quantile bins)",
        x1_name="feature_A",
        x2_name="feature_B",
    )
    plt.show()

### Which features are important and why 

In [None]:
# --- requirements ---
# pip install shap lightgbm PyALE  # (PyALE only if you later want ALE; not used here)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import shap
from typing import Optional, Sequence, Tuple, Union, Callable

# ========== CORE HELPERS ==========

def compute_shap_values(
    model,
    X: pd.DataFrame,
    shap_values: Optional[np.ndarray] = None,
    approximate: bool = False,
) -> np.ndarray:
    """
    Return SHAP values for a regression model.
    If shap_values are provided, they are just returned (validated).
    """
    if shap_values is not None:
        sv = np.asarray(shap_values)
        # for multiclass shap returns list -> not supported here
        if isinstance(shap_values, list):
            raise ValueError("Provide regression SHAP values of shape (n, m), not a list.")
        if sv.shape[0] != len(X) or sv.shape[1] != X.shape[1]:
            raise ValueError("shap_values must have shape (n_samples, n_features).")
        return sv

    # compute
    explainer = shap.TreeExplainer(model) if hasattr(model, "predict_proba") or "lightgbm" in str(type(model)).lower() or "xgb" in str(type(model)).lower() or "catboost" in str(type(model)).lower() \
        else shap.Explainer(model, X, algorithm="permutation" if approximate else None)

    sv = explainer(X).values if hasattr(explainer(X), "values") else explainer.shap_values(X)
    # TreeExplainer for regression returns (n, m)
    if isinstance(sv, list):
        raise ValueError("Got list of SHAP arrays (likely multiclass). Use a regression model or pass precomputed array.")
    return sv


def ensure_df(X) -> pd.DataFrame:
    return X if isinstance(X, pd.DataFrame) else pd.DataFrame(X)


def abs_error(y_true: np.ndarray, y_pred: np.ndarray) -> np.ndarray:
    return np.abs(np.asarray(y_true).ravel() - np.asarray(y_pred).ravel())


# ========== PLOTS ==========

def plot_shap_beeswarm(shap_values: np.ndarray, X: pd.DataFrame, top_n: int = 10, title: str = "SHAP summary"):
    shap.summary_plot(shap_values, X, plot_type="dot", max_display=top_n, show=False)
    plt.title(title)
    plt.tight_layout()
    plt.show()


def plot_shap_vs_error(
    shap_values: np.ndarray,
    errors: np.ndarray,
    X: pd.DataFrame,
    top_n: int = 10,
    title: str = "SHAP value vs Error",
):
    """
    For each of the top_n features by mean|shap|, scatter SHAP value (x) vs error (color).
    """
    X = ensure_df(X)
    n, m = shap_values.shape
    feat_order = np.argsort(np.mean(np.abs(shap_values), axis=0))[::-1][:top_n]
    fig, ax = plt.subplots(figsize=(12, 7), dpi=150)
    y_ticks = []
    y_pos = 0
    for j in feat_order:
        sv = shap_values[:, j]
        # offset each feature cloud vertically
        y_vals = np.full_like(sv, fill_value=y_pos, dtype=float)
        sc = ax.scatter(sv, y_vals, c=errors, s=8, alpha=0.7)
        y_ticks.append(X.columns[j])
        # dashed separator
        ax.axhline(y_pos + 0.5, color="k", linestyle="--", linewidth=0.6, alpha=0.4)
        y_pos += 1
    cbar = plt.colorbar(sc, ax=ax)
    cbar.set_label("Absolute error")
    ax.set_yticks(range(len(y_ticks)))
    ax.set_yticklabels(y_ticks)
    ax.set_xlabel("SHAP value (impact on model output)")
    ax.set_title(title)
    plt.axvline(0, color="k", linestyle=":", linewidth=0.8)
    plt.tight_layout()
    plt.show()


def plot_feature_vs_error(
    X: pd.DataFrame,
    errors: np.ndarray,
    top_n: int = 10,
    title: str = "Feature value vs Error",
    normalise: bool = True,
):
    """
    Scatter feature value (x) vs error (color), stacked by feature.
    """
    X = ensure_df(X)
    n, m = X.shape
    # order by variance (or later by importance you pass separately)
    feat_order = np.argsort(X.var().values)[::-1][:top_n]
    fig, ax = plt.subplots(figsize=(12, 7), dpi=150)
    y_ticks = []
    y_pos = 0
    for j in feat_order:
        xj = X.iloc[:, j].values.astype(float)
        if normalise:
            # robust center (median) and scale (IQR)
            med = np.median(xj)
            iqr = np.subtract(*np.percentile(xj, [75, 25])) or 1.0
            xj = (xj - med) / iqr
        y_vals = np.full_like(xj, fill_value=y_pos, dtype=float)
        sc = ax.scatter(xj, y_vals, c=errors, s=8, alpha=0.7)
        y_ticks.append(X.columns[j])
        ax.axhline(y_pos + 0.5, color="k", linestyle="--", linewidth=0.6, alpha=0.4)
        y_pos += 1
    cbar = plt.colorbar(sc, ax=ax)
    cbar.set_label("Absolute error")
    ax.set_yticks(range(len(y_ticks)))
    ax.set_yticklabels(y_ticks)
    ax.set_xlabel("Feature value (robust normalised)" if normalise else "Feature value")
    ax.set_title(title)
    plt.axvline(0, color="k", linestyle=":", linewidth=0.8)
    plt.tight_layout()
    plt.show()


def binned_error_curve(
    values: np.ndarray,
    errors: np.ndarray,
    bins: int = 10,
    quantile_bins: bool = True,
    agg: str = "mae",
) -> Tuple[np.ndarray, np.ndarray]:
    v = np.asarray(values).ravel()
    e = np.asarray(errors).ravel()

    if quantile_bins:
        edges = np.unique(np.quantile(v, np.linspace(0, 1, bins + 1)))
        if len(edges) < 3:  # fallback
            edges = np.linspace(v.min(), v.max(), bins + 1)
    else:
        edges = np.linspace(v.min(), v.max(), bins + 1)

    idx = np.digitize(v, edges[1:-1], right=False)

    def aggfunc(arr):
        if agg == "mae":
            return np.mean(np.abs(arr))
        if agg == "mse":
            return np.mean(arr ** 2)
        if agg == "rmse":
            return np.sqrt(np.mean(arr ** 2))
        if agg == "mean":
            return np.mean(arr)
        if agg == "median":
            return np.median(np.abs(arr))
        raise ValueError("Unknown agg")

    vals = np.array([aggfunc(e[idx == k]) if np.any(idx == k) else np.nan for k in range(bins)])
    centers = 0.5 * (edges[:-1] + edges[1:])
    return centers[:bins], vals


def plot_binned_error_for_top_features(
    shap_values: np.ndarray,
    X: pd.DataFrame,
    errors: np.ndarray,
    top_n: int = 6,
    bins: int = 10,
    quantile_bins: bool = True,
):
    """
    For each top feature by mean|shap|, plot two lines:
    - error vs SHAP value bins
    - error vs FEATURE value bins
    """
    X = ensure_df(X)
    order = np.argsort(np.mean(np.abs(shap_values), axis=0))[::-1][:top_n]
    ncols = 3
    nrows = int(np.ceil(top_n / ncols))
    fig, axes = plt.subplots(nrows, ncols, figsize=(14, 4 * nrows), dpi=150, squeeze=False)

    for i, j in enumerate(order):
        ax = axes[i // ncols, i % ncols]
        # SHAP bins
        c1, y1 = binned_error_curve(shap_values[:, j], errors, bins=bins, quantile_bins=quantile_bins, agg="mae")
        # Feature bins
        c2, y2 = binned_error_curve(X.iloc[:, j].values, errors, bins=bins, quantile_bins=quantile_bins, agg="mae")

        ax.plot(c1, y1, marker="o", label="Error vs SHAP")
        ax.plot(c2, y2, marker="s", linestyle="--", label="Error vs Feature")
        ax.set_title(X.columns[j])
        ax.set_xlabel("Binned value (quantiles)" if quantile_bins else "Binned value")
        ax.set_ylabel("MAE")
        ax.grid(True, alpha=0.3)
        ax.legend()

    # remove empty subplots
    for k in range(i + 1, nrows * ncols):
        fig.delaxes(axes[k // ncols, k % ncols])

    fig.suptitle("Binned error curves (top features)", y=1.02)
    plt.tight_layout()
    plt.show()


def error_heatmap(
    err: np.ndarray,
    x1: np.ndarray,
    x2: np.ndarray,
    bins: Union[int, Tuple[int, int]] = (10, 10),
    quantile_bins: bool = True,
    min_count: int = 25,
    agg: str = "mae",
    title: str = "Error heatmap",
    x1_name: str = "x1",
    x2_name: str = "x2",
):
    """
    2D error heatmap across two features (same as we discussed earlier).
    """
    if isinstance(bins, int):
        b1 = b2 = bins
    else:
        b1, b2 = bins

    def edges(v, k):
        if quantile_bins:
            e = np.unique(np.quantile(v, np.linspace(0, 1, k + 1)))
            if len(e) < k + 1:
                e = np.linspace(v.min(), v.max(), k + 1)
        else:
            e = np.linspace(v.min(), v.max(), k + 1)
        return e

    e1 = edges(x1, b1)
    e2 = edges(x2, b2)

    b1c = pd.cut(x1, e1, include_lowest=True, duplicates="drop")
    b2c = pd.cut(x2, e2, include_lowest=True, duplicate_


### Correlation / VIF between features 

In [None]:
#!/usr/bin/env python3
"""
feature_analysis.py

A script to visualize feature correlations and compute Variance Inflation Factors (VIF).
"""

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from statsmodels.stats.outliers_influence import variance_inflation_factor
from patsy import dmatrix


def plot_correlation_heatmap(df: pd.DataFrame,
                             features: list,
                             method: str = 'pearson',
                             annot: bool = True,
                             figsize: tuple = (10, 8),
                             cmap: str = 'vlag',
                             save_path: str = None):
    """
    Plot a correlation heatmap for the specified features.

    Parameters
    ----------
    df : pd.DataFrame
        The input DataFrame.
    features : list
        List of column names to include.
    method : str
        Correlation method: 'pearson' or 'spearman'.
    annot : bool
        Whether to annotate the cells with correlation coefficients.
    figsize : tuple
        Figure size.
    cmap : str
        Seaborn colormap.
    save_path : str
        If provided, path to save the figure (e.g. 'heatmap.png').
    """
    corr = df[features].corr(method=method)
    plt.figure(figsize=figsize)
    sns.heatmap(corr, annot=annot, cmap=cmap, center=0,
                fmt=".2f", square=True, linewidths=0.5)
    plt.title(f"{method.capitalize()} Correlation Heatmap")
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=300)
    plt.show()


def compute_vif(df: pd.DataFrame,
                features: list,
                add_constant: bool = True) -> pd.DataFrame:
    """
    Compute Variance Inflation Factor (VIF) for each feature.

    Parameters
    ----------
    df : pd.DataFrame
        The input DataFrame.
    features : list
        List of column names to include.
    add_constant : bool
        Whether to add an intercept column for VIF computation.

    Returns
    -------
    pd.DataFrame
        DataFrame with features and their VIFs.
    """
    X = df[features].copy()
    if add_constant:
        # statsmodels VIF expects an intercept
        X['Intercept'] = 1.0

    vif_data = []
    for i, col in enumerate(X.columns):
        vif = variance_inflation_factor(X.values, i)
        vif_data.append({'feature': col, 'VIF': vif})

    vif_df = pd.DataFrame(vif_data)
    # drop the intercept row if added
    if add_constant:
        vif_df = vif_df[vif_df['feature'] != 'Intercept'].reset_index(drop=True)
    return vif_df.sort_values(by='VIF', ascending=False)


def main():
    # === Example: load the Iris dataset ===
    from sklearn.datasets import load_iris
    iris = load_iris(as_frame=True)
    df = iris.frame

    # Specify which features to analyze
    features = iris.feature_names

    # 1) Correlation heatmap
    plot_correlation_heatmap(df, features, method='pearson',
                             save_path='iris_corr_heatmap.png')

    # 2) VIF
    vif_df = compute_vif(df, features)
    print("\nVariance Inflation Factors:")
    print(vif_df.to_string(index=False))

