Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Synthetic experiment that sweeps over a range of distribution shifts

This notebook implements an experiment that trains a model on a single source domain where $P(U=1)=0.1$ and evaluates on a collection of target domains where $P(U=1)=\{0.1, \ldots, 0.9\}$. The notebook also produces a visualuzation of the results, analogous to Figure 3A in the paper. This notebook relies on previously running `colab/synthetic_data_to_file.ipynb`.

In [None]:
import numpy as np
import pandas as pd
import sklearn
import tensorflow as tf
import ml_collections as mlc
import scipy
import matplotlib.pyplot as plt
import seaborn as sns
import re
import os
import itertools
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import roc_auc_score
from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss
from IPython.display import display
from latent_shift_adaptation.methods.algorithms_sknp import get_classifier, latent_shift_adaptation
from latent_shift_adaptation.utils import gumbelmax_vae_ci, gumbelmax_vae
from latent_shift_adaptation.methods.vae import gumbelmax_vanilla, gumbelmax_graph

In [None]:
ITERATIONS = 10 # Set to 10 to replicate experiments in paper
EPOCHS = 200 # Set to 200 to replicate experiments in paper

In [None]:
#@title Library functions

def extract_from_df(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 
                                      'u_one_hot', 'x_scaled',
                                      'w_1', 'w_1_binary', 'w_1_one_hot',
                                      'w_2', 'w_2_binary', 'w_2_one_hot',
                                      'w_2_binary', 'w_2_one_hot',
                                      ]):
  """
  Extracts dict of numpy arrays from dataframe
  """
  result = {}
  for col in cols:
    if col in samples_df.columns:
      result[col] = samples_df[col].values
    else:
      match_str = f"^{col}_\\d$"
      r = re.compile(match_str, re.IGNORECASE)
      matching_columns = list(filter(r.match, samples_df.columns))
      if len(matching_columns) == 0:
        continue
      result[col] = samples_df[matching_columns].to_numpy()
  return result

def extract_from_df_nested(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot', 'x_scaled']):
  """
  Extracts nested dict of numpy arrays from dataframe with structure {domain: {partition: data}}
  """
  result = {}
  if 'domain' in samples_df.keys():
    for domain in samples_df['domain'].unique():
      result[domain] = {}
      domain_df = samples_df.query('domain == @domain')
      for partition in domain_df['partition'].unique():
        partition_df = domain_df.query('partition == @partition')
        result[domain][partition] = extract_from_df(partition_df, cols=cols)
  else:
    for partition in samples_df['partition'].unique():
        partition_df = samples_df.query('partition == @partition')
        result[partition] = extract_from_df(partition_df, cols=cols)
  return result



In [None]:
# Read the data
folder_id='./tmp_data'
p_u_list = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
w_coeff_list = [1, 2, 3]
data_dict_all = {}
for p_u_0, w_coeff in itertools.product(p_u_list, w_coeff_list):
  print(p_u_0, w_coeff)
  filename=f"synthetic_multivariate_num_samples_10000_w_coeff_{w_coeff}_p_u_0_{p_u_0}.csv"
  data_df = pd.read_csv(os.path.join(folder_id, filename))
  data_dict_all[(p_u_0, w_coeff)] = extract_from_df_nested(data_df)

In [None]:
# Define Sklearn evaluation functions
def soft_accuracy(y_true, y_pred, threshold=0.5, **kwargs):
  return sklearn.metrics.accuracy_score(y_true, y_pred >= threshold, **kwargs)

def log_loss64(y_true, y_pred, **kwargs):
  return sklearn.metrics.log_loss(y_true, y_pred.astype(np.float64), **kwargs)

evals_sklearn = {
    'cross-entropy': log_loss64,
    'accuracy': soft_accuracy, 
    'auc': sklearn.metrics.roc_auc_score
}
def fit_and_evaluate_sk(data_dict_source: dict,
                     data_dict_target: dict,
                     model_type: str = 'mlp'
                     ):
  
  ## Fit baselines
  model = get_classifier(model_type)
  model.fit(data_dict_source['train']['x'], data_dict_source['train']['y'])
  model_target = get_classifier(model_type)
  model_target.fit(data_dict_target['train']['x'], data_dict_target['train']['y'])
  
  # Apply LSA with known U
  lsa_pred_probs_target = latent_shift_adaptation(
      x_source=data_dict_source['train']['x'],
      y_source=data_dict_source['train']['y'],
      u_source=data_dict_source['train']['u'],
      x_target=data_dict_target['test']['x'],
      model_type=model_type)[:, -1]
  
  result = {}
  for metric, eval_fn in evals_sklearn.items():
    result[('erm-source', 'source', metric)] = eval_fn(data_dict_source['test']['y'], model.predict_proba(data_dict_source['test']['x'])[:, -1])
    result[('erm-source', 'source', metric)] = eval_fn(data_dict_source['test']['y'], model.predict_proba(data_dict_source['test']['x'])[:, -1])
    result[('erm-source', 'target', metric)] = eval_fn(data_dict_target['test']['y'], model.predict_proba(data_dict_target['test']['x'])[:, -1])
    result[('erm-target', 'source', metric)] = eval_fn(data_dict_source['test']['y'], model_target.predict_proba(data_dict_source['test']['x'])[:, -1])
    result[('erm-target', 'target', metric)] = eval_fn(data_dict_target['test']['y'], model_target.predict_proba(data_dict_target['test']['x'])[:, -1])
    result[('lsa-oracle-sk', 'target', metric)] = eval_fn(data_dict_target['test']['y'], lsa_pred_probs_target)

  return result

In [None]:
data_dict_source = data_dict_all[(0.9, 1)]

In [None]:
result_list_sk = []
for seed in range(ITERATIONS):
  print(f'Iteration: {seed}')
  np.random.seed(seed)
  result_sk = {}
  for p_u_0 in p_u_list:
    print(p_u_0)
    result_sk[p_u_0] = fit_and_evaluate_sk(data_dict_source, data_dict_all[(p_u_0, 1)], model_type='mlp')
  result_list_sk.append(result_sk)

In [None]:
result_sk_concat = pd.concat([pd.concat({key: pd.Series(value) for key, value in elem.items()}).to_frame().rename_axis(['p_u_target_0', 'method', 'eval_set', 'metric']).rename(columns={0: 'performance'}).assign(iteration=i) for i, elem in enumerate(result_list_sk)]).reset_index()
result_sk_concat

In [None]:
result_sk_concat.query('method == "erm-source" & eval_set == "target" & metric == "auc" & p_u_target_0 == 0.1')

### Run VAE

In [None]:
# Create TF datasets

batch_size = 128

ds_dict_source = {
    w_coeff: {
        key: tf.data.Dataset.from_tensor_slices(
        (value['x'], value['y_one_hot'], value['c'], value['w_one_hot'], value['u_one_hot']), 
    ).repeat().shuffle(1000).batch(batch_size) for key, value in data_dict_all[(0.9, w_coeff)].items()
  } for w_coeff in w_coeff_list
}

ds_dict_target = {
    p_u_0: {
        key: tf.data.Dataset.from_tensor_slices(
        (value['x'], value['y_one_hot'], value['c'], value['w_one_hot'], value['u_one_hot']), 
    ).repeat().shuffle(1000).batch(batch_size) for key, value in data_dict_all[(p_u_0, 1)].items()
    } for p_u_0 in p_u_list
}

In [None]:
ds_temp = ds_dict_source[1]
batch = next(iter(ds_temp['train']))
x_dim = batch[0].shape[1]
c_dim = batch[2].shape[1]
w_dim = batch[3].shape[1]
u_dim = batch[4].shape[1]
num_classes = 2
test_fract = 0.2
val_fract = 0.1

num_examples = 10_000
steps_per_epoch = num_examples // batch_size
steps_per_epoch_test = int(steps_per_epoch * test_fract)
steps_per_epoch_val = int(steps_per_epoch * val_fract)
steps_per_epoch_train = steps_per_epoch - steps_per_epoch_test - steps_per_epoch_val

pos = mlc.ConfigDict()
pos.x, pos.y, pos.c, pos.w, pos.u = 0, 1, 2, 3, 4

In [None]:
def evaluate_clf(data_dict_source, data_dict_target):
  result_dict = {}
  for metric in evals_sklearn.keys():
    result_dict[metric] = {}
  
  y_pred_source = clf.predict(data_dict_source['test']['x'])
  y_pred_target = clf.predict(data_dict_target['test']['x'])
  if 'cbm' in method:
    # hacky workaround for now
    y_pred_source = y_pred_source[1]
    y_pred_target = y_pred_target[1]
  y_pred_source = y_pred_source.numpy()[:, 1] if tf.is_tensor(y_pred_source) else y_pred_source[:, 1]
  y_pred_target = y_pred_target.numpy()[:, 1] if tf.is_tensor(y_pred_target) else y_pred_target[:, 1]
  y_true_source = data_dict_source['test']['y']
  y_true_target = data_dict_target['test']['y']
  
  for metric in evals_sklearn.keys():
    result_dict[metric]['source'] = evals_sklearn[metric](y_true_source, y_pred_source)
    result_dict[metric]['target'] = evals_sklearn[metric](y_true_target, y_pred_target)
  return result_dict

In [None]:
DEFAULT_LOSS = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

def mlp(num_classes, width, input_shape, learning_rate,
        loss=DEFAULT_LOSS, metrics=[]):
  """Multilabel Classification."""
  model_input = tf.keras.Input(shape=input_shape)
  # hidden layer
  if width:
    x = tf.keras.layers.Dense(
        width, use_bias=True, activation='relu'
    )(model_input)
  else:
    x = model_input
  model_outuput = tf.keras.layers.Dense(num_classes,
                                        use_bias=True,
                                        activation="linear")(x)  # get logits
  opt = tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
  model = tf.keras.models.Model(model_input, model_outuput)
  model.build(input_shape)
  model.compile(loss=loss, optimizer=opt, metrics=metrics)

  return model

xlabel = 'x'  # or 'x', 'x_scaled'
SEED = 0
tf.random.set_seed(SEED)
np.random.seed(SEED)
evals = {  # evaluation functions
    "cross-entropy": tf.keras.metrics.CategoricalCrossentropy(),
    "accuracy": tf.keras.metrics.CategoricalAccuracy(),
    "auc": tf.keras.metrics.AUC(multi_label = False)
}
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='loss', min_delta=0.01, factor=0.1, patience=20,
    min_lr=1e-7)

callbacks = [reduce_lr]

do_calib = True
evaluate = tf.keras.metrics.CategoricalCrossentropy()

learning_rate = 0.01  #@param {type:"number"}
width = 100  #@param {type:"number"}

epochs = EPOCHS
train_kwargs = {
    'epochs': epochs,
    'steps_per_epoch':steps_per_epoch_train,
    'verbose': True,
    'callbacks':callbacks
    }

val_kwargs = {
    'epochs': epochs,
    'steps_per_epoch':steps_per_epoch_val,
    'verbose': False,
    'callbacks':callbacks
    }

test_kwargs = {'verbose': False,
               'steps': steps_per_epoch_test}
tmep_kwargs = {'verbose': False}
latent_dim = 10

In [None]:
method = 'vae_graph'
input_shape = (x_dim, )
result_list_vae = []
for seed in range(ITERATIONS):
  print(f'Iteration: {seed}')
  tf.random.set_seed(seed)
  np.random.seed(seed)
  result_vae = {}
  ds_target_dummy = ds_dict_target[0.9]
  for w_coeff in w_coeff_list:
    print(f'Training model for w_coeff: {w_coeff}')
    ds_source = ds_dict_source[w_coeff]
    encoder = mlp(num_classes=latent_dim, width=width,
                  input_shape=(x_dim + c_dim + w_dim + num_classes,),
                  learning_rate=learning_rate,
                  metrics=['accuracy'])

    model_x2u = mlp(num_classes=latent_dim, width=width, input_shape=(x_dim,),
                    learning_rate=learning_rate,
                    metrics=['accuracy'])
    model_xu2y = mlp(num_classes=num_classes, width=width,
                    input_shape=(x_dim + latent_dim,),
                    learning_rate=learning_rate,
                    metrics=['accuracy'])
    vae_opt = tf.keras.optimizers.RMSprop(learning_rate=1e-4)

    dims = mlc.ConfigDict()
    dims.x = x_dim
    dims.y = num_classes
    dims.c = c_dim
    dims.w = u_dim
    dims.u = u_dim

    clf = gumbelmax_graph.Method(encoder, width, vae_opt,
                        model_x2u, model_xu2y, 
                        dims, latent_dim, None,
                        kl_loss_coef=3,
                        num_classes=num_classes, evaluate=evaluate,
                        dtype=tf.float32, pos=pos)

    clf.fit(ds_source['train'], ds_source['val'], ds_target_dummy['train'],
            steps_per_epoch_val, **train_kwargs)
    for p_u_0 in p_u_list:
      print(f'Adapting for p_u_0: {p_u_0}')
      ds_target = ds_dict_target[p_u_0]
      clf.freq_ratio = clf._get_freq_ratio(
          data_source_val=ds_source['val'], 
          data_target=ds_target['train'], 
          num_batches=steps_per_epoch_val
        )
      result_vae[(p_u_0, w_coeff)] = evaluate_clf(data_dict_source, data_dict_all[(p_u_0, 1)])
  result_list_vae.append(result_vae)

In [None]:
result_vae_concat = pd.concat([
    (
      pd.concat({key: pd.DataFrame(value) for key, value in elem.items()})
      .rename_axis(['p_u_target_0', 'w_coeff', 'eval_set'])
      .reset_index()
      .melt(id_vars=['p_u_target_0', 'w_coeff', 'eval_set'], value_vars=['cross-entropy', 'accuracy', 'auc'], var_name='metric', value_name='performance')
  ).assign(iteration=i) for i, elem in enumerate(result_list_vae)
])
result_vae_concat

In [None]:
result_vae_concat.query('eval_set == "target" & metric == "auc"')

In [None]:
result_sk_agg = result_sk_concat.groupby(['p_u_target_0', 'method', 'eval_set', 'metric']).agg(performance_mean=('performance', 'mean'), performance_std=('performance', 'std')).reset_index()
result_vae_agg = result_vae_concat.groupby(['p_u_target_0', 'w_coeff', 'eval_set', 'metric']).agg(performance_mean=('performance', 'mean'), performance_std=('performance', 'std')).reset_index().assign(method='vae')
result_agg = pd.concat([result_sk_agg, result_vae_agg])
result_agg

In [None]:
## Make a plot
plt.close()
plt.figure(figsize=(6, 4))
plot_x = 1-np.array([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])
erm_source_mean = result_sk_agg.query('method == "erm-source" & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_mean.values
erm_source_std = result_sk_agg.query('method == "erm-source" & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_std.values
plt.errorbar(plot_x, erm_source_mean, yerr=erm_source_std, color='k', linestyle='dashed', lw=2)

erm_target_mean = result_sk_agg.query('method == "erm-target" & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_mean.values
erm_target_std = result_sk_agg.query('method == "erm-target" & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_std.values
plt.errorbar(plot_x, erm_target_mean, yerr=erm_target_std, color='k', linestyle='dashed', lw=2)


lsa_result_mean = result_sk_agg.query('method == "lsa-oracle-sk" & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_mean.values
lsa_result_std = result_sk_agg.query('method == "lsa-oracle-sk" & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_std.values
plt.errorbar(plot_x, lsa_result_mean, yerr=lsa_result_std, color='r', lw=2, label='LSA-observed')

vae_result_1_mean = result_vae_agg.query('w_coeff == 1 & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_mean.values
vae_result_1_std = result_vae_agg.query('w_coeff == 1 & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_std.values
plt.errorbar(plot_x, vae_result_1_mean, yerr=vae_result_1_std, color="#154360", lw=2, label='WAE-high-noise')
vae_result_2_mean = result_vae_agg.query('w_coeff == 2 & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_mean.values
vae_result_2_std = result_vae_agg.query('w_coeff == 2 & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_std.values
plt.errorbar(plot_x, vae_result_2_mean, yerr=vae_result_2_std, color="#2E86C1", lw=2, label='WAE-med-noise')
vae_result_3_mean = result_vae_agg.query('w_coeff == 3 & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_mean.values
vae_result_3_std = result_vae_agg.query('w_coeff ==3 & eval_set == "target" & metric == "auc"').sort_values('p_u_target_0').performance_std.values
plt.errorbar(plot_x, vae_result_3_mean, yerr=vae_result_3_std, color="#85C1E9", lw=2, label='WAE-low-noise')

plt.ylabel('AUROC', size=20)
plt.xlabel('q(U=1)', size=20)
plt.legend(fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.text(0.95, np.array(erm_source_mean).min(), s='ERM-source', fontsize=16)
plt.text(0.95, erm_target_mean[0], s='ERM-target', fontsize=16)
sns.despine()

figure_folder_id = './tmp_data'
figure_filename = 'synthetic_sweep_vae_200_epochs.png'
plt.savefig(os.path.join(figure_folder_id, figure_filename), bbox_inches='tight', dpi=90)

In [None]:
# Write the plot data
result_agg.to_csv(os.path.join(figure_folder_id, 'synthetic_sweep_vae_results_200_epochs.csv'), index=False)