Copyright 2025 The Google Research Authors

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
```
 http://www.apache.org/licenses/LICENSE-2.0
```
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Simulation of the effects of selection bias on subgroup calibration

In [None]:
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import itertools

from causal_evaluation import utils
from causal_evaluation.experiments.simulator import Simulator

In [None]:
# Flags
N_SAMPLES_TRAIN = 100000
N_SAMPLES_TRAIN_EFFECTIVE = 50000
N_SAMPLES_EVAL = 20000
FIGURE_PATH = './../../figures'

In [None]:
os.makedirs(FIGURE_PATH, exist_ok=True)

In [None]:
np.random.seed(173)
random.seed(100)

In [None]:
def select_x(x):
  return np.minimum(np.maximum(-4/25*x**2 + 1, 0), 1)
fig, ax = plt.subplots(1, 1)
x = np.linspace(-10, 10, 1000)
ax.plot(x, select_x(x))

In [None]:
class SimulatorSelectX(Simulator):
  def selection_function(self, u, a, x, y):
    p_selected = select_x(x)
    return np.random.binomial(n=1, p=p_selected), p_selected

class SimulatorSelectY(Simulator):

  def selection_function(self, u, a, x, y):
    p_selected_y = np.array([0.4, 0.8])
    p_selected = p_selected_y[y]
    return np.random.binomial(n=1, p=p_selected), p_selected

class SimulatorSelectAY(Simulator):

  def selection_function(self, u, a, x, y):
    p_selected_y = np.array([[0.8, 0.5], [0.8, 0.25]])
    p_selected = p_selected_y[a, y]
    return np.random.binomial(n=1, p=p_selected), p_selected

In [None]:
def get_sim_dict(**kwargs):

  sim_dict = {
      'select_x': SimulatorSelectX(
          a_to_y=True, mu_y_a=np.array([0.1, 0]), **kwargs
      ),
      'select_y': SimulatorSelectY(
          a_to_y=True, mu_y_a=np.array([0.1, 0]), **kwargs
      ),
      'select_ay': SimulatorSelectAY(
          a_to_y=True, mu_y_a=np.array([0.1, 0]), **kwargs
      )
  }
  return sim_dict

sim_samples_dict = {
    key: utils.get_squeezed_df(value.get_samples(seed=i))
    for i, (key, value) in enumerate(
        get_sim_dict(num_samples=N_SAMPLES_TRAIN).items()
    )
}

sim_samples_df = (
    pd.concat(sim_samples_dict)
    .reset_index(level=-1, drop=True)
    .rename_axis('setting')
    .reset_index()
)

sim_samples_dict_eval = {
    key: utils.get_squeezed_df(value.get_samples(seed=2 * i))
    for i, (key, value) in enumerate(
        get_sim_dict(num_samples=N_SAMPLES_EVAL).items()
    )
}


In [None]:
assert (sim_samples_df.query('selected == 1').groupby('setting')['setting'].agg(lambda x: x.shape[0]).values > N_SAMPLES_TRAIN_EFFECTIVE).all()
sim_samples_df_selected = sim_samples_df.query('selected == 1').groupby('setting').apply(lambda x: x.sample(N_SAMPLES_TRAIN_EFFECTIVE, random_state=10), include_groups=False).reset_index(level=-1, drop=True).reset_index()
sim_samples_df_selected

In [None]:
model_type = 'gradient_boosting'
settings = sim_samples_dict.keys()
for setting in settings:
  print(f'Setting: {setting}', flush=True)
  selected_df = sim_samples_df_selected.query('setting == @setting')
  # Fit model for E[Y | X]
  model = utils.fit_model(
      selected_df['x'].values.reshape(-1, 1),
      selected_df['y'].values,
      model_type=model_type,
      model_cross_val=True,
  )

  # Predict y in the eval data
  sim_samples_dict_eval[setting]['pred_probs_y_x'] = utils.array_to_series(
      model.predict_proba(
          sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1)
      )
  )
  sim_samples_dict_eval[setting]['pred_probs_y1_x'] = sim_samples_dict_eval[
      setting
  ]['pred_probs_y_x'].map(lambda x: x[-1])

  # Fit model stratified
  model_dict = utils.fit_model_stratified(
    selected_df['x'].values.reshape(-1, 1),
    selected_df['y'].values,
    group=selected_df['a'].values,
    model_type=model_type,
    model_cross_val=True,
)

  # Predict y in the eval data
  sim_samples_dict_eval[setting]['pred_probs_y_xa_stratified'] = utils.array_to_series(
      utils.predict_proba_stratified(
          sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1),
          model_dict,
          group=sim_samples_dict_eval[setting]['a'].values,
      )
  )
  sim_samples_dict_eval[setting]['pred_probs_y1_xa_stratified'] = sim_samples_dict_eval[
      setting
  ]['pred_probs_y_xa_stratified'].map(lambda x: x[-1])

  # Fit model with X and A
  model_xa = utils.fit_model(
      np.concatenate(
          (
              selected_df['x'].values.reshape(-1, 1),
              selected_df['a'].values.reshape(-1, 1),
          ),
          axis=1,
      ),
      selected_df['y'].values,
      model_type=model_type,
      model_cross_val=True,
  )

  # Predict y in the eval data
  sim_samples_dict_eval[setting]['pred_probs_y_xa'] = utils.array_to_series(
      model_xa.predict_proba(
          np.concatenate(
              (
                  sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1),
                  sim_samples_dict_eval[setting]['a'].values.reshape(-1, 1),
              ),
              axis=1,
          )
      )
  )

  sim_samples_dict_eval[setting]['pred_probs_y1_xa'] = sim_samples_dict_eval[
      setting
  ]['pred_probs_y_xa'].map(lambda x: x[-1])

In [None]:
def calibration_plot(
    y_true,
    y_prob,
    group=None,
    ax=None,
    plot_overall=True,
    palette=None,
    legend=False,
    xlabel='Score',
    ylabel='Fraction positive',
    **kwargs,
):

  if ax is None:
    plt.figure()
    ax = plt.gca()

  if kwargs.get('n_bins') is None:
    kwargs['n_bins'] = 10

  if palette is None:
    palette = sns.color_palette('Set2')

  palette_count = 0
  if plot_overall:
    calibration_curve_result = utils.calibration_curve_ci(
        y_true, y_prob, **kwargs
    )
    ax.plot(
        calibration_curve_result[1],
        calibration_curve_result[0],
        label='Overall',
        color=palette[0],
    )
    ax.fill_between(
        calibration_curve_result[1],
        calibration_curve_result[2],
        calibration_curve_result[3],
        alpha=0.3,
        color=palette[0],
    )
    palette_count = 1

  if group is not None:
    df = pd.DataFrame({'y_true': y_true, 'y_prob': y_prob, 'group': group})
    for i, (the_group, group_df) in enumerate(df.groupby('group')):
      if group_df.shape[0] > 0:
        calibration_curve_result = utils.calibration_curve_ci(
            group_df['y_true'],
            group_df['y_prob'],
        )
        ax.plot(
            calibration_curve_result[1],
            calibration_curve_result[0],
            label=f'{the_group}',
            linewidth=2,
            color=palette[i + palette_count],
        )
        ax.fill_between(
            calibration_curve_result[1],
            calibration_curve_result[2],
            calibration_curve_result[3],
            alpha=0.3,
            color=palette[i + palette_count],
        )

  ax.plot(
      np.linspace(0, 1, 100),
      np.linspace(0, 1, 100),
      linestyle='-.',
      color='gray',
  )

  ax.set_xlim(0, 1)
  ax.set_ylim(0, 1)

  if xlabel is not None:
    ax.set_xlabel(xlabel, size=14)
  if ylabel is not None:
    ax.set_ylabel(ylabel, size=14)

  if legend:
    ax.legend()

  sns.despine()

In [None]:
settings = sim_samples_dict.keys()

In [None]:
pred_eval_dict = sim_samples_dict_eval.copy()

In [None]:
setting_title_dict = {
    'select_x': r'X $\rightarrow$ S',
    'select_y': r'Y $\rightarrow$ S',
    'select_ay': r'{Y, A} $\rightarrow$ A'
}

In [None]:
figsize = (5, 5)
fig, axes = plt.subplots(3, len(settings), sharex=True, sharey=True, figsize=figsize)
for i, (setting, feature_set) in enumerate(
    itertools.product(settings, ['x', 'xa', 'xa_stratified'])
):

  calibration_plot(
      pred_eval_dict[setting]['y'],
      pred_eval_dict[setting][f'pred_probs_y1_{feature_set}'],
      group=pred_eval_dict[setting]['a'],
      plot_overall=False,
      ax=axes[i % 3][i // 3],
      legend=False,
      xlabel=None,
      ylabel=None,
  )

for i, setting in enumerate(settings):
  axes[0][i].set_title(setting_title_dict[setting], size=12)

axes[0][-1].text(x=1.02, y=0.3, s='Cov.: X', rotation=-90, size=12)
axes[1][-1].text(x=1.02, y=0.3, s='Cov.: XA', rotation=-90, size=12)
axes[2][-1].text(x=1.02, y=0.1, s=r'Cov.: $\text{X}_{strat}$', rotation=-90, size=12)
fig.supxlabel('Score', size=14, y=0.05)
fig.supylabel('Fraction positive', size=14, x=0.03)
plt.tight_layout()
fig.set_size_inches(*figsize)
axes[1][-1].legend(frameon=False, bbox_to_anchor=(1.25, 1), title='Group', fontsize='medium', title_fontsize='large')
for filetype in ['png', 'pdf']:
  fig.savefig(
      os.path.join(FIGURE_PATH, f'selection_calibration.{filetype}'), format=filetype, bbox_inches='tight'
  )