Copyright 2025 The Google Research Authors

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
```
 http://www.apache.org/licenses/LICENSE-2.0
```
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Simluation study -- model fitting

In [None]:
import os
import random
import numpy as np
import pandas as pd
import scipy

from causal_evaluation import utils
from causal_evaluation.experiments import simulator

In [None]:
# Flags
## Paths are relative to the directory of the ipynb file
DATA_PATH = './../../data/simulation' # @param
FIT_MODELS = True  # @param
WRITE_PREDS = True  # @param

N_SAMPLES_TRAIN = 50000
N_SAMPLES_EVAL = 20000

model_type = 'gradient_boosting' # @param
group_model_type = 'gradient_boosting' # @param

In [None]:
os.makedirs(DATA_PATH, exist_ok=True)

In [None]:
np.random.seed(173)
random.seed(100)

In [None]:
def get_sim_dict(**kwargs):

  sim_dict = {
      'covariate_shift': simulator.Simulator(**kwargs),
      'no_shift': simulator.Simulator(beta_a=0, **kwargs),
      'outcome_shift': simulator.Simulator(
          a_to_y=True, beta_a=0, mu_y_a=np.array([0.1, 0]), **kwargs
      ),
      'complex_causal_shift': simulator.Simulator(
          a_to_y=True, mu_y_a=np.array([0.1, 0]), **kwargs
      ),
      'low_overlap_causal': simulator.Simulator(
          a_to_y=True,
          mu_y_a=np.array([0.1, 0]),
          mu_x_u=np.array([-2, 2]),
          **kwargs,
      ),
      'anticausal_label_shift': simulator.SimulatorAnticausal(**kwargs),
      'anticausal_presentation_shift': simulator.SimulatorAnticausal(
          mu_x_ay=np.array([[1, 0], [-1, 1]]),
          mu_y_u=np.array([[0.5, 0.5], [0.5, 0.5]]),
          **kwargs,
      ),
      'complex_anticausal_shift': simulator.SimulatorAnticausal(
          mu_x_ay=np.array([[1, 0], [-1, 1]]),
          **kwargs,
      ),
  }
  return sim_dict


sim_samples_dict = {
    key: utils.get_squeezed_df(value.get_samples(seed=i))
    for i, (key, value) in enumerate(
        get_sim_dict(num_samples=N_SAMPLES_TRAIN).items()
    )
}

sim_samples_dict_eval = {
    key: utils.get_squeezed_df(value.get_samples(seed=2 * i))
    for i, (key, value) in enumerate(
        get_sim_dict(num_samples=N_SAMPLES_EVAL).items()
    )
}

In [None]:
sim_samples_df = (
    pd.concat(sim_samples_dict)
    .reset_index(level=-1, drop=True)
    .rename_axis('setting')
    .reset_index()
)
sim_samples_df

## Fit models for all settings

In [None]:
settings = [
    'covariate_shift',
    'outcome_shift',
    'complex_causal_shift',
    'low_overlap_causal',
    'anticausal_label_shift',
    'anticausal_presentation_shift',
    'complex_anticausal_shift',
]

In [None]:
# Fit models using X as features
if FIT_MODELS:

  for setting in settings:
    print(f'Setting: {setting}', flush=True)
    # Fit model for E[Y | X]
    model = utils.fit_model(
        sim_samples_dict[setting]['x'].values.reshape(-1, 1),
        sim_samples_dict[setting]['y'].values,
        model_type=model_type,
        model_cross_val=True,
    )

    # Predict y in the eval data
    sim_samples_dict_eval[setting]['pred_probs_y_x'] = utils.array_to_series(
        model.predict_proba(
            sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1)
        )
    )
    sim_samples_dict_eval[setting]['pred_probs_y1_x'] = sim_samples_dict_eval[
        setting
    ]['pred_probs_y_x'].map(lambda x: x[-1])

    # Fit model stratified
    model_dict = utils.fit_model_stratified(
        sim_samples_dict[setting]['x'].values.reshape(-1, 1),
        sim_samples_dict[setting]['y'].values,
        group=sim_samples_dict[setting]['a'].values,
        model_type=model_type,
        model_cross_val=True,
    )

    # Predict y in the eval data
    sim_samples_dict_eval[setting]['pred_probs_y_xa_stratified'] = (
        utils.array_to_series(
            utils.predict_proba_stratified(
                sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1),
                model_dict,
                group=sim_samples_dict_eval[setting]['a'].values,
            )
        )
    )
    sim_samples_dict_eval[setting]['pred_probs_y1_xa_stratified'] = (
        sim_samples_dict_eval[setting]['pred_probs_y_xa_stratified'].map(
            lambda x: x[-1]
        )
    )

    # Fit model with X and A
    model_xa = utils.fit_model(
        np.concatenate(
            (
                sim_samples_dict[setting]['x'].values.reshape(-1, 1),
                sim_samples_dict[setting]['a'].values.reshape(-1, 1),
            ),
            axis=1,
        ),
        sim_samples_dict[setting]['y'].values,
        model_type=model_type,
        model_cross_val=True,
    )

    # Predict y in the eval data
    sim_samples_dict_eval[setting]['pred_probs_y_xa'] = utils.array_to_series(
        model_xa.predict_proba(
            np.concatenate(
                (
                    sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1),
                    sim_samples_dict_eval[setting]['a'].values.reshape(-1, 1),
                ),
                axis=1,
            )
        )
    )

    sim_samples_dict_eval[setting]['pred_probs_y1_xa'] = sim_samples_dict_eval[
        setting
    ]['pred_probs_y_xa'].map(lambda x: x[-1])

In [None]:
# Fit models of group membership

if FIT_MODELS:

  for setting in settings:
    print(f'Setting: {setting}', flush=True)
    # Fit model to predict P(A | X)
    model_group_x = utils.fit_model(
        sim_samples_dict[setting]['x'].values.reshape(-1, 1),
        sim_samples_dict[setting]['a'].values,
        model_type=group_model_type,
        model_cross_val=True,
    )

    # Apply P(A | X) model to the evaluation data
    sim_samples_dict_eval[setting]['pred_probs_group_x'] = (
        utils.array_to_series(
            model_group_x.predict_proba(
                sim_samples_dict_eval[setting]['x'].values.reshape(-1, 1)
            )
        )
    )

    # Fit model to predict P(A | Y)
    model_group_y = utils.fit_model(
        sim_samples_dict[setting]['y'].values.reshape(-1, 1),
        sim_samples_dict[setting]['a'].values,
        model_type=group_model_type,
        model_cross_val=True,
    )
    # Apply P(A | Y) model to the evaluation data
    sim_samples_dict_eval[setting]['pred_probs_group_y'] = (
        utils.array_to_series(
            model_group_y.predict_proba(
                sim_samples_dict_eval[setting]['y'].values.reshape(-1, 1)
            )
        )
    )

    # Estimate P(A | R_x) out-of-sample in the eval data using nested cross-validation
    sim_samples_dict_eval[setting]['pred_probs_group_r_x'] = (
        utils.array_to_series(
            utils.fit_cross_val_predict(
                scipy.special.logit(
                    sim_samples_dict_eval[setting][
                        'pred_probs_y1_x'
                    ].values.reshape(-1, 1)
                ),
                sim_samples_dict_eval[setting]['a'].values,
                model_type=group_model_type,
                model_cross_val=True,
            )
        )
    )

    # Estimate P(A | R_xa) out-of-sample in the eval data using nested cross-validation, stratitfied
    sim_samples_dict_eval[setting]['pred_probs_group_r_xa_stratified'] = (
        utils.array_to_series(
            utils.fit_cross_val_predict(
                scipy.special.logit(
                    sim_samples_dict_eval[setting][
                        'pred_probs_y1_xa_stratified'
                    ].values.reshape(-1, 1)
                ),
                sim_samples_dict_eval[setting]['a'].values,
                model_type=group_model_type,
                model_cross_val=True,
            )
        )
    )

    # Estimate P(A | R_xa) out-of-sample in the eval data using nested cross-validation, with A included in the feature set
    sim_samples_dict_eval[setting]['pred_probs_group_r_xa'] = (
        utils.array_to_series(
            utils.fit_cross_val_predict(
                scipy.special.logit(
                    sim_samples_dict_eval[setting][
                        'pred_probs_y1_xa'
                    ].values.reshape(-1, 1)
                ),
                sim_samples_dict_eval[setting]['a'].values,
                model_type=group_model_type,
                model_cross_val=True,
            )
        )
    )

## Write the predictions

In [None]:
for setting in settings:
  filename = f'sim_samples_eval_{setting}_{N_SAMPLES_TRAIN}_{N_SAMPLES_EVAL}_{model_type}_{group_model_type}.parquet'
  if WRITE_PREDS and FIT_MODELS:
    sim_samples_dict_eval[setting].to_parquet(
        os.path.join(DATA_PATH, filename), index=False)