Copyright 2022 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

## Discrete simulation study

An implementation of the simulation study for discrete observations described in https://arxiv.org/abs/2212.11254. This notebook implements Algorithm 1 described in the paper to produce the results included in Table 1.

This notebook relies on previously executing `colab/synthetic_data_to_file.ipynb`.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import warnings
import re
import os
from sklearn.cluster import KMeans

with warnings.catch_warnings():
  warnings.filterwarnings('ignore', category=DeprecationWarning)
  import scipy.spatial

In [None]:
def rmse(pred, true):
  return np.sqrt(np.mean((pred - true)**2))


def extract_from_df_nested(samples_df, cols=['u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits', 'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot', 'x_scaled']):
  """
  Extracts nested dict of numpy arrays from dataframe with structure {domain: {partition: data}}
  """
  result = {}
  if 'domain' in samples_df.keys():
    for domain in samples_df['domain'].unique():
      result[domain] = {}
      domain_df = samples_df.query('domain == @domain')
      for partition in domain_df['partition'].unique():
        partition_df = domain_df.query('partition == @partition')
        result[domain][partition] = extract_from_df(partition_df, cols=cols)
  else:
    for partition in samples_df['partition'].unique():
        partition_df = samples_df.query('partition == @partition')
        result[partition] = extract_from_df(partition_df, cols=cols)
  return result


def extract_from_df(samples_df,
                    cols=[
                        'u', 'x', 'w', 'c', 'c_logits', 'y', 'y_logits',
                        'y_one_hot', 'w_binary', 'w_one_hot', 'u_one_hot',
                        'x_scaled'
                    ]):
  """Extracts dict of numpy arrays from dataframe"""
  result = {}
  for col in cols:
    if col in samples_df.columns:
      result[col] = samples_df[col].values
    else:
      match_str = f'^{col}_\d$'
      r = re.compile(match_str, re.IGNORECASE)
      matching_columns = list(filter(r.match, samples_df.columns))
      if len(matching_columns) == 0:
        continue
      result[col] = samples_df[matching_columns].to_numpy()
  return result


# load data and discretize it
def discretize_data(data_dict, is_univariate=True):

  X = np.concatenate(
      (data_dict['source']['train']['x'], data_dict['source']['val']['x']))
  Y = np.concatenate(
      (data_dict['source']['train']['y'], data_dict['source']['val']['y']))
  X_shift = np.concatenate(
      (data_dict['target']['train']['x'], data_dict['target']['val']['x']))
  Y_shift = np.concatenate(
      (data_dict['target']['train']['y'], data_dict['target']['val']['y']))
  X_shift_test = data_dict['target']['test']['x']
  Y_shift_test = data_dict['target']['test']['y']
  C = np.concatenate(
      (data_dict['source']['train']['c'], data_dict['source']['val']['c']))
  if is_univariate:
    quant_X = np.quantile(X, [0.25, 0.5, 0.75])
    X = np.digitize(X, quant_X)
    X_shift = np.digitize(X_shift, quant_X)
    X_shift_test = np.digitize(X_shift_test, quant_X)
  else:
    kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
    X = kmeans.labels_
    X_shift = kmeans.predict(X_shift)
    X_shift_test = kmeans.predict(X_shift_test)
    C = np.dot(C, np.array([4, 2, 1]))  # treat multi-hot C as a binary encoding
  W = np.concatenate((data_dict['source']['train']['w_binary'], data_dict['source']['val']['w_binary']))
  U = np.concatenate(
      (data_dict['source']['train']['u'], data_dict['source']['val']['u']))
  return X, C, Y, W, U, X_shift, Y_shift, X_shift_test, Y_shift_test


def estimate_q_y_x(X, C, Y, W, U, X_shift, Y_shift, X_shift_test, Y_shift_test):
  k_W = len(np.unique(W))
  k_U = len(np.unique(U))
  k_X = len(np.unique(X))
  k_C = len(np.unique(C))
  k_Y = len(np.unique(Y))
  small = 1e-9

  # need to fix a c and y index, just use most popular ones
  c_index = 0
  c_sum = 0
  for c in np.unique(C):
    if np.sum(C == c) > c_sum:
      c_sum = np.sum(C == c)
      c_index = c

  y_index = 0
  y_sum = 0
  for y in np.unique(Y):
    if np.sum(Y == y) > y_sum:
      y_sum = np.sum(Y == y)
      y_index = y

  # B
  # -
  # p(y | c)
  ix_c = (C == c_index)  # where C = c_index
  p_y_c = np.mean(Y[ix_c] == y_index)

  # p(y, W | c) = p(W | y , c)p(y | c)
  p_y_W_c = np.zeros((k_W,))
  ix_c_y = ix_c * (Y == y_index)  # where C = c_index AND Y = y_index
  for i in range(k_W):
    p_y_W_c[i] = np.mean(W[ix_c_y] == i) * p_y_c

  # p(y, X | c) = p(X | y, c)p(y | c)
  p_y_X_c = np.zeros((k_X,))
  p_X_y_c = np.zeros((k_X,))
  # for later: p(Y | X, c)
  p_Y_X = np.zeros((k_Y, k_X))
  for i in range(k_X):
    p_X_y_c[i] = np.mean(X[ix_c_y] == i)
    p_y_X_c[i] = p_X_y_c[i] * p_y_c
    for j in range(k_Y):
      p_Y_X[j, i] = np.mean(Y[(X == i)] == j)

  assert np.abs(np.sum(p_X_y_c) - 1) < small

  # p(y, X, W | c) = p(W | y, X, c)p(X | y, c)p(y | c)
  p_y_X_W_c = np.zeros((k_X, k_W))
  # for later: p(y | X, W, C)
  p_y_X_W_C = np.zeros((k_X, k_W, k_C))
  # for later: p(Y | x, W)
  p_Y_X_W = np.zeros((k_Y, k_X, k_W))
  for i in range(k_X):
    ix_x_c_y = ix_c_y * (X == i)  # where C = c_index AND Y = y_index AND X = i
    for j in range(k_W):
      p_y_X_W_c[i, j] = np.mean(W[ix_x_c_y] == j) * p_X_y_c[i] * p_y_c
      for k in range(k_C):
        p_y_X_W_C[i, j,
                  k] = np.mean(Y[(X == i) * (W == j) * (C == k)] == y_index)
      for k in range(k_Y):
        p_Y_X_W[k, i, j] = np.mean(Y[(X == i) * (W == j)] == k)

  B = np.zeros((k_X, k_W))
  B[0, 0] = p_y_c
  B[0, 1:] = p_y_W_c[:-1]
  B[1:, 0] = p_y_X_c[:-1]
  B[1:, 1:] = p_y_X_W_c[:-1, :-1]

  assert np.isnan(B).any() == False

  # A
  # -
  # p(W | c)
  p_W_c = np.zeros((k_W,))
  for i in range(k_W):
    p_W_c[i] = np.mean(W[ix_c] == i)

  assert np.abs(np.sum(p_W_c) - 1) < small

  # p(X | c)
  p_X_c = np.zeros((k_X,))
  # for later: p(X), q(X)
  p_X = np.zeros((k_X,))
  q_X = np.zeros((k_X,))
  for i in range(k_X):
    p_X[i] = np.mean(X == i)
    q_X[i] = np.mean(X_shift == i)
    p_X_c[i] = np.mean(X[ix_c] == i)

  assert np.abs(np.sum(p_X_c) - 1) < small

  # p(X, W | c) = p(W | X, c)p(X | c)
  p_X_W_c = np.zeros((k_X, k_W))
  # for later: p(W | X, C), p(W | X)
  p_W_X_C = np.zeros((k_W, k_X, k_C))
  p_W_X = np.zeros((k_W, k_X))
  for i in range(k_X):
    ix_x_c = ix_c * (X == i)
    for j in range(k_W):
      p_X_W_c[i, j] = np.mean(W[ix_x_c] == j) * p_X_c[i]
      p_W_X[j, i] = np.mean(W[X == i] == j)
      for k in range(k_C):
        p_W_X_C[j, i, k] = np.mean(W[(C == k) * (X == i)] == j)

  assert (np.abs(np.sum(p_W_X_C[:, 0, 0]) - 1) < small)

  A = np.zeros((k_X, k_W))
  A[0, 0] = 1.0
  A[0, 1:] = p_W_c[:-1]
  A[1:, 0] = p_X_c[:-1]
  A[1:, 1:] = p_X_W_c[:-1, :-1]

  assert np.isnan(A).any() == False

  # take psuedo-inverse of A
  # if k_X > k_W P^+ = (A'A)^{-1}A'
  A_inv = np.linalg.pinv(A)
  AiB = np.dot(A_inv, B)  # (k_W, k_W)
  Delta, H = np.linalg.eig(AiB)

  H_inv = np.linalg.pinv(H)  # H_inv is (k_U, k_W)
  # get scaling vector e
  e = 1 / H_inv[:, 0]
  # multiply by H_inv to get S
  S = np.dot(np.diag(e), H_inv)
  p_W_U = np.column_stack(
      (S[:, 1:], 1 - np.sum(S[:, 1:], axis=1))).transpose()  # (k_W, k_U)

  # p(U | X) = p(W | U)^{-1}p(W | X)
  # -----------------------------------
  p_U_X = np.dot(np.linalg.pinv(p_W_U), p_W_X)
  assert np.isnan(p_U_X).any() == False

  # q(U)/p(U) = p(U | X)^{-1}[q(X)/p(X)]
  # ------------------------------------
  q_U_p_U = np.dot(q_X / p_X, np.linalg.pinv(p_U_X))
  assert np.isnan(q_U_p_U).any() == False
  # p(Y | x, U) = p(Y | x, W)p(U | x, W)^{-1}
  # p(U | x, W) = (p(W | U) * p(U | x)) / p(W | x)
  q_Y_X_ours = np.zeros((k_Y, k_X))
  for i in range(k_X):
    # p(U | x, W) = (p(W | U) * p(U | x)) / p(W | x)
    # ----------------------------------------------
    # p(W | U), from above
    # p(U | x), from above
    # p(W | x), from above

    p_U_x_W = ((p_W_U * p_U_X[:, i].reshape((1, k_U))) / p_W_X[:, i].reshape(
        (k_W, 1))).transpose()

    # p(Y | x, U) = p(Y | x, W)p(U | x, W)^{-1}
    # -----------------------------------------
    # p(Y | x, W), from above
    p_Y_x_U = np.dot(p_Y_X_W[:, i, :], np.linalg.pinv(p_U_x_W))

    # q(Y | x) = p(Y | x, U) * p(U | x) * q(U)/p(U) / Z, Z is a normalizing constant to ensure np.dot(1', q(Y | x)) = 1
    # -----------------------------------------------------------------------------------------------------------------
    unnorm = np.dot(p_Y_x_U, np.dot(np.diag(q_U_p_U,), p_U_X[:, i]))
    q_Y_X_ours[:, i] = unnorm / unnorm.sum()  # normalize

  # baseline: standard estimator - p_Y_X_c
  # oracle: q_Y_X_c
  q_Y_X = np.zeros((k_Y, k_X))
  q_Y_X_test = np.zeros((k_Y, k_X))

  for i in range(k_Y):
    for j in range(k_X):
      q_Y_X[i, j] = np.mean(Y_shift[(X_shift == j)] == i)
      q_Y_X_test[i, j] = np.mean(Y_shift_test[(X_shift_test == j)] == i)

  baseline_pYX = p_Y_X
  our_qYX = q_Y_X_ours
  train_qYX = q_Y_X
  test_qYX = q_Y_X_test

  return baseline_pYX, our_qYX, train_qYX, test_qYX

## W=1 (high noise)

In [None]:
folder_id = './tmp_data'
filename_source = "synthetic_multivariate_num_samples_10000_w_coeff_1_p_u_0_0.9.csv"
filename_target = "synthetic_multivariate_num_samples_10000_w_coeff_1_p_u_0_0.1.csv"
data_df_source = pd.read_csv(os.path.join(folder_id, filename_source))
data_df_target = pd.read_csv(os.path.join(folder_id, filename_target))
data_dict_source = extract_from_df_nested(data_df_source)
data_dict_target = extract_from_df_nested(data_df_target)
data_dict_all_w1 = dict(source=data_dict_source, target=data_dict_target)

In [None]:
X, C, Y, W, U, X_shift, Y_shift, X_shift_test, Y_shift_test = discretize_data(data_dict_all_w1, is_univariate=False)
baseline_pYX, our_qYX, train_qYX, test_qYX = estimate_q_y_x(
    X, C, Y, W, U, X_shift, Y_shift, X_shift_test, Y_shift_test)
rmse_baseline = rmse(baseline_pYX, test_qYX)
rmse_ours = rmse(our_qYX, test_qYX)
rmse_sampling = rmse(train_qYX, test_qYX)
print(rmse_baseline)
print(rmse_ours)
print(rmse_sampling)

## W=3 (high noise)

In [None]:
folder_id = "./tmp_data" # 20221012
filename_source = "synthetic_multivariate_num_samples_10000_w_coeff_3_p_u_0_0.9.csv"
filename_target = "synthetic_multivariate_num_samples_10000_w_coeff_3_p_u_0_0.1.csv"
data_df_source = pd.read_csv(os.path.join(folder_id, filename_source))
data_df_target = pd.read_csv(os.path.join(folder_id, filename_target))
data_dict_source = extract_from_df_nested(data_df_source)
data_dict_target = extract_from_df_nested(data_df_target)
data_dict_all_w3 = dict(source=data_dict_source, target=data_dict_target)

In [None]:
X, C, Y, W, U, X_shift, Y_shift, X_shift_test, Y_shift_test = discretize_data(data_dict_all_w3, is_univariate=False)
baseline_pYX, our_qYX, train_qYX, test_qYX = estimate_q_y_x(
    X, C, Y, W, U, X_shift, Y_shift, X_shift_test, Y_shift_test)
rmse_baseline = rmse(baseline_pYX, test_qYX)
rmse_ours = rmse(our_qYX, test_qYX)
rmse_sampling = rmse(train_qYX, test_qYX)
print(rmse_baseline)
print(rmse_ours)
print(rmse_sampling)