In [1]:
import matplotlib.pyplot as plt
import numpy as np
import ot
import seaborn as sns
import time
import warnings

import condo

In [2]:
plt.rcParams["xtick.major.size"] = 2
plt.rcParams["ytick.major.size"] = 2
plt.rcParams['axes.linewidth'] = 0.5
plt.rc('font', size=7) #controls default text size
plt.rc('axes', titlesize=7) #fontsize of the title
plt.rc('axes', labelsize=7) #fontsize of the x and y labels
plt.rc('xtick', labelsize=7) #fontsize of the x tick labels
plt.rc('ytick', labelsize=7) #fontsize of the y tick labels
plt.rc('legend', fontsize=7) #fontsize of the legend

In [None]:
rng = np.random.RandomState(42)

N_T = 200
N_S = 100

# How batch effect affects S
batch_m = 2
batch_b = 5

# The true batch correction from Sbatch to S
true_m = 1. / batch_m
true_b = -1 * batch_b / batch_m

noise_setting = "NoiseFree"
assert noise_setting in ("Noisy", "NoiseFree")

targetshift_setting = "TargetShift"
assert targetshift_setting in ("NoTargetShift", "TargetShift")

featureshift_setting = "FeatureShift"
assert featureshift_setting in ("NoFeatureShift", "FeatureShift")

prob_settings = ["homoscedastic-linear", "heteroscedastic-linear", "nonlinear"]
#prob_settings = ["nonlinear"]
num_probs = len(prob_settings)

fig, axes = plt.subplots(
    nrows=5, ncols=num_probs, sharex="all", sharey="all", squeeze=False,
    gridspec_kw={"hspace": 0.03, "wspace": 0.03},
    figsize=(6, 8), dpi=150)
msize = 1
basic_leg = ['target', 'source: true (unobserved)', 'source: batch-effected']
figname = f"figure-continuous1d-{noise_setting}-{targetshift_setting}-{featureshift_setting}.pdf"
fsizemse = 5

for pix, prob_setting in enumerate(prob_settings):
    # Distribution of confounders
    X_T = np.sort(np.random.uniform(0, 8, size=(N_T,)))
    if targetshift_setting == "TargetShift":
        X_S = np.sort(rng.uniform(4, 8, size=(N_S,)))
    elif targetshift_setting == "NoTargetShift":
        X_S = np.sort(rng.uniform(0, 8, size=(N_S,)))

    if prob_setting == "homoscedastic-linear":
        # How confounder X affects the distribution of T and S
        theta_m = 4
        theta_b = 1
        phi_m = 0
        phi_b = 2
        mu_T = theta_m * X_T + theta_b
        sigma_T = phi_m * X_T + phi_b
        mu_S = theta_m * X_S + theta_b
        sigma_S = phi_m * X_S + phi_b
        T = rng.normal(mu_T, sigma_T)
        Strue = rng.normal(mu_S, sigma_S)
    elif prob_setting == "heteroscedastic-linear":
        # How confounder X affects the distribution of T and S
        theta_m = 4
        theta_b = 1
        phi_m = 1
        phi_b = 1
        mu_T = theta_m * X_T + theta_b
        sigma_T = phi_m * X_T + phi_b
        mu_S = theta_m * X_S + theta_b
        sigma_S = phi_m * X_S + phi_b
        T = rng.normal(mu_T, sigma_T)
        Strue = rng.normal(mu_S, sigma_S)
    elif prob_setting == "nonlinear":
        # How confounder X affects the distribution of T and S
        theta_m = 4
        theta_b = 1
        phi_m = 1
        phi_b = 1
        mu_T = theta_m * (np.maximum(X_T-5, 0) ** 2) + theta_b
        sigma_T = phi_m * (np.maximum(X_T-5, 0) ** 2) + phi_b
        mu_S = theta_m * (np.maximum(X_S-5, 0) ** 2) + theta_b
        sigma_S = phi_m * (np.maximum(X_S-5, 0) ** 2) + phi_b
        T = rng.normal(mu_T, sigma_T)
        Strue = rng.normal(mu_S, sigma_S)

    if featureshift_setting == "FeatureShift":
        Sbatch = batch_m * Strue + batch_b
    elif featureshift_setting == "NoFeatureShift":
        Sbatch = Strue.copy()
        
    if noise_setting == "NoiseFree":
        pass
    elif noise_setting == "Noisy":
        Sbatch = Sbatch + rng.normal(0, 1, size=(N_S,))

    T_ = T.reshape(-1, 1) # (N_T, 1)
    Strue_ = Strue.reshape(-1, 1) # (N_S, 1)
    Sbatch_ = Sbatch.reshape(-1, 1)
    X_T_ = X_T.reshape(-1, 1) # (N_T, 1)
    X_S_ = X_S.reshape(-1, 1)
    
    axes[0, pix].tick_params(axis="both", which="both", direction="in")    
    axes[0, pix].scatter(X_T, T, s=msize)
    axes[0, pix].scatter(X_S, Strue, s=msize)
    axes[0, pix].scatter(X_S, Sbatch, s=msize)
    rMSE = np.sqrt(np.mean((Sbatch - Strue_) ** 2));
    axes[0, pix].text(0.5, 70, f"rMSE: {rMSE:.3f}", size=fsizemse);
    if pix == num_probs - 1:
        axes[0, pix].legend(
            basic_leg,
            loc="center right", bbox_to_anchor=(2.2, 0.5),
            frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);
    
    Soracle = true_m * Sbatch + true_b
    axes[1, pix].tick_params(axis="both", which="both", direction="in")    
    axes[1, pix].scatter(X_T, T, s=msize)
    axes[1, pix].scatter(X_S, Strue, s=msize)
    axes[1, pix].scatter(X_S, Sbatch, s=msize)
    axes[1, pix].scatter(X_S, Soracle, s=msize)
    rMSE = np.sqrt(np.mean((Soracle - Strue) ** 2));
    axes[1, pix].text(0.5, 70, f"rMSE: {rMSE:.3f}", size=fsizemse);
    if pix == num_probs - 1:
        axes[1, pix].legend(
            basic_leg + ['source: oracle-adapted'],
            loc="center right", bbox_to_anchor=(2.2, 0.5),
            frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);   
    
    A_otda, b_otda = ot.da.OT_mapping_linear(Sbatch_, T_)
    Sotda_ = Sbatch_.dot(A_otda) + b_otda
    axes[2, pix].tick_params(axis="both", which="both", direction="in")    
    axes[2, pix].scatter(X_T, T, s=msize)
    axes[2, pix].scatter(X_S, Strue, s=msize)
    axes[2, pix].scatter(X_S, Sbatch, s=msize)
    axes[2, pix].scatter(X_S, Sotda_, s=msize)
    rMSE = np.sqrt(np.mean((Sotda_ - Strue_) ** 2));
    axes[2, pix].text(0.5, 70, f"rMSE: {rMSE:.3f}", size=fsizemse);
    if pix == num_probs - 1:
        axes[2, pix].legend(
            basic_leg + ['source: Gaussian OT'],
            loc="center right", bbox_to_anchor=(2.2, 0.5),
            frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

    # ConDo using ReverseKL with heteroscedastic GP    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        cder = condo.ConDoAdapter(
            sampling="proportional",
            transform_type="location-scale",
            model_type="heteroscedastic-gp",
            divergence="reverse",
            debug=False,
            verbose=1
        )
        cder.fit(Sbatch_, T_, X_S_, X_T_)
        Sreverse = cder.transform(Sbatch_)
    axes[3, pix].tick_params(axis="both", which="both", direction="in")    
    axes[3, pix].scatter(X_T, T, s=msize)
    axes[3, pix].scatter(X_S, Strue, s=msize)
    axes[3, pix].scatter(X_S, Sbatch, s=msize)
    axes[3, pix].scatter(X_S, Sreverse, s=msize)
    rMSE = np.sqrt(np.mean((Sreverse - Strue_) ** 2));
    axes[3, pix].text(0.5, 70, f"rMSE: {rMSE:.3f}", size=fsizemse);
    if pix == num_probs - 1:
        axes[3, pix].legend(
            basic_leg + ['source: ConDo GP-ReverseKL'],
            loc="center right", bbox_to_anchor=(2.2, 0.5),
            frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);
    fig.savefig(figname, bbox_inches="tight")
    
    # Condo using MMD
    
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        cder = condo.ConDoAdapter(
            sampling="proportional",
            transform_type="affine",
            model_type="empirical",
            divergence="mmd",
            optim_kwargs={"epochs": 100, "alpha": 0.01, "beta": 0.9},
            debug=False,
            verbose=1,
        )
        cder.fit(Sbatch_, T_, X_S_, X_T_)
        Smmd = cder.transform(Sbatch_)
    axes[4, pix].tick_params(axis="both", which="both", direction="in")    
    axes[4, pix].scatter(X_T, T, s=msize)
    axes[4, pix].scatter(X_S, Strue, s=msize)
    axes[4, pix].scatter(X_S, Sbatch, s=msize)
    axes[4, pix].scatter(X_S, Smmd, s=msize)
    rMSE = np.sqrt(np.mean((Smmd - Strue_) ** 2));
    axes[4, pix].text(0.5, 70, f"rMSE: {rMSE:.3f}", size=fsizemse);
    if pix == num_probs - 1:
        axes[4, pix].legend(
            basic_leg + ['source: ConDo MMD'],
            loc="center right", bbox_to_anchor=(2.2, 0.5),
            frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

fig.savefig(figname, bbox_inches="tight")

Optimization terminated successfully.
         Current function value: -427.666687
         Iterations: 16
         Function evaluations: 24
epoch:0 -0.32686->-0.53222 avg:-0.49543
epoch:1 -0.53802->-0.54825 avg:-0.53878
epoch:2 -0.53183->-0.54261 avg:-0.53921
epoch:3 -0.54298->-0.54089 avg:-0.53915
epoch:4 -0.53882->-0.54663 avg:-0.53939
epoch:5 -0.54728->-0.54635 avg:-0.53972
epoch:6 -0.54011->-0.54517 avg:-0.54083
epoch:7 -0.53405->-0.54401 avg:-0.54049
epoch:8 -0.54148->-0.53900 avg:-0.54100
epoch:9 -0.54860->-0.54669 avg:-0.54276
epoch:10 -0.54740->-0.55118 avg:-0.54189
epoch:11 -0.54379->-0.53839 avg:-0.54273
epoch:12 -0.53820->-0.54053 avg:-0.54367
epoch:13 -0.54467->-0.54964 avg:-0.54292
epoch:14 -0.54876->-0.54959 avg:-0.54344
epoch:15 -0.54414->-0.54101 avg:-0.54435
epoch:16 -0.54432->-0.54975 avg:-0.54429
epoch:17 -0.54418->-0.54631 avg:-0.54431
epoch:18 -0.54608->-0.55302 avg:-0.54512
epoch:19 -0.54703->-0.54568 avg:-0.54667
epoch:20 -0.54323->-0.54355 avg:-0.54557
epoch:21