In [None]:
from collections import defaultdict
from itertools import product

import numpy as np
import matplotlib.pyplot as plt
#import ot
import pandas as pd
import seaborn as sns
import sklearn
import time
import warnings

import condo



In [None]:
_ = time.time()  # warmup
num_random = 10
mmd_size = 20

Ns = [10, 20, 50, 100, 200, 500]
Ns_np = np.array(Ns)
rmses = {} # method -> (num_Ns, num_random) array
times = {} # method -> (num_Ns, num_random) array
methods = [
    "source - true (unobserved)",
    "source - batch-effected",
    "source - Gaussian OT",
    "source - MMD",
    "source - ConDo Gaussian KLD", 
    "source - ConDo MMD",
]
mcolors = plt.rcParams['axes.prop_cycle'].by_key()['color'][0:len(methods)]
for method in methods:
    rmses[method] = np.zeros((len(Ns), num_random))
    times[method] = np.zeros((len(Ns), num_random))

for rix in range(num_random):
    for nix, N in enumerate(Ns):
        print(f"running N:{N} rix:{rix}") 
        rng = np.random.RandomState(rix)

        N_T = N
        N_S = N

        # How confounder X affects the distribution of T and S
        mu_hotdog = 5.
        sigma_hotdog = 1.0
        mu_not = 0.0
        sigma_not = 2.0

        # How batch effect affects S
        batch_m = 2 * np.ones((1, 1))
        batch_b = 5 * np.ones((1, 1))
        # The true batch correction from Sbatch to S
        true_m = 1. / batch_m
        true_b = -1 * batch_b / batch_m

        n_hotdogT = round(3 * N_T / 4)
        n_notT = N_T - n_hotdogT
        n_hotdogS = round(N_S / 4)
        n_notS = N_S - n_hotdogS
        X_T = np.array([['hotdog']*n_hotdogT + ['not_hotdog']*n_notT]).reshape((N_T, 1))
        X_S = np.array([['hotdog']*n_hotdogS + ['not_hotdog']*n_notS]).reshape((N_S, 1))
        Strue = np.nan * np.ones((N_S, 1))
        T = np.nan * np.ones((N_T, 1))
        Strue[np.where(X_S[:,0] == 'hotdog'), 0] = rng.normal(
            mu_hotdog, sigma_hotdog, size=(n_hotdogS))
        T[np.where(X_T[:,0] == 'hotdog'), 0] = rng.normal(
            mu_hotdog, sigma_hotdog, size=(n_hotdogT))
        Strue[np.where(X_S[:,0] == 'not_hotdog'), 0] = rng.normal(
            mu_not, sigma_not, size=(n_notS))
        T[np.where(X_T[:,0] == 'not_hotdog'), 0] = rng.normal(
            mu_not, sigma_not, size=(n_notT))
        Sbatch = Strue @ batch_m.T + batch_b

        dfT = pd.concat([
            pd.DataFrame(T, columns=['feature']), pd.DataFrame(X_T, columns=['confounder'])
        ], axis=1)
        dfT["batch"] = "target"

        dfStrue = pd.concat([
            pd.DataFrame(Strue, columns=['feature']), pd.DataFrame(X_S, columns=['confounder'])
        ], axis=1)
        dfStrue["batch"] = "source - true (unobserved)"
        rmses["source - true (unobserved)"][nix, rix] = ((dfStrue.feature - dfStrue.feature) ** 2).mean() ** 0.5

        dfSbatch = pd.concat([
            pd.DataFrame(Sbatch, columns=['feature']), pd.DataFrame(X_S, columns=['confounder'])
        ], axis=1)
        dfSbatch["batch"] = "source - batch-effected"
        rmses["source - batch-effected"][nix, rix] = ((dfStrue.feature - dfSbatch.feature) ** 2).mean() ** 0.5

        lter = condo.AdapterGaussianOT()
        start_time = time.time()
        lter.fit(Xs=Sbatch, Xt=T)
        Sotda = lter.transform(Xs=Sbatch)
        times["source - Gaussian OT"][nix, rix] = time.time() - start_time
        dfSotda = pd.concat([
            pd.DataFrame(Sotda, columns=['feature']), pd.DataFrame(X_S, columns=['confounder'])
        ], axis=1)
        dfSotda["batch"] = "source - Gaussian OT"
        rmses["source - Gaussian OT"][nix, rix] = ((dfStrue.feature - dfSotda.feature) ** 2).mean() ** 0.5
        
        nder = condo.AdapterMMD(
            transform_type="location-scale",
            n_epochs=100,
            learning_rate=1e-2,
            mmd_size=20,
            verbose=0,
        )
        start_time = time.time()
        nder.fit(Sbatch, T)
        Smmd = nder.transform(Sbatch)
        times["source - MMD"][nix, rix] = time.time() - start_time
        dfSmmd = pd.concat([
            pd.DataFrame(Smmd, columns=['feature']), pd.DataFrame(X_S, columns=['confounder'])
        ], axis=1)
        dfSmmd["batch"] = "source - MMD"
        rmses["source - MMD"][nix, rix] = ((dfStrue.feature - dfSmmd.feature) ** 2).mean() ** 0.5
        
        cder = condo.ConDoAdapterKLD(
            transform_type="location-scale",
            verbose=0,
        )
        start_time = time.time()
        cder.fit(Sbatch, T, X_S, X_T)
        Sclinear = cder.transform(Sbatch)
        times["source - ConDo Gaussian KLD"][nix, rix] = time.time() - start_time
        dfSclinear = pd.concat([
            pd.DataFrame(Sclinear, columns=['feature']), pd.DataFrame(X_S, columns=['confounder'])
        ], axis=1)
        dfSclinear["batch"] = "source - ConDo Gaussian KLD"
        rmses["source - ConDo Gaussian KLD"][nix, rix] = ((dfStrue.feature - dfSclinear.feature) ** 2).mean() ** 0.5

        cder = condo.ConDoAdapterMMD(
            transform_type="location-scale",
            n_epochs=100,
            learning_rate=1e-2,
            mmd_size=20,
            verbose=0,
        )
        start_time = time.time()
        cder.fit(Sbatch, T, X_S, X_T)
        Scmmd = cder.transform(Sbatch)
        times["source - ConDo MMD"][nix, rix] = time.time() - start_time
        dfScmmd = pd.concat([
            pd.DataFrame(Scmmd, columns=['feature']), pd.DataFrame(X_S, columns=['confounder'])
        ], axis=1)
        dfScmmd["batch"] = "source - ConDo MMD"
        rmses["source - ConDo MMD"][nix, rix] = ((dfStrue.feature - dfScmmd.feature) ** 2).mean() ** 0.5       

        df = pd.concat(
            [dfT, dfStrue, dfSbatch, dfSotda, dfSmmd, dfSclinear, dfScmmd], axis=0)

        fig = plt.figure(dpi=150, figsize=(6, 3));
        figname = f"figure-categorical1d-{N}-{rix}.pdf"
        sns.stripplot(
            x="confounder",
            y="feature",
            hue="batch",
            jitter=0.3,
            dodge=True,
            s=3,
            data=df)
        plt.legend(
            title='', loc="center left", bbox_to_anchor=(1.05, 0.5),
            frameon=False,
        );
        plt.tight_layout();
        if rix == 0:
            fig.savefig(figname, bbox_inches="tight")
        plt.close();

In [None]:
fig = plt.figure(dpi=150, figsize=(5, 3))

for method, mcolor in zip(methods, mcolors):
    if method in ["source - true (unobserved)"]:
        continue
    plt.errorbar(
        Ns_np, np.mean(rmses[method], axis=1), yerr=np.std(rmses[method], axis=1),
        marker='x', color=mcolor,
    );
plt.ylabel('rMSE');
plt.xlabel('N');
plt.xscale('log');
plt.xticks(Ns, Ns);
plt.minorticks_off()
plt.legend(
    methods[1:],
    title='', loc="center left", bbox_to_anchor=(1.05, 0.5),
    frameon=False,
);
figname = f"figure-categorical1d-rmses-{rix}.pdf"
fig.savefig(figname, bbox_inches="tight")

fig = plt.figure(dpi=150, figsize=(5, 3))
for method, mcolor in zip(methods, mcolors):
    if method in ["source - true (unobserved)", "source - batch-effected"]:
        continue
    plt.errorbar(
        Ns_np, np.mean(times[method], axis=1), yerr=np.std(times[method], axis=1),
        marker='x', color=mcolor,
    );
plt.ylabel('time (s)');
plt.xlabel('N');
plt.xscale('log');
plt.yscale('log');
plt.xticks(Ns, Ns);
plt.minorticks_off()
plt.legend(
    methods[2:],
    title='', loc="center left", bbox_to_anchor=(1.05, 0.5),
    frameon=False,
);
figname = f"figure-categorical1d-times-{rix}.pdf"
fig.savefig(figname, bbox_inches="tight")

In [None]:
mcolors = ["blue", "red", "cyan", "magenta"]

fig = plt.figure(dpi=150, figsize=(5, 3))
for method, mcolor in zip(methods[2:], mcolors):
    plt.errorbar(
        Ns_np, np.mean(rmses[method], axis=1), yerr=np.std(rmses[method], axis=1),
        marker='x', color=mcolor,
    );
plt.ylabel('rMSE');
plt.xlabel('N');
plt.xscale('log');
plt.xticks(Ns, Ns);
plt.minorticks_off()
plt.legend(
    methods[2:],
    title='', loc="center left", bbox_to_anchor=(1.05, 0.5),
    frameon=False,
);
figname = f"figure-categorical1d-rmses-{rix}.pdf"
fig.savefig(figname, bbox_inches="tight")

fig = plt.figure(dpi=150, figsize=(5, 3))
for method, mcolor in zip(methods[2:], mcolors):
    plt.errorbar(
        Ns_np, np.mean(times[method], axis=1), yerr=np.std(times[method], axis=1),
        marker='x', color=mcolor,
    );
plt.ylabel('time (s)');
plt.xlabel('N');
plt.xscale('log');
plt.yscale('log');
plt.xticks(Ns, Ns);
plt.minorticks_off()
plt.legend(
    methods[2:],
    title='', loc="center left", bbox_to_anchor=(1.05, 0.5),
    frameon=False,
);
figname = f"figure-categorical1d-times-{rix}.pdf"
fig.savefig(figname, bbox_inches="tight")