Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License"); you may not use
this file except in compliance with the License. You may obtain a copy of the
License at

```
https://www.apache.org/licenses/LICENSE-2.0
```

Unless required by applicable law or agreed to in writing, software distributed
under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

# Simulation study

This code constructs a simulated data generating process where there are three
proxies $Y_1$, $Y_2$, and $Y_3$ of a variable $Y$. Dependent on the
data-generating process, $Y$ is generated by covariates $X$ and (depending on
the data-generating process) subgroup membership $A$. $Y_1$ and $Y_2$ are
"unbiased proxies" of $Y$ because they depend only on $Y$. $Y_2$ is noisier than
$Y_1$. $Y_3$ is a biased proxy because it depends on both $Y$ and $A$.

We fit models to predict $Y$ and each of its proxies using $X$ or $\{X, A\}$ for
a collection of data generating processes. We evaluate the sufficiency fairness
criterion ($Y_i \perp A \mid R$ for a continuous score $R$ and label $Y_i$) using calibration curves.

In [None]:
from causal_label_bias import utils
import jax
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.special import expit
import seaborn as sns
from sklearn.preprocessing import OneHotEncoder

In [None]:
class Simulator:
  """Generates simulated data following a causal generative process"""

  def __init__(self, **kwargs):
    """Initializes simulation.

    Arguments:
      param_dict: A dictionary for data-generating parameters. The provided
        param_dict will override default parameters.

    Returns:
      A pd.DataFrame with columns corresponding to the variables in the causal
      DAG.
    """

    self.param_dict = self.get_default_param_dict()

    if kwargs is not None:
      for key, value in kwargs.items():
        self.param_dict[key] = value

  def get_default_param_dict(self):
    param_dict = {
        'num_samples': 5000,
        'k_x': 1,
        'k_y': 1,
        'mu_x_u': np.array([0.0, 0.0]),
        'beta_a': 0,
        'pi_a': np.array([0.5, 0.5]),
        'mu_y1_a': np.array([0.0, 0.0]),
        'mu_y2_a': np.array([0.0, 0.0]),
        'mu_y3_a': np.array([0.0, 0.5]),
        'mu_y_a': np.array([0, 2]),
        'mu_y_x': 1,
        'mu_y1_y': np.array([1]),
        'mu_y2_y': np.array([0.5]),
        'mu_y3_y': np.array([1.5]),
        'mu_y_x_base': 2,
        'sd_x': 3,
        'p_u': [0.5, 0.5],
        'loc_y1': 0.0,
        'scale_y1': 0.5,
        'loc_y2': 0.0,
        'scale_y2': 0.5,
        'loc_y3': 0.0,
        'scale_y3': 0.0,
    }
    return param_dict

  def get_samples(self, p_u=None, seed=42):
    """Generates samples from the simulation.

    Arguments:
      p_u: array that specifies the mixture proportions over latent categories u
      seed: a random seed

    Returns:
      a dict containing generated data
    """

    rng = jax.random.PRNGKey(seed)
    _, k0, _ = jax.random.split(rng, 3)

    ## Generate u
    if p_u is None:
      p_u = self.param_dict['p_u']

    u = np.random.binomial(1, p_u[1], size=self.param_dict['num_samples'])
    u_one_hot = OneHotEncoder(sparse=False).fit_transform(u.reshape(-1, 1))

    ## Generate x
    x = jax.random.multivariate_normal(
        key=k0,
        mean=(u_one_hot @ self.param_dict['mu_x_u']).reshape(-1, 1),
        cov=self.param_dict['sd_x'] * np.eye(self.param_dict['k_x']),
    )
    x = np.array(x).astype(np.float64)

    ## Generate a
    p_a = (
        self.param_dict['beta_a'] * u_one_hot
        + (1 - self.param_dict['beta_a']) * self.param_dict['pi_a']
    )
    a = np.random.binomial(1, p_a[:, 1], size=self.param_dict['num_samples'])
    a_one_hot = OneHotEncoder(sparse=False).fit_transform(a.reshape(-1, 1))

    ## Generate y
    mu_y_x = np.array(
        [[self.param_dict['mu_y_x'], -1 * self.param_dict['mu_y_x']]]
    )

    y = x.dot(mu_y_x)[
        np.arange(self.param_dict['num_samples']), np.squeeze(a)
    ].reshape(-1, 1) + (a_one_hot @ self.param_dict['mu_y_a']).reshape(-1, 1)
    p_y = expit(y)
    y_bin = np.squeeze(np.random.binomial(n=1, p=p_y))

    ## Generate y1

    y1_logits = y.dot(self.param_dict['mu_y1_y']).reshape(
        -1, 1
    ) + np.random.normal(
        loc=self.param_dict['loc_y1'],
        scale=self.param_dict['scale_y1'],
        size=self.param_dict['num_samples'],
    ).reshape(
        -1, 1
    )

    p_y1 = expit(y1_logits)
    y1 = np.squeeze(np.random.binomial(n=1, p=p_y1))
    y1_one_hot = OneHotEncoder(sparse=False).fit_transform(y1.reshape(-1, 1))

    ## Generate y2
    y2_logits = y.dot(self.param_dict['mu_y2_y']).reshape(
        -1, 1
    ) + np.random.normal(
        loc=self.param_dict['loc_y2'],
        scale=self.param_dict['scale_y2'],
        size=self.param_dict['num_samples'],
    ).reshape(
        -1, 1
    )

    p_y2 = expit(y2_logits)
    y2 = np.squeeze(np.random.binomial(n=1, p=p_y2))
    y2_one_hot = OneHotEncoder(sparse=False).fit_transform(y2.reshape(-1, 1))

    ## Generate y3

    y3_logits = (
        y.dot(self.param_dict['mu_y3_y']).reshape(-1, 1)
        + (np.tan(a_one_hot @ self.param_dict['mu_y3_a']).reshape(-1, 1))
        + np.random.normal(
            loc=self.param_dict['loc_y3'],
            scale=self.param_dict['scale_y3'],
            size=self.param_dict['num_samples'],
        ).reshape(-1, 1)
    )

    p_y3 = expit(y3_logits)
    y3 = np.squeeze(np.random.binomial(n=1, p=p_y3))
    y3_one_hot = OneHotEncoder(sparse=False).fit_transform(y3.reshape(-1, 1))

    return {
        'u': u,
        'a': a,
        'x': x,
        'y': y,
        'y_bin': y_bin,
        'p_y': p_y,
        'y1': y1,
        'y1_logits': y1_logits,
        'p_y1': p_y1,
        'y1_one_hot': y1_one_hot,
        'y2': y2,
        'y2_logits': y2_logits,
        'p_y2': p_y2,
        'y2_one_hot': y2_one_hot,
        'y3': y3,
        'y3_logits': y3_logits,
        'p_y3': p_y3,
        'y3_one_hot': y3_one_hot,
    }


def get_squeezed_df(data_dict: dict) -> pd.DataFrame:
  """Converts a dict of numpy arrays into a DataFrame, extracting columns of arrays into separate DataFrame columns."""
  temp = {}
  for key, value in data_dict.items():
    squeezed_array = np.squeeze(value)
    if len(squeezed_array.shape) == 1:
      temp[key] = squeezed_array
    elif len(squeezed_array.shape) > 1:
      for i in range(value.shape[1]):
        temp[f'{key}_{i}'] = np.squeeze(value[:, i])
  df = pd.DataFrame(temp)
  return df

In [None]:
def plot_calibration_curves(
    the_dict: dict,
    x_var: str = 'pred_probs',
    y_var: str = 'calibration_curve',
    ax=None,
    show_legend: bool = False,
    **kwargs,
):
  """Plot calibration curves.

  Arguments:
    the_dict: A dictionary containing values to be plotted for each subgroup,
      where subgroup is the key.
    x_var: A string specifying the x-axis variable.
    y_var: A string specifying the y-axis variable.
  """
  if ax is None:
    plt.figure()
    ax = plt.gca()
  for key in sorted(the_dict.keys()):
    ax.plot(the_dict[key][x_var], the_dict[key][y_var], label=key, **kwargs)

  ax.plot(
      np.linspace(0, 1, 100),
      np.linspace(0, 1, 100),
      alpha=0.5,
      linestyle='--',
      color='k',
      label='Identity',
  )
  sns.despine(ax=ax)
  if show_legend:
    plt.legend(sorted(the_dict.keys()))

In [None]:
def plot(evals: dict, title: str):
  """Plot calibration curves for different setups

  Arguments:
    evals: A dict containing dicts for different outcomes. Each inner dict is
      keyed by group, contains dataframes containing the fitted models. The
      dataframe is keyed by "outcomes" (true outcomes) and "pred_probs"
      (predicted probabilities generated by the fitted model).
    title: A string containing the title for the plot
  """
  plt.close()
  plt.rcParams['xtick.labelsize'] = 12
  plt.rcParams['ytick.labelsize'] = 12
  n_eval_outcomes = len(evals['y'])
  _, ax = plt.subplots(
      4, n_eval_outcomes, figsize=(18, 12), sharey=True, sharex=True
  )
  plt.subplots_adjust(wspace=0.1, hspace=0.05)

  ax[0][0].set_ylabel(r'P(Outcome)$', fontsize=24)
  ax[1][0].set_ylabel(r'P(Outcome)$', fontsize=24)
  ax[2][0].set_ylabel(r'P(Outcome)$', fontsize=24)
  ax[3][0].set_ylabel(r'P(Outcome)$', fontsize=24)

  ax[0][0].set_title('Eval Outcome: $Y_1$ ', fontsize=24)
  ax[0][1].set_title('Eval Outcome: $Y_2$', fontsize=24)
  ax[0][2].set_title('Eval Outcome: $Y_3$', fontsize=24)
  ax[0][3].set_title('Eval Outcome: $Y$', fontsize=24)

  # Plot calibration curve with proxy outcome y1
  for i in range(len(evals['y1'])):
    plot_calibration_curves(
        the_dict=evals['y1'][i],
        x_var='pred_probs',
        y_var='calibration_curve_y',
        ax=ax[0][i],
        lw=3,
    )

  # Plot calibration curve with proxy outcome y2
  for i in range(len(evals['y2'])):
    plot_calibration_curves(
        the_dict=evals['y2'][i],
        x_var='pred_probs',
        y_var='calibration_curve_y',
        ax=ax[1][i],
        lw=3,
    )

  # Plot calibration curve with proxy outcome y2
  for i in range(len(evals['y3'])):
    plot_calibration_curves(
        the_dict=evals['y3'][i],
        x_var='pred_probs',
        y_var='calibration_curve_y',
        ax=ax[2][i],
        lw=3,
    )

  # Plot calibration curve with true outcome
  for i in range(len(evals['y'])):
    plot_calibration_curves(
        the_dict=evals['y'][i],
        x_var='pred_probs',
        y_var='calibration_curve_y',
        ax=ax[3][i],
        lw=3,
    )

  plt.gcf().text(x=1, y=0.80, s='Train Outcome: $Y_1$', fontsize=24)
  plt.gcf().text(x=1, y=0.60, s='Train Outcome: $Y_2$', fontsize=24)
  plt.gcf().text(x=1, y=0.40, s='Train Outcome: $Y_3$', fontsize=24)
  plt.gcf().text(x=1, y=0.20, s='Train Outcome: $Y$', fontsize=24)
  plt.gcf().text(x=0.4, y=0.05, s='Predicted probability', fontsize=24)
  plt.gcf().text(x=0.35, y=1, s='Causal graph has ' + title, fontsize=34)

  sns.despine()
  plt.figure()

In [None]:
def train_setups(
    dgp_type: str = 'no_a',
    eval_outcomes: str = ['y1', 'y2', 'y3', 'y_bin'],
    features: str = ['x'],
    sim_samples_dict: dict = {},
    sim_samples_dict_eval: dict = {},
    model_type: str = 'gradient_boosting',
    stratified: bool = False,
):
  """Train and evaluate models with different outcomes.

  Arguments:
    dgp_type: A string specifying edge connections in the causal DAG.
    eval_outcomes: An iterable of strings indicating the different evaluation
      outcomes.
    features: A string or iterable of strings indicating the name(s) of the
      columns used as features.
    sim_samples_dict: A dictionary containing a pd.DataFrame for model training.
    sim_samples_dict_eval: A dictionary containing a pd.DataFrame for model
      evaluation.
    model_type: A string specifying the model type.
    stratified: A boolean flag for training group specific models.

  Returns:
    An a dict containing dicts for different training outcomes.
    Each dict keyed by group with dict values.
    The inner dicts contain, for each group, the fitted model and numpy arrays
    corresponding to the features, labels, and predictions.
  """

  evals = dict()
  # outcome for training is y1
  evals['y1'] = [
      utils.fit_models_df(
          source_df=sim_samples_dict[dgp_type],
          target_df=sim_samples_dict_eval[dgp_type],
          outcome_key='y1',
          outcome_key_target=eval_outcome,
          features_keys=features,
          model_type=model_type,
          stratified=stratified,
      )
      for eval_outcome in eval_outcomes
  ]
  # outcome for training is y2
  evals['y2'] = [
      utils.fit_models_df(
          source_df=sim_samples_dict[dgp_type],
          target_df=sim_samples_dict_eval[dgp_type],
          outcome_key='y2',
          outcome_key_target=eval_outcome,
          features_keys=features,
          model_type=model_type,
          stratified=stratified,
      )
      for eval_outcome in eval_outcomes
  ]
  # outcome for training is y3
  evals['y3'] = [
      utils.fit_models_df(
          source_df=sim_samples_dict[dgp_type],
          target_df=sim_samples_dict_eval[dgp_type],
          outcome_key='y3',
          outcome_key_target=eval_outcome,
          features_keys=features,
          model_type=model_type,
          stratified=stratified,
      )
      for eval_outcome in eval_outcomes
  ]
  # outcome for training is h_bin
  evals['y'] = [
      utils.fit_models_df(
          source_df=sim_samples_dict[dgp_type],
          target_df=sim_samples_dict_eval[dgp_type],
          outcome_key='y_bin',
          outcome_key_target=eval_outcome,
          features_keys=features,
          model_type=model_type,
          stratified=stratified,
      )
      for eval_outcome in eval_outcomes
  ]
  return evals

In [None]:
sim_dict = {
    'y3_a_no_y_a': Simulator(
        mu_y3_a=np.array([0, -2]), mu_y_a=np.array([0, 0]), scale_y2=15
    ),
    'y_a_no_y3_a': Simulator(
        mu_y3_a=np.array([0, 0]), mu_y_a=np.array([-1, 1]), scale_y2=15
    ),
    'no_a': Simulator(
        mu_y3_a=np.array([0, 0]), mu_y_a=np.array([0, 0]), scale_y2=15
    ),
    'y_a_y3_a': Simulator(
        mu_y3_a=np.array([0, -2]), mu_y_a=np.array([-1, 1]), scale_y2=4
    ),
}
sim_samples_dict = {
    key: get_squeezed_df(value.get_samples(seed=i))
    for i, (key, value) in enumerate(sim_dict.items())
}
sim_samples_dict_eval = {
    key: get_squeezed_df(value.get_samples(seed=2 * i))
    for i, (key, value) in enumerate(sim_dict.items())
}
DISPLAY_PLOTS = True
if DISPLAY_PLOTS:
  plt.close()
  for key, value in sim_samples_dict.items():
    plt.figure()
    sns.kdeplot(value, x='x', hue='a')
    plt.title(f'{key}: x')
    plt.figure()
    sns.kdeplot(value, x='y', hue='a')
    plt.title(f'{key}: y')
    plt.figure()
    sns.kdeplot(value, x='p_y', hue='a')
    plt.title(f'{key}: p_y')
    plt.figure()
    sns.kdeplot(value, x='y1_logits', hue='a')
    plt.title(f'{key}: y1_logits')
    plt.figure()
    sns.kdeplot(value, x='y2_logits', hue='a')
    plt.title(f'{key}: y2_logits')
    plt.figure()
    sns.kdeplot(value, x='y3_logits', hue='a')
    plt.title(f'{key}: y3_logits')

# Scenario 1: Model trained with [$X$]

In [None]:
dgp_types = ['no_a', 'y3_a_no_y_a', 'y_a_no_y3_a', 'y_a_y3_a']
eval_outcomes = ['y1', 'y2', 'y3', 'y_bin']
features = ['x']

### This refers to the data generating process without $A → Y$ or $A → Y_3$.

In [None]:
print("Evaluation results for {}".format(dgp_types[0]))
evals = train_setups(
    dgp_types[0],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[0])

### This refers to the data generating process without $A \rightarrow Y$, but with $A → Y_3$.

In [None]:
print("Evaluation results for {}".format(dgp_types[1]))
evals = train_setups(
    dgp_types[1],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[1])

### This refers to the data generating process with $A → Y$, but without $A → Y_3$.

In [None]:
print("Evaluation results for {}".format(dgp_types[2]))
evals = train_setups(
    dgp_types[2],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[2])

### This refers to the data generating process with $A → Y$ and $A → Y_3$.

In [None]:
print("Evaluation results for {}".format(dgp_types[3]))
evals = train_setups(
    dgp_types[3],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[3])

# Scenario 2: Models trained with [$X$,$A$]

In [None]:
features = ['x', 'a']

In [None]:
print("Evaluation results for {}".format(dgp_types[0]))
evals = train_setups(
    dgp_types[0],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[0])

In [None]:
print("Evaluation results for {}".format(dgp_types[1]))
evals = train_setups(
    dgp_types[1],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[1])

In [None]:
print("Evaluation results for {}".format(dgp_types[2]))
evals = train_setups(
    dgp_types[2],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[2])

In [None]:
print("Evaluation results for {}".format(dgp_types[3]))
evals = train_setups(
    dgp_types[3],
    eval_outcomes,
    features,
    sim_samples_dict,
    sim_samples_dict_eval,
)
plot(evals, dgp_types[3])