# Outcome-informed balance plots

In [2]:
import numpy as np
import pandas as pd
import seaborn as sns
import seaborn.objects as so
import matplotlib as mpl
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy.special import expit

from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.linear_model import LinearRegression
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.inspection import permutation_importance
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error

from causallib.estimation import IPW


In [3]:
def generate_data(n=1000, seed=0):
    d = 60
    rng = np.random.default_rng(seed)
    X = rng.normal(0, 0.5, size=(n, d))

    a_beta = np.concatenate((
        rng.normal(0, 0.01, size=d//4),
        rng.normal(0, 0.5, size=d//4),
        rng.normal(0, 0.01, size=d//4),
        rng.normal(0, 0.5, size=d//4),
    ))
    a_logit = X @ a_beta
    a_prop = expit(a_logit)
    a = rng.binomial(1, a_prop)

    y_beta = np.concatenate((
        rng.normal(0, 0.01, size=d//2),
        rng.normal(-2, 0.5, size=d//2),
    ))
    effect = 1
    y = X @ y_beta + a * effect + rng.normal(0, 1, size=n)
    
    X = pd.DataFrame(X, columns=[f"x{i:02d}" for i in range(X.shape[1])])
    a = pd.Series(a, name="a")
    y = pd.Series(y, name="y")
    return X, a, y

In [4]:
X, a, y = generate_data()
X.join(a).join(y)

In [5]:
ipw = IPW(LogisticRegression(penalty="none", max_iter=5000))
# ipw = IPW(LogisticRegressionCV(max_iter=5000))
ipw.fit(X, a)
w = ipw.compute_weights(X, a)

In [6]:
def calculate_asmd(X, a, w=None):
    # eps = np.finfo(X.dtypes.iloc[0]).resolution  # .eps
    eps = 1e-8
    if w is None:
        w = pd.Series(1, index=a.index)
    
    is_treated = a == 1
    x1 = sm.stats.DescrStatsW(X.loc[is_treated], weights=w.loc[is_treated])
    x0 = sm.stats.DescrStatsW(X.loc[~is_treated], weights=w.loc[~is_treated])

    x1_mean = pd.Series(x1.mean, index=X.columns)
    x0_mean = pd.Series(x0.mean, index=X.columns)
    x1_var = pd.Series(x1.var, index=X.columns)
    x0_var = pd.Series(x0.var, index=X.columns)

    # smds = (x1_mean - x0_mean) / np.sqrt(x0_var + x1_var + eps)
    smds = (x1_mean - x0_mean) / ((x0_var + x1_var + eps)**0.5)
    asmds = smds.abs()
    asmds.name = "asmd"
    return asmds

In [7]:
asmds = pd.concat({
    "weighted": calculate_asmd(X, a, w),
    "unweighted": calculate_asmd(X, a),
}, names=["adjustment", "covariate"])
asmds

adjustment  covariate
weighted    x00          0.015807
            x01          0.029745
            x02          0.036038
            x03          0.029324
            x04          0.017355
                           ...   
unweighted  x55          0.180169
            x56          0.011614
            x57          0.093725
            x58          0.081427
            x59          0.033620
Name: asmd, Length: 120, dtype: float64

In [8]:
def leave_one_out_importance(estimator, X, a, y):
    results = []

    for col in ["full"] + X.columns.tolist():
        curX = X.drop(columns=col, errors="ignore")
        curXa = curX.join(a)
        estimator.fit(curXa, y)
        y_pred = estimator.predict(curXa)
        result = {
            "covariate": col,
            "r2": r2_score(y, y_pred),
            "mse": mean_squared_error(y, y_pred),
            "mae": mean_absolute_error(y, y_pred),
        }
        results.append(result)
    results = pd.DataFrame(results)
    return results

def relative_explained_variation(estimator, X, a, y, metric="mse"):
    """Harrell: https://www.fharrell.com/post/addvalue/"""
    importance = leave_one_out_importance(estimator, X, a, y)
    importance = importance.set_index("covariate")
    importance = importance / importance.loc["full"]
    importance = importance.drop(index="full")
    # importance = importance[metric]
    return importance

def decrease_in_explain_variation(estimator, X, a, y, metric="mse"):
    """https://stackoverflow.com/q/31343563"""
    importance = leave_one_out_importance(estimator, X, a, y)
    importance = importance.set_index("covariate")
    importance = (importance.loc["full"]-importance) / importance.loc["full"]
    importance = importance.drop(index="full")
    # importance = importance[metric]
    importance = importance.abs()
    return importance


In [9]:
# i = leave_one_out_importance(LinearRegression(), X, a, y)
# i = i.set_index("covariate")
# i
# relative_explained_variation(LinearRegression(), X, a, y)
feature_importance = decrease_in_explain_variation(LinearRegression(), X, a, y)
feature_importance.head()

In [10]:
plot_data = asmds.reset_index().merge(
    feature_importance.reset_index(), on="covariate",
)
plot_data

In [11]:
outcome_metric = "mse"
ouiasmd = plot_data.query("adjustment=='unweighted'").drop(columns="adjustment")
ouiasmd["ouiasmd"] = ouiasmd["asmd"] * ouiasmd[outcome_metric]

plot_data = plot_data.merge(
    ouiasmd[["covariate", "ouiasmd"]],
    on="covariate",
    how="left",
)
plot_data = plot_data.rename(columns={"ouiasmd": "Oui-ASMD"})
plot_data

In [12]:
def slope_balance(
    plot_data, 
    opacity=False, 
    pointsize=False, 
    importance_metric="Oui-ASMD",
    legend=True,
    threshold=None, ax=None
):
    p = so.Plot(
        data=plot_data,
        x="adjustment",
        y="asmd",
        group="covariate",
        alpha=importance_metric if opacity else None,
    ).add(
        so.Lines(),
        linewidth=importance_metric if pointsize else None,
        legend=legend,
    ).add(
        # so.Dot(pointsize=3 if not pointsize else None),
        so.Dot() if pointsize else so.Dot(pointsize=3),
        pointsize=importance_metric if pointsize else None,
        legend=legend,
    ).scale(
        x=so.Nominal(order=["unweighted", "weighted"]),
    ).label(
        x="",
        y="Absolute standardized mean difference",
    ).theme(
        sns.axes_style("white")
    ).limit(
        x=(-0.1, 1.1)
    )
    if threshold is not None:
        ax.axhline(0.1, linestyle="--", color="0.6", zorder=0)
    if ax is not None:
        p = p.on(ax).plot()

    return p

In [13]:
def scatter_balance(
    plot_data, 
    opacity=False, 
    pointsize=False, 
    importance_metric="Oui-ASMD",
    legend=True,
    threshold=None, ax=None
):
    plot_data = plot_data.pivot_table(
        # values=["asmd", "mse"]
        values="asmd",
        index="covariate",
        columns="adjustment"
    ).merge(
        plot_data.query("adjustment=='weighted'").drop(
            columns=["adjustment", "asmd"]
        ),
        left_index=True,
        right_on="covariate",
    )
    p = so.Plot(
        data=plot_data,
        x="unweighted",
        y="weighted",
        alpha=importance_metric if opacity else None,
        pointsize=importance_metric if pointsize else None,
    ).add(
        # so.Dot(pointsize=3 if not pointsize else None),
        so.Dot() if pointsize else so.Dot(pointsize=3),
        legend=legend,
    ).label(
        x="Unweighted ASMD",
        y="Weighted ASMD",
    ).theme(
        sns.axes_style("white")
    )
    if opacity:
        p = p.scale(alpha=so.Continuous().tick(upto=4))
    if pointsize:
        p = p.scale(pointsize=so.Continuous().tick(upto=4))
    if threshold is not None:
        ax.axhline(0.1, linestyle="--", color="0.6", zorder=0)
        ax.axvline(0.1, linestyle="--", color="0.6", zorder=0)
    if ax is not None:
        p = p.on(ax).plot()

    return p

In [14]:
fig = mpl.pyplot.figure(figsize=(7.5, 5))
topfig, bottomfig = fig.subfigures(2, 1, hspace=0)
topaxes = topfig.subplots(1, 2)
bottomaxes = bottomfig.subplots(1, 2);
mpl.pyplot.close()

In [15]:
scatter_balance(plot_data, threshold=0.1, ax=topaxes[0])
scatter_balance(plot_data, threshold=0.1, opacity=True, pointsize=True, ax=topaxes[1])

slope_balance(plot_data, threshold=0.1, ax=bottomaxes[0])
slope_balance(plot_data, threshold=0.1, opacity=True, pointsize=True, ax=bottomaxes[1])

In [16]:
topaxes[0].set_title("Standard")
topaxes[1].set_title("Outcome-informed")
bottomaxes[0].set_ylabel("ASMD")
bottomaxes[1].set_ylabel("ASMD")
topfig.subplots_adjust(wspace=0.3)
bottomfig.subplots_adjust(wspace=0.3)
# fig.tight_layout()
fig