Copyright 2025 The Google Research Authors

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at
```
 http://www.apache.org/licenses/LICENSE-2.0
```
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

# Visualize results and export figures

In [None]:
import itertools
import os
from IPython.display import display
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import warnings
import matplotlib.ticker as ticker
import textwrap

from causal_evaluation import utils

In [None]:
# Flags
DATA_PATH = './../../data/simulation' # @param
N_SAMPLES_TRAIN = 50000 # @param
N_SAMPLES_EVAL = 20000 # @param
model_type = 'gradient_boosting' # @param
group_model_type = 'gradient_boosting' # @param
FIGURE_PATH = './../../figures'

In [None]:
os.makedirs(FIGURE_PATH, exist_ok=True)

In [None]:
settings = [
    'covariate_shift',
    'outcome_shift',
    'complex_causal_shift',
    'low_overlap_causal',
    'anticausal_label_shift',
    'anticausal_presentation_shift',
    'complex_anticausal_shift',
]

In [None]:
# Read the predictions
pred_eval_dict = {}

for setting in settings:
  filename = f'sim_samples_eval_{setting}_{N_SAMPLES_TRAIN}_{N_SAMPLES_EVAL}_{model_type}_{group_model_type}.parquet'
  pred_eval_dict[setting] = pd.read_parquet(
      os.path.join(DATA_PATH, filename)
  )

In [None]:
# Read the eval results
result_df = pd.concat([
    pd.read_parquet(
        os.path.join(
            DATA_PATH,
            f'metrics_{setting}_{N_SAMPLES_TRAIN}_{N_SAMPLES_EVAL}_{model_type}_{group_model_type}.parquet',
        )
    )
    for setting in settings
])

## Plotting functions

In [None]:
def calibration_plot(
    y_true,
    y_prob,
    group=None,
    ax=None,
    plot_overall=True,
    palette=None,
    legend=False,
    xlabel='Score',
    ylabel='Fraction positive',
    **kwargs,
):

  if ax is None:
    plt.figure()
    ax = plt.gca()

  if kwargs.get('n_bins') is None:
    kwargs['n_bins'] = 10

  if palette is None:
    palette = sns.color_palette('Set2')

  palette_count = 0
  if plot_overall:
    calibration_curve_result = utils.calibration_curve_ci(
        y_true, y_prob, **kwargs
    )
    ax.plot(
        calibration_curve_result[1],
        calibration_curve_result[0],
        label='Overall',
        color=palette[0],
    )
    ax.fill_between(
        calibration_curve_result[1],
        calibration_curve_result[2],
        calibration_curve_result[3],
        alpha=0.3,
        color=palette[0],
    )
    palette_count = 1

  if group is not None:
    df = pd.DataFrame({'y_true': y_true, 'y_prob': y_prob, 'group': group})
    for i, (the_group, group_df) in enumerate(df.groupby('group')):
      if group_df.shape[0] > 0:
        calibration_curve_result = utils.calibration_curve_ci(
            group_df['y_true'],
            group_df['y_prob'],
        )
        ax.plot(
            calibration_curve_result[1],
            calibration_curve_result[0],
            label=f'{the_group}',
            linewidth=2,
            color=palette[i + palette_count],
        )
        ax.fill_between(
            calibration_curve_result[1],
            calibration_curve_result[2],
            calibration_curve_result[3],
            alpha=0.3,
            color=palette[i + palette_count],
        )

  ax.plot(
      np.linspace(0, 1, 100),
      np.linspace(0, 1, 100),
      linestyle='-.',
      color='gray',
  )

  ax.set_xlim(0, 1)
  ax.set_ylim(0, 1)

  if xlabel is not None:
    ax.set_xlabel(xlabel, size=14)
  if ylabel is not None:
    ax.set_ylabel(ylabel, size=14)

  if legend:
    ax.legend()

  sns.despine()


def pointplot_with_errorbars(
    x, y, hue, ci_low, ci_high, vertical=True, pairwise=False, **kwargs
):
  ax = plt.gca()

  # Set the palette
  palette = sns.color_palette('Set2')

  point_scale = kwargs.get('point_scale', 0.75)

  # Create the pointplot
  points = sns.pointplot(
      x=x,
      y=y,
      ax=ax,
      hue=hue,
      dodge=0.4,
      linestyles='none',
      palette=palette,
      scale=point_scale,
      hue_order=kwargs.get('hue_order'),
  )

  # Get the positions of the points
  point_pos = []
  point_colors = []
  for collection in points.collections:
    point_pos.extend(
        collection.get_offsets()[:, (1 - vertical)]
    )  # Get coordinates from offsets
    point_colors.extend(
        [collection.get_facecolor()] * len(collection.get_offsets())
    )  # Repeat color for each point
  # Add error bars with matching colors
  for i, pos in enumerate(point_pos):
    if vertical:
      ax.vlines(
          x=pos,
          ymin=ci_low.values[i],
          ymax=ci_high.values[i],
          color=point_colors[i],
      )
    else:
      ax.hlines(
          y=pos,
          xmin=ci_low.values[i],
          xmax=ci_high.values[i],
          color=point_colors[i],
      )

  if pairwise:
    if vertical:
      ax.axhline(y=0, color='gray', linestyle='--')
    else:
      ax.axvline(x=0, color='gray', linestyle='--')

In [None]:
def format_zero(value, tick_number, num_digits=2):
    if value == 0.0:
        return "0"
    else:
        return f"{value:.{num_digits}f}" #Default format

def wrap_titles(g, width=20):
  """Wraps the titles of a seaborn FacetGrid."""
  for ax in g.axes.flat:
    title = ax.get_title()
    if title:
      wrapped_title = textwrap.fill(title, width=width, break_long_words=False)
      ax.set_title(wrapped_title)
  plt.tight_layout()
  return g

# Simluation study results

In [None]:
# Map population_on_group_comparison_weights to a weight_type
the_filter = result_df.weights.map(lambda x: 'population_on_group_comparison' in x).values
result_df.loc[the_filter, 'weight_type'] = "population_on_group_comparison"
the_filter_normalized = result_df.weights.map(lambda x: 'population_on_group_comparison_weights_normalized' in x).values
result_df.loc[the_filter_normalized, 'weight_type'] = "population_on_group_comparison_normalized"

In [None]:
# Create a dataframe for plotting
plot_df = result_df.copy()

# Map the weights
weight_mapping_dict = {
    'weights_none': 'None',
    'weights_none_overall': 'None',
    'weights_population_x': 'X',
    'weights_population_y': 'Y',
    'weights_population_r_x': 'R',
    'weights_population_r_xa': 'R',
    'weights_population_r_xa_stratified': 'R',
    'weights_stable_x': 'X',
    'weights_stable_y': 'Y',
    'weights_stable_r_x': 'R',
    'weights_stable_r_xa': 'R',
    'weights_stable_r_xa_stratified': 'R',
    'weights_cross_group_x': 'X',
    'weights_cross_group_y': 'Y',
    'weights_cross_group_r_x': 'R',
    'weights_cross_group_r_xa': 'R',
    'weights_cross_group_r_xa_stratified': 'R',
    'population_on_group_comparison_weights_none': 'None',
    'population_on_group_comparison_weights_x': 'X',
    'population_on_group_comparison_weights_y': 'Y',
    'population_on_group_comparison_weights_r_x': 'R',
    'population_on_group_comparison_weights_r_xa': 'R',
    'population_on_group_comparison_weights_r_xa_stratified': 'R',
    'population_on_group_comparison_weights_normalized_x': 'X',
    'population_on_group_comparison_weights_normalized_y': 'Y',
    'population_on_group_comparison_weights_normalized_r_x': 'R',
    'population_on_group_comparison_weights_normalized_r_xa': 'R',
    'population_on_group_comparison_weights_normalized_r_xa_stratified': 'R',
}
plot_df['weights'] = result_df['weights'].replace(weight_mapping_dict)

plot_df['weights'] = pd.Categorical(
    plot_df['weights'], ['None', 'X', 'Y', 'R']
)

setting_title_dict = {
    'covariate_shift': 'Covariate Shift',
    'outcome_shift': 'Outcome Shift',
    'complex_causal_shift': 'Complex Causal',
    'low_overlap_causal': 'Separable',
    'anticausal_label_shift': 'Label Shift',
    'anticausal_presentation_shift': 'Presentation Shift',
    'complex_anticausal_shift': 'Complex Anticausal',
}
plot_df['setting'] = plot_df['setting'].replace(setting_title_dict)

plot_df['setting'] = pd.Categorical(
    plot_df['setting'],
    list(setting_title_dict.values())
)

plot_df['features'] = plot_df['features'].replace(
    {'features_x': 'X', 'features_xa': 'XA', 'features_xa_stratified': 'X_strat'}
)

metric_dict = {'log_loss': 'log loss',
        'roc_auc': 'AUC-ROC',
        'label_rate': 'cls rate',
        'recall_0.5': 'sensitivity',
        'specificity_0.5': 'specificity',
        'precision_0.5': 'precision',
        'net_benefit_0.5_0.5': 'net benefit'}
plot_df['metric'] = plot_df['metric'].replace(
metric_dict
)

plot_df['metric'] = pd.Categorical(
    plot_df['metric'],
    list(metric_dict.values())
)

plot_df['group'] = plot_df['group'].replace(
    {"overall": "Population"}
)
plot_df = plot_df.query('~weights.isna()')

In [None]:
relative_comparison_features = ["relative_comparison", "relative_comparison_stratified"]

In [None]:
metric='log loss'
weight_types = ['None']
settings_to_plot = setting_title_dict.values()

feature_set = ['X', 'XA']

g = sns.FacetGrid(
  (
      plot_df.query(
          'weight_type == "population_on_group_comparison" &'
          ' metric == @metric & setting in @settings_to_plot &'
          ' features in @feature_set'
      )
      .sort_values(['group', 'setting'])
      .assign(setting=lambda x: x.setting.values.remove_unused_categories(),
              weights=lambda x: x.weights.values.remove_unused_categories())
  ),
  col='weights',
  row='features',
  margin_titles=True,
  sharex=True
)

g.map(
  pointplot_with_errorbars,
  'performance',
  'setting',
  'group',
  'ci_low',
  'ci_high',
  vertical=False,
  pairwise=True,
  point_scale=0.5
)

g.set_axis_labels('', '')
g.fig.supylabel(f'Setting', fontsize=16, x=-0.06, y=0.57)
g.fig.supxlabel(f'Subgroup {metric} - weighted population estimate', fontsize=16, x=0.55, y=-0.075)

g.set_titles(
  row_template='Covariates: {row_name}', col_template='Weights: {col_name}', size=12
)

for ax in g.axes.flat:
  ax.tick_params(axis='x', labelsize=12)  # Set x-axis label size to 12

# Legend
g.add_legend(fontsize=12)
g.legend.set_title('Group')
sns.move_legend(g, 'center right', bbox_to_anchor=(1.03, 0.57))
plt.setp(g._legend.get_title(), fontsize=14)
plt.setp(g._legend.get_texts(), fontsize=12)

g.figure.set_size_inches(8.5, 3)
for filetype in ['png', 'pdf']:
  g.figure.savefig(
      os.path.join(FIGURE_PATH, f'adjustment_weights_xyr_{metric}_main.{filetype}'), format=filetype, bbox_inches='tight'
  )

In [None]:
with warnings.catch_warnings():
  warnings.simplefilter("ignore")
  for metric in metric_dict.values():

    weight_types = ['None']
    settings_to_plot = setting_title_dict.values()

    feature_set = ['X', 'XA', 'X_strat']

    g = sns.FacetGrid(
        (
            plot_df.query(
                'weight_type == "population_on_group_comparison" &'
                ' metric == @metric & setting in @settings_to_plot &'
                ' features in @feature_set'
            )
            .sort_values(['group', 'setting'])
            .assign(setting=lambda x: x.setting.values.remove_unused_categories(),
                    weights=lambda x: x.weights.values.remove_unused_categories())
        ),
        col='weights',
        row='features',
        margin_titles=True,
        sharex=True
    )

    g.map(
        pointplot_with_errorbars,
        'performance',
        'setting',
        'group',
        'ci_low',
        'ci_high',
        vertical=False,
        pairwise=True,
        point_scale=0.5
    )

    g.set_axis_labels('', '')
    g.fig.supylabel(f'Setting', fontsize=16, x=-0.06, y=0.57)
    g.fig.supxlabel(f'Subgroup {metric} - weighted population estimate', fontsize=16, x=0.55, y=-0.075)

    g.set_titles(
        row_template='Cov: {row_name}', col_template='Weights: {col_name}', size=12
    )

    for ax in g.axes.flat:
      ax.tick_params(axis='x', labelsize=12)  # Set x-axis label size to 12

    # Legend
    g.add_legend(fontsize=12)
    g.legend.set_title('Group')
    sns.move_legend(g, 'center right', bbox_to_anchor=(1.03, 0.57))
    plt.setp(g._legend.get_title(), fontsize=14)
    plt.setp(g._legend.get_texts(), fontsize=12)

    g.figure.set_size_inches(8.5, 5)
    for filetype in ['png', 'pdf']:
      g.figure.savefig(
          os.path.join(FIGURE_PATH, f'adjustment_weights_xyr_{metric}.{filetype}'), format=filetype, bbox_inches='tight'
      )

In [None]:
# Calibration plots
figsize = (8.5, 3.5)
fig, axes = plt.subplots(3, len(settings), sharex=True, sharey=True, figsize=figsize)

for i, (setting, feature_set) in enumerate(
    itertools.product(settings, ['x', 'xa', 'xa_stratified'])
):

  calibration_plot(
      pred_eval_dict[setting]['y'],
      pred_eval_dict[setting][f'pred_probs_y1_{feature_set}'],
      group=pred_eval_dict[setting]['a'],
      plot_overall=False,
      ax=axes[i % 3][i // 3],
      legend=False,
      xlabel=None,
      ylabel=None,
  )

for i, setting in enumerate(settings):
  axes[0][i].set_title(setting_title_dict[setting], size=8)

axes[0][-1].text(x=1.02, y=0.3, s='Cov.: X', rotation=-90, size=8)
axes[1][-1].text(x=1.02, y=0.3, s='Cov.: XA', rotation=-90, size=8)
axes[2][-1].text(x=1.02, y=0.1, s='Cov.: X_strat', rotation=-90, size=8)

fig.supxlabel('Score', size=14, y=0.005)
fig.supylabel('Fraction positive', size=14, x=0.03)
plt.tight_layout()
fig.set_size_inches(*figsize)
axes[1][-1].legend(frameon=False, bbox_to_anchor=(1.25, 1.2), title='Group')
for filetype in ['png', 'pdf']:
  fig.savefig(
      os.path.join(FIGURE_PATH, f'calibration.{filetype}'), format=filetype, bbox_inches='tight'
  )

In [None]:
weight_types = ['None']
metrics = ['log loss', 'AUC-ROC', 'net benefit', 'sensitivity', 'specificity']
settings_to_plot = setting_title_dict.values()

g = sns.FacetGrid(
    (
        plot_df.query(
            'weight_type in @weight_types & features == "relative_comparison" &'
            ' metric.isin(@metrics) & setting in @settings_to_plot'
        )
        .sort_values(['group', 'setting'])
        .assign(
            setting=lambda x: x.setting.values.remove_unused_categories(),
            metric=lambda x: x.metric.values.remove_unused_categories(),
            )
    ),
    col='metric',
    margin_titles=True,
    sharex=False
)

g.map(
    pointplot_with_errorbars,
    'performance',
    'setting',
    'group',
    'ci_low',
    'ci_high',
    vertical=False,
    pairwise=True,
    point_scale=0.5
)

g.set_axis_labels('', '')
g.fig.supylabel(f'Setting', fontsize=16, x=-0.1, y=0.6)
g.fig.supxlabel(f'Difference in performance (subgroup-aware vs. unaware model)', fontsize=16, x=0.5, y=-0.02)

g.set_titles(
    row_template='', col_template='{col_name}', size=12
)

for ax in g.axes.flat:
  ax.tick_params(axis='x', labelsize=10)
  ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_zero))

# Legend
g.add_legend(fontsize=10)
g.legend.set_title('Group')
sns.move_legend(g, 'center right', bbox_to_anchor=(1.07, 0.57))
plt.setp(g._legend.get_title(), fontsize=14)
plt.setp(g._legend.get_texts(), fontsize=12)

g.figure.set_size_inches(8.5, 2.5)
for filetype in ['png', 'pdf']:
  g.figure.savefig(
      os.path.join(FIGURE_PATH, f'relative_performance.{filetype}'), format=filetype, bbox_inches="tight"
  )

In [None]:
# Population mapping results
with warnings.catch_warnings():
  warnings.simplefilter("ignore")
  for metric in metric_dict.values():
    weight_types = ['population', 'None']
    settings_to_plot = setting_title_dict.values()
    the_plot_df = (plot_df.query(
        'weight_type in @weight_types & ~features.isin(@relative_comparison_features) &'
        ' metric == @metric & setting in @settings_to_plot'
    ).query('group != "Population"')
    .sort_values(['group', 'setting', 'weights'])
    .assign(setting=lambda x: x.setting.values.remove_unused_categories(),
            weights=lambda x: x.weights.replace('None', 'Subgroup')
            )
    )

    g = sns.FacetGrid(
        the_plot_df,
        row='setting',
        col='features',
        margin_titles=True,
        sharex='row',
    )

    g.map(
        pointplot_with_errorbars,
        'performance',
        'weights',
        'group',
        'ci_low',
        'ci_high',
        vertical=False,
        point_scale=0.5
    )

    g.set_axis_labels('', '')
    g.fig.supylabel(f'Control variable', fontsize=16, x=-0.02)
    g.fig.supxlabel(f'Weighted model performance ({metric})', fontsize=16, x=0.5, y=-0.02)

    g.set_titles(
        row_template='{row_name}', col_template='Features: {col_name}', size=10
    )
    g = wrap_titles(g)

    for ax in g.axes.flat:
      ax.tick_params(axis='x', labelsize=10)
      ax.tick_params(axis='y', labelsize=10)
      if metric == "net benefit":
        num_digits=3
        ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda *args: format_zero(*args, num_digits=num_digits)))

    # Legend
    g.add_legend(fontsize=12)
    g.legend.set_title('Group')
    sns.move_legend(g, 'center right', bbox_to_anchor=(1.03, 0.50))
    plt.setp(g._legend.get_title(), fontsize=14)
    plt.setp(g._legend.get_texts(), fontsize=12)

    g.fig.subplots_adjust(hspace=0.2, wspace=0.1)
    g.figure.set_size_inches(8.5, 11)
    for filetype in ['png', 'pdf']:
      g.figure.savefig(
          os.path.join(FIGURE_PATH, f'adjustment_weights_absolute_{metric}.{filetype}'), format=filetype, bbox_inches="tight"
      )

In [None]:
# Plot absolute performance across settings using shared space weights
weight_types = ['stable', 'None']

for metric in metric_dict.values():

  g = sns.FacetGrid(
      (
          plot_df.query(
              'weight_type in @weight_types & ~features.isin(@relative_comparison_features) &'
              ' metric == @metric'
          )
          .sort_values(['weights', 'setting', 'group'])
          .assign(setting=lambda x: x.setting.values.remove_unused_categories())
      ),
      row='setting',
      col='weights',
      margin_titles=True,
      sharey='row',
  )

  g.map(
      pointplot_with_errorbars,
      'features',
      'performance',
      'group',
      'ci_low',
      'ci_high',
      vertical=True,
      point_scale=0.5
  )

  g.set_axis_labels('', '')
  g.fig.supylabel(f'Shared space performance ({metric})', fontsize=16, x=-0.01)
  g.fig.supxlabel(f'Covariate set', fontsize=16, x=0.45, y=-0.03)

  g.set_titles(
      row_template='{row_name}', col_template='Weights: {col_name}', size=8
  )
  g = wrap_titles(g)

  for ax in g.axes.flat:
    ax.tick_params(axis='x', labelsize=10)  # Set x-axis label size to 12

  # Legend
  g.add_legend(fontsize=12)
  g.legend.set_title('Group')
  sns.move_legend(g, 'center right', bbox_to_anchor=(1.05, 0.57))
  plt.setp(g._legend.get_title(), fontsize=14)
  plt.setp(g._legend.get_texts(), fontsize=12)


  g.figure.set_size_inches(8.5, 9)
  for filetype in ['png', 'pdf']:
    g.figure.savefig(
        os.path.join(FIGURE_PATH, f'shared_space_weights_xyr_{metric}.{filetype}'), format=filetype, bbox_inches="tight"
    )

# ACS PUMS Results

In [None]:
# Flags
DATA_PATH_ACS = '.../../data/acs_pums/' # @param

In [None]:
tasks = [
    'ACSIncome',
    'ACSPublicCoverage'
]

In [None]:
# Read the predictions
pred_eval_dict_acs = {}

for task in tasks:
  filename = f'preds_{task}_5-Year_2018_gradient_boosting.parquet'
  pred_eval_dict_acs[task] = pd.read_parquet(
      os.path.join(DATA_PATH_ACS, filename)
  )

In [None]:
# Read the eval results
result_df_acs = pd.concat([
    pd.read_parquet(
        os.path.join(
            DATA_PATH_ACS,
            f'metrics_{task}_5-Year_2018_gradient_boosting.parquet',
        )
    )
    for task in tasks
])

In [None]:
# Map population_on_group_comparison_weights to a weight_type
the_filter = result_df_acs.weights.map(lambda x: 'population_on_group_comparison' in x).values
result_df_acs.loc[the_filter, 'weight_type'] = "population_on_group_comparison"
the_filter_normalized = result_df_acs.weights.map(lambda x: 'population_on_group_comparison_weights_normalized' in x).values
result_df_acs.loc[the_filter_normalized, 'weight_type'] = "population_on_group_comparison_normalized"

In [None]:
group_name_map_df = pred_eval_dict_acs[tasks[0]][['group', 'group_name']].drop_duplicates()
group_name_map_df['group'] = group_name_map_df['group'].astype(str)
group_name_map_df

In [None]:
plot_df_acs = result_df_acs.copy()

plot_df_acs = plot_df_acs.merge(group_name_map_df, how='left')

plot_df_acs['weights'] = plot_df_acs['weights'].replace(weight_mapping_dict)

plot_df_acs['weights'] = pd.Categorical(
    plot_df_acs['weights'], ['None', 'X', 'Y', 'R']
)

plot_df_acs['task'] = pd.Categorical(
    plot_df_acs['task'],
    list(['ACSIncome', 'ACSPublicCoverage'])
)

plot_df_acs['features'] = plot_df_acs['features'].replace(
    {'features_x': 'X', 'features_xa': 'XA', 'features_xa_stratified': 'X_strat'}
)

plot_df_acs['metric'] = plot_df_acs['metric'].replace(
  metric_dict
)

plot_df_acs['metric'] = pd.Categorical(
    plot_df_acs['metric'],
    list(metric_dict.values())
)

plot_df_acs['group'] = plot_df_acs['group'].replace(
    {"overall": "Population"}
)
plot_df_acs = plot_df_acs.query('~weights.isna()')

In [None]:
with warnings.catch_warnings():
  warnings.simplefilter("ignore")
  for metric in metric_dict.values():
    weight_types = ['None']
    tasks_to_plot = ['ACSIncome', 'ACSPublicCoverage']

    feature_set = ['X', 'XA', 'X_strat']

    g = sns.FacetGrid(
      (
          plot_df_acs.query(
              'weight_type == "population_on_group_comparison" &'
              ' metric == @metric & task in @tasks_to_plot &'
              ' features in @feature_set'
          )
          .sort_values(['group_name', 'task'])
          .assign(task=lambda x: x.task.values.remove_unused_categories(),
                  weights=lambda x: x.weights.values.remove_unused_categories())
      ),
      col='weights',
      row='features',
      margin_titles=True,
      sharex=True
    )

    g.map(
      pointplot_with_errorbars,
      'performance',
      'task',
      'group_name',
      'ci_low',
      'ci_high',
      vertical=False,
      pairwise=True,
      point_scale=0.5
    )

    g.set_axis_labels('', '')
    g.fig.supylabel(f'Task', fontsize=14, x=-0.06, y=0.57)
    g.fig.supxlabel(f'Subgroup {metric} - weighted population estimate', fontsize=14, x=0.5, y=-0.075)

    g.set_titles(
      row_template='Cov.: {row_name}', col_template='Weights: {col_name}', size=12
    )

    for ax in g.axes.flat:
      ax.tick_params(axis='x', labelsize=12)  # Set x-axis label size to 12

    # Legend
    g.add_legend(fontsize=12)
    g.legend.set_title('Group')
    sns.move_legend(g, 'center right', bbox_to_anchor=(1.08, 0.5))
    plt.setp(g._legend.get_title(), fontsize=14)
    plt.setp(g._legend.get_texts(), fontsize=12)

    g.figure.set_size_inches(8.5, 4)
    for filetype in ['png', 'pdf']:
      g.figure.savefig(
          os.path.join(FIGURE_PATH, f'adjustment_weights_xyr_{metric}_acs.{filetype}'), format=filetype, bbox_inches='tight'
      )

In [None]:
# Calibration plots
figsize = (4, 3.5)

fig, axes = plt.subplots(3, len(tasks), sharex=True, sharey=True, figsize=figsize)

for i, (task, feature_set) in enumerate(
    itertools.product(tasks, ['x', 'xa', 'xa_stratified'])
):

  calibration_plot(
      pred_eval_dict_acs[task]['labels'],
      pred_eval_dict_acs[task][f'pred_probs_y1_{feature_set}'],
      group=pred_eval_dict_acs[task]['group_name'],
      plot_overall=False,
      ax=axes[i % 3][i // 3],
      legend=False,
      xlabel=None,
      ylabel=None,
  )

for i, task in enumerate(tasks):
  axes[0][i].set_title(task, size=8)

axes[0][-1].text(x=1.02, y=0.3, s='Cov.: X', rotation=-90, size=8)
axes[1][-1].text(x=1.02, y=0.3, s='Cov.: XA', rotation=-90, size=8)
axes[2][-1].text(x=1.02, y=0.1, s='Cov.: X_strat', rotation=-90, size=8)

fig.supxlabel('Score', size=14, x=0.55, y=0.05)
fig.supylabel('Fraction positive', size=14, x=0.03, y=0.55)
plt.tight_layout()
fig.set_size_inches(*figsize)
axes[1][-1].legend(frameon=False, bbox_to_anchor=(1.25, 1.8), title='Group')
for filetype in ['png', 'pdf']:
  fig.savefig(
      os.path.join(FIGURE_PATH, f'calibration_acs.{filetype}'), format=filetype, bbox_inches='tight'
  )

In [None]:
weight_types = ['None']
metrics = ['log loss', 'AUC-ROC', 'net benefit', 'sensitivity', 'specificity']
tasks_to_plot = ['ACSIncome', 'ACSPublicCoverage']

g = sns.FacetGrid(
    (
        plot_df_acs.query(
            'weight_type in @weight_types & features == "relative_comparison" &'
            ' metric.isin(@metrics) & task in @tasks_to_plot'
        )
        .sort_values(['group_name', 'task'])
        .assign(
            task=lambda x: x.task.values.remove_unused_categories(),
            metric=lambda x: x.metric.values.remove_unused_categories(),
            )
    ),
    col='metric',
    margin_titles=True,
    sharex=False
)

g.map(
    pointplot_with_errorbars,
    'performance',
    'task',
    'group_name',
    'ci_low',
    'ci_high',
    vertical=False,
    pairwise=True,
    point_scale=0.5
)

g.set_axis_labels('', '')
g.fig.supylabel(f'Task', fontsize=14, x=-0.1, y=0.6)
g.fig.supxlabel(f'Difference in performance (subgroup-aware vs. unaware model)', fontsize=14, x=0.5, y=-0.02)

g.set_titles(
    row_template='', col_template='{col_name}', size=12
)

for ax in g.axes.flat:
  ax.tick_params(axis='x', labelsize=10)
  ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_zero))

# Legend
g.add_legend(fontsize=10)
g.legend.set_title('Group')
sns.move_legend(g, 'center right', bbox_to_anchor=(1.07, 0.57))
plt.setp(g._legend.get_title(), fontsize=14)
plt.setp(g._legend.get_texts(), fontsize=12)

g.figure.set_size_inches(8.5, 2.5)
for filetype in ['png', 'pdf']:
  g.figure.savefig(
      os.path.join(FIGURE_PATH, f'relative_performance_acs.{filetype}'), format=filetype, bbox_inches="tight"
  )

In [None]:
the_plot_df = (plot_df_acs.query(
        'weight_type in @weight_types & ~features.isin(@relative_comparison_features) &'
        ' metric == @metric & task in @tasks_to_plot'
    ).query('group != "Population"')
    .sort_values(['group', 'task', 'weights'])
    .assign(task=lambda x: x.task.values.remove_unused_categories(),
            weights=lambda x: x.weights.replace('None', 'Subgroup')
            )
    )
the_plot_df.groupby(['task', 'features', 'weights', 'group_name']).agg(lambda x: x.shape[0]).head(20)

In [None]:
# Population mapping results
with warnings.catch_warnings():
  warnings.simplefilter("ignore")
  for metric in metric_dict.values():
    weight_types = ['population', 'None']

    the_plot_df = (plot_df_acs.query(
        'weight_type in @weight_types & ~features.isin(@relative_comparison_features) &'
        ' metric == @metric & task in @tasks_to_plot'
    ).query('group != "Population"')
    .sort_values(['group', 'task', 'weights'])
    .assign(task=lambda x: x.task.values.remove_unused_categories(),
            weights=lambda x: x.weights.replace('None', 'Subgroup')
            )
    )

    g = sns.FacetGrid(
        the_plot_df,
        row='task',
        col='features',
        margin_titles=True,
        sharex='row',
    )

    g.map(
        pointplot_with_errorbars,
        'performance',
        'weights',
        'group_name',
        'ci_low',
        'ci_high',
        vertical=False,
        point_scale=0.5
    )

    g.set_axis_labels('', '')
    g.fig.supylabel(f'Control variable', fontsize=16, x=-0.02)
    g.fig.supxlabel(f'Weighted model performance ({metric})', fontsize=16, x=0.5, y=-0.02)

    g.set_titles(
        row_template='{row_name}', col_template='Features: {col_name}', size=10
    )
    g = wrap_titles(g)

    for ax in g.axes.flat:
      ax.tick_params(axis='x', labelsize=10)
      ax.tick_params(axis='y', labelsize=10)
      if metric == "net benefit":
        num_digits=3
        ax.xaxis.set_major_formatter(ticker.FuncFormatter(lambda *args: format_zero(*args, num_digits=num_digits)))

    # Legend
    g.add_legend(fontsize=12)
    g.legend.set_title('Group')
    sns.move_legend(g, 'center right', bbox_to_anchor=(1.03, 0.50))
    plt.setp(g._legend.get_title(), fontsize=14)
    plt.setp(g._legend.get_texts(), fontsize=12)

    g.fig.subplots_adjust(hspace=0.2, wspace=0.1)
    g.figure.set_size_inches(8.5, 11)
    for filetype in ['png', 'pdf']:
      g.figure.savefig(
          os.path.join(FIGURE_PATH, f'adjustment_weights_absolute_{metric}_acs.{filetype}'), format=filetype, bbox_inches="tight"
      )