In [1]:
from collections import defaultdict
from itertools import product

import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
import seaborn as sns
import time
import warnings

import condo

In [2]:
plt.rcParams["xtick.major.size"] = 2
plt.rcParams["ytick.major.size"] = 2
plt.rcParams['axes.linewidth'] = 0.5
plt.rc('font', size=7) #controls default text size
plt.rc('axes', titlesize=7) #fontsize of the title
plt.rc('axes', labelsize=7) #fontsize of the x and y labels
plt.rc('xtick', labelsize=7) #fontsize of the x tick labels
plt.rc('ytick', labelsize=7) #fontsize of the y tick labels
plt.rc('legend', fontsize=7) #fontsize of the legend
plt.ioff();

In [3]:
N_T = 200
N_S = 100
mmd_size = 20
# How batch effect affects S
batch_m = 2
batch_b = 5

# The true batch correction from Sbatch to S
true_m = 1. / batch_m
true_b = -1 * batch_b / batch_m

noise_settings = ["NoiseFree", "Noisy",]
targetshift_settings = ["TargetShift", "NoTargetShift",]
featureshift_settings = ["FeatureShift", "NoFeatureShift"]

prob_settings = ["Homoscedastic Linear", "Heteroscedastic Linear", "Nonlinear"]
#prob_settings = ["Nonlinear"]
num_probs = len(prob_settings)

rMSEs = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
rMSEs_test = defaultdict(lambda: defaultdict(lambda: defaultdict(list)))
num_random = 10
for rix in range(num_random):
    rng = np.random.RandomState(rix)
    for setting_combo in product(noise_settings, targetshift_settings, featureshift_settings):
        (noise_setting, targetshift_setting, featureshift_setting) = setting_combo
        setting_str = "-".join(setting_combo)
        assert noise_setting in ("Noisy", "NoiseFree")
        assert targetshift_setting in ("NoTargetShift", "TargetShift")
        assert featureshift_setting in ("NoFeatureShift", "FeatureShift")

        fig, axes = plt.subplots(
            nrows=5, ncols=num_probs, sharex="all", sharey="all", squeeze=False,
            gridspec_kw={"hspace": 0.03, "wspace": 0.03},
            figsize=(6, 7), dpi=150)
        msize = 1
        basic_leg = ['target', 'source: batch-effected', 'source: true (unobserved)']
        figname = f"figure-continuous1d-{setting_str}-{rix}.pdf"
        fsizemse = 7

        for pix, prob_setting in enumerate(prob_settings):
            print(f"rix:{rix} {setting_str} {prob_setting}")
            # Distribution of confounders
            X_T = np.sort(rng.uniform(0, 8, size=(N_T,)))
            if targetshift_setting == "TargetShift":
                X_S = np.sort(rng.uniform(4, 8, size=(N_S,)))
                X_S_test = np.sort(rng.uniform(0, 8, size=(N_S,)))
            elif targetshift_setting == "NoTargetShift":
                X_S = np.sort(rng.uniform(0, 8, size=(N_S,)))
                X_S_test = np.sort(rng.uniform(0, 8, size=(N_S,)))

            if prob_setting == "Homoscedastic Linear":
                # How confounder X affects the distribution of T and S
                theta_m = 4
                theta_b = 1
                phi_m = 0
                phi_b = 2
                mu_T = theta_m * X_T + theta_b
                mu_S = theta_m * X_S + theta_b
                mu_S_test = theta_m * X_S_test + theta_b
                sigma_T = phi_m * X_T + phi_b
                sigma_S = phi_m * X_S + phi_b
                sigma_S_test = phi_m * X_S_test + phi_b
                T = rng.normal(mu_T, sigma_T)
                Strue = rng.normal(mu_S, sigma_S)
                Strue_test = rng.normal(mu_S_test, sigma_S_test)
            elif prob_setting == "Heteroscedastic Linear":
                # How confounder X affects the distribution of T and S
                theta_m = 4
                theta_b = 1
                phi_m = 1
                phi_b = 1
                mu_T = theta_m * X_T + theta_b
                mu_S = theta_m * X_S + theta_b
                mu_S_test = theta_m * X_S_test + theta_b
                sigma_T = phi_m * X_T + phi_b
                sigma_S = phi_m * X_S + phi_b
                sigma_S_test = phi_m * X_S_test + phi_b
                T = rng.normal(mu_T, sigma_T)
                Strue = rng.normal(mu_S, sigma_S)
                Strue_test = rng.normal(mu_S_test, sigma_S_test)
            elif prob_setting == "Nonlinear":
                # How confounder X affects the distribution of T and S
                theta_m = 4
                theta_b = 1
                phi_m = 1
                phi_b = 1
                mu_T = theta_m * (np.maximum(X_T-5, 0) ** 2) + theta_b
                mu_S = theta_m * (np.maximum(X_S-5, 0) ** 2) + theta_b
                mu_S_test = theta_m * (np.maximum(X_S_test-5, 0) ** 2) + theta_b
                sigma_T = phi_m * (np.maximum(X_T-5, 0) ** 2) + phi_b
                sigma_S = phi_m * (np.maximum(X_S-5, 0) ** 2) + phi_b
                sigma_S_test = phi_m * (np.maximum(X_S_test-5, 0) ** 2) + phi_b
                T = rng.normal(mu_T, sigma_T)
                Strue = rng.normal(mu_S, sigma_S)
                Strue_test = rng.normal(mu_S_test, sigma_S_test)
            if featureshift_setting == "FeatureShift":
                Sbatch = batch_m * Strue + batch_b
                Sbatch_test = batch_m * Strue_test + batch_b
                oracle_m = true_m
                oracle_b = true_b
            elif featureshift_setting == "NoFeatureShift":
                oracle_m = 1.0
                oracle_b = 0.0
                Sbatch = Strue.copy()
                Sbatch_test = Strue_test.copy()

            if noise_setting == "NoiseFree":
                pass
            elif noise_setting == "Noisy":
                Sbatch = Sbatch + rng.normal(0, 1, size=(N_S,))
                Sbatch_test = Sbatch_test + rng.normal(0, 1, size=(N_S,))

            T_ = T.reshape(-1, 1) # (N_T, 1)
            Strue_ = Strue.reshape(-1, 1) # (N_S, 1)
            Sbatch_ = Sbatch.reshape(-1, 1)
            X_T_ = X_T.reshape(-1, 1) # (N_T, 1)
            X_S_ = X_S.reshape(-1, 1)
            Strue_test_ = Strue_test.reshape(-1, 1) # (N_S, 1)
            Sbatch_test_ = Sbatch_test.reshape(-1, 1)
            X_S_test_ = X_S_test.reshape(-1, 1)
            
            # Before correction
            method = "Before Correction"
            rMSE = np.sqrt(np.mean((Sbatch - Strue) ** 2));
            rMSEs[setting_str][method][prob_setting].append(rMSE)
            rMSE_test = np.sqrt(np.mean((Sbatch_test - Strue_test) ** 2));
            rMSEs_test[setting_str][method][prob_setting].append(rMSE_test)
            avgrMSE = np.mean(rMSEs[setting_str][method][prob_setting])
            avgrMSE_test = np.mean(rMSEs_test[setting_str][method][prob_setting])
            print(f"    {method} rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})")
            axes[0, pix].tick_params(axis="both", which="both", direction="in")    
            axes[0, pix].scatter(X_T, T, s=msize)
            axes[0, pix].scatter(X_S, Sbatch, s=msize)
            axes[0, pix].scatter(X_S, Strue, s=msize)
            axes[0, pix].text(
                0.1, 0.9, f"rMSE: {avgrMSE:.3f} ({avgrMSE_test:.3f})", size=fsizemse,
                transform = axes[0, pix].transAxes);
            if pix == num_probs - 1:
                axes[0, pix].legend(
                    basic_leg,
                    loc="center left", bbox_to_anchor=(1.05, 0.5),
                    frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

            # Oracle - not displayed
            method = "Oracle"
            Soracle = oracle_m * Sbatch + oracle_b
            Soracle_test = oracle_m * Sbatch_test + oracle_b
            rMSE = np.sqrt(np.mean((Soracle - Strue) ** 2));
            rMSEs[setting_str][method][prob_setting].append(rMSE)
            rMSE_test = np.sqrt(np.mean((Soracle_test - Strue_test) ** 2));
            rMSEs_test[setting_str][method][prob_setting].append(rMSE_test)
            avgrMSE = np.mean(rMSEs[setting_str][method][prob_setting])
            avgrMSE_test = np.mean(rMSEs_test[setting_str][method][prob_setting])
            print(f"    {method} rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})")

            # OTDA
            method = "Gaussian OT"
            lter = condo.AdapterGaussianOT()
            lter.fit(Xs=Sbatch_, Xt=T_)
            Sotda_ = lter.transform(Xs=Sbatch_)
            Sotda_test_ = lter.transform(Xs=Sbatch_test_)           
            rMSE = np.sqrt(np.mean((Sotda_ - Strue_) ** 2));
            rMSEs[setting_str][method][prob_setting].append(rMSE)
            rMSE_test = np.sqrt(np.mean((Sotda_test_ - Strue_test_) ** 2));
            rMSEs_test[setting_str][method][prob_setting].append(rMSE_test)
            avgrMSE = np.mean(rMSEs[setting_str][method][prob_setting])
            avgrMSE_test = np.mean(rMSEs_test[setting_str][method][prob_setting])
            print(f"    {method} rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})")
            axes[1, pix].tick_params(axis="both", which="both", direction="in")    
            axes[1, pix].scatter(X_T, T, s=msize)
            axes[1, pix].scatter(X_S, Sbatch, s=msize)
            axes[1, pix].scatter(X_S, Strue, s=msize)
            axes[1, pix].scatter(X_S, Sotda_, s=msize)
            axes[1, pix].text(
                0.1, 0.9, f"rMSE: {avgrMSE:.3f} ({avgrMSE_test:.3f})", size=fsizemse,
                transform = axes[1, pix].transAxes);
            if pix == num_probs - 1:
                axes[1, pix].legend(
                    basic_leg + [f"source: {method}"],
                    loc="center left", bbox_to_anchor=(1.05, 0.5),
                    frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

            # MMD
            method = "MMD"
            cder = condo.AdapterMMD(
                transform_type="location-scale",
                n_epochs=100,
                learning_rate=1e-2,
                mmd_size=mmd_size,
                verbose=0,
            )
            cder.fit(Sbatch_, T_)
            Smmd_ = cder.transform(Sbatch_)
            Smmd_test_ = cder.transform(Sbatch_test_)
            rMSE = np.sqrt(np.mean((Smmd_ - Strue_) ** 2));
            rMSEs[setting_str][method][prob_setting].append(rMSE)
            rMSE_test = np.sqrt(np.mean((Smmd_test_ - Strue_test_) ** 2));
            rMSEs_test[setting_str][method][prob_setting].append(rMSE_test)
            avgrMSE = np.mean(rMSEs[setting_str][method][prob_setting])
            avgrMSE_test = np.mean(rMSEs_test[setting_str][method][prob_setting]);
            print(f"    {method} rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})")
            axes[2, pix].tick_params(axis="both", which="both", direction="in")    
            axes[2, pix].scatter(X_T, T, s=msize)
            axes[2, pix].scatter(X_S, Sbatch, s=msize)
            axes[2, pix].scatter(X_S, Strue, s=msize)
            axes[2, pix].scatter(X_S, Smmd_, s=msize)
            axes[2, pix].text(
                0.1, 0.9, f"rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})", size=fsizemse,
                transform = axes[2, pix].transAxes);
            if pix == num_probs - 1:
                axes[2, pix].legend(
                    basic_leg + [f"source: {method}"],
                    loc="center left", bbox_to_anchor=(1.05, 0.5),
                    frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

            # ConDo using ReverseKL with linear Gaussian
            method = "ConDo Gaussian KLD"
            cder = condo.ConDoAdapterKLD(
                transform_type="location-scale",
                verbose=0,
            )
            cder.fit(Sbatch_, T_, X_S_, X_T_)
            Sreverse_ = cder.transform(Sbatch_)
            Sreverse_test_ = cder.transform(Sbatch_test_)
            rMSE = np.sqrt(np.mean((Sreverse_ - Strue_) ** 2));
            rMSEs[setting_str][method][prob_setting].append(rMSE)
            rMSE_test = np.sqrt(np.mean((Sreverse_test_ - Strue_test_) ** 2));
            rMSEs_test[setting_str][method][prob_setting].append(rMSE_test)
            avgrMSE = np.mean(rMSEs[setting_str][method][prob_setting])
            avgrMSE_test = np.mean(rMSEs_test[setting_str][method][prob_setting])
            print(f"    {method} rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})")
            axes[3, pix].tick_params(axis="both", which="both", direction="in")    
            axes[3, pix].scatter(X_T, T, s=msize)
            axes[3, pix].scatter(X_S, Sbatch, s=msize)
            axes[3, pix].scatter(X_S, Strue, s=msize)
            axes[3, pix].scatter(X_S, Sreverse_, s=msize)
            axes[3, pix].text(
                0.1, 0.9, f"rMSE: {avgrMSE:.3f} ({avgrMSE_test:.3f})", size=fsizemse,
                transform = axes[3, pix].transAxes);
            if pix == num_probs - 1:
                axes[3, pix].legend(
                    basic_leg + [f"source: {method}"],
                    loc="center left", bbox_to_anchor=(1.05, 0.5),
                    frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);            

            # Condo using MMD
            method = "ConDo MMD"
            cder = condo.ConDoAdapterMMD(
                transform_type="location-scale",
                n_epochs=100,
                learning_rate=1e-2,
                mmd_size=mmd_size,
                verbose=0,
            )
            cder.fit(Sbatch_, T_, X_S_, X_T_)
            Smmd_ = cder.transform(Sbatch_)
            Smmd_test_ = cder.transform(Sbatch_test_)
            rMSE = np.sqrt(np.mean((Smmd_ - Strue_) ** 2));
            rMSEs[setting_str][method][prob_setting].append(rMSE)
            rMSE_test = np.sqrt(np.mean((Smmd_test_ - Strue_test_) ** 2));
            rMSEs_test[setting_str][method][prob_setting].append(rMSE_test)
            avgrMSE = np.mean(rMSEs[setting_str][method][prob_setting])
            avgrMSE_test = np.mean(rMSEs_test[setting_str][method][prob_setting]);
            print(f"    {method} rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})")
            axes[4, pix].tick_params(axis="both", which="both", direction="in")    
            axes[4, pix].scatter(X_T, T, s=msize)
            axes[4, pix].scatter(X_S, Sbatch, s=msize)
            axes[4, pix].scatter(X_S, Strue, s=msize)
            axes[4, pix].scatter(X_S, Smmd_, s=msize)
            axes[4, pix].text(
                0.1, 0.9, f"rMSE: {avgrMSE:.3f}  ({avgrMSE_test:.3f})", size=fsizemse,
                transform = axes[4, pix].transAxes);
            if pix == num_probs - 1:
                axes[4, pix].legend(
                    basic_leg + [f"source: {method}"],
                    loc="center left", bbox_to_anchor=(1.05, 0.5),
                    frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);
        if rix in (0, num_random - 1):
            fig.savefig(figname, bbox_inches="tight")
        plt.close()

rix:0 NoiseFree-TargetShift-FeatureShift Homoscedastic Linear
    Before Correction rMSE: 30.110  (22.992)
    Oracle rMSE: 0.000  (0.000)
    Gaussian OT rMSE: 9.049  (17.721)
    MMD rMSE: 5.160  (4.995)
    ConDo Gaussian KLD rMSE: 0.459  (1.218)
    ConDo MMD rMSE: 0.345  (0.919)
rix:0 NoiseFree-TargetShift-FeatureShift Heteroscedastic Linear
    Before Correction rMSE: 31.425  (26.066)
    Oracle rMSE: 0.000  (0.000)
    Gaussian OT rMSE: 8.538  (11.201)
    MMD rMSE: 7.983  (7.380)
    ConDo Gaussian KLD rMSE: 0.873  (1.374)
    ConDo MMD rMSE: 2.202  (2.164)
rix:0 NoiseFree-TargetShift-FeatureShift Nonlinear
    Before Correction rMSE: 17.877  (15.738)
    Oracle rMSE: 0.000  (0.000)
    Gaussian OT rMSE: 4.535  (4.148)
    MMD rMSE: 8.617  (7.574)
    ConDo Gaussian KLD rMSE: 1.870  (2.034)
    ConDo MMD rMSE: 0.945  (0.832)
rix:0 NoiseFree-TargetShift-NoFeatureShift Homoscedastic Linear
    Before Correction rMSE: 0.000  (0.000)
    Oracle rMSE: 0.000  (0.000)
    Gaussian OT 

In [4]:
s_dict = {}

for setting_str in rMSEs:
    #print(setting_str)
    #print(f" & " + " & ".join(prob_settings) + r" \\")
    cur_min = defaultdict(lambda: np.inf)
    cur_min_test = defaultdict(lambda: np.inf)
    cur_max = defaultdict(lambda: np.NINF)
    cur_max_test = defaultdict(lambda: np.NINF)
    for method in rMSEs[setting_str]:
        if method in ('Before Correction', 'Oracle'):
            continue
        method_str = f"{method}"
        for prob_setting in prob_settings:
            cur_rmse = np.mean(rMSEs[setting_str][method][prob_setting])
            cur_rmse_test = np.mean(rMSEs_test[setting_str][method][prob_setting])
            cur_min[prob_setting] = min(cur_min[prob_setting], cur_rmse)
            cur_min_test[prob_setting] = min(cur_min_test[prob_setting], cur_rmse)
            cur_max[prob_setting] = max(cur_max[prob_setting], cur_rmse)
            cur_max_test[prob_setting] = max(cur_max_test[prob_setting], cur_rmse)
    method_str_all = r""        
    for method in rMSEs[setting_str]:
        method_str = f"{method}"
        for prob_setting in prob_settings:
            cur_rmse = np.mean(rMSEs[setting_str][method][prob_setting])
            cur_rmse_test = np.mean(rMSEs_test[setting_str][method][prob_setting])
            is_min = cur_rmse == cur_min[prob_setting]
            is_max = cur_rmse == cur_max[prob_setting]
            is_min_test = cur_rmse == cur_min_test[prob_setting]
            is_max_test = cur_rmse == cur_max_test[prob_setting]
            mycolor = "black"
            if is_min:
                mycolor = "green"
            if is_max:
                mycolor = "red"
            if method in ('Before Correction', 'Oracle'):
                mycolor = "black"
            mycolortest = "black"
            if is_min_test:
                mycolortest = "green"
            if is_max_test:
                mycolortest = "red"
            if method in ('Before Correction', 'Oracle'):
                mycolortest = "black"
            train_str = " & {\\color{" + mycolor + "}" + f"{cur_rmse:.3f}" + "}"
            test_str = " ({\\color{" + mycolortest + "}" + f"{cur_rmse_test:.3f}" + "})"
            method_str += train_str + test_str
        method_str += r" \\"
        #print(method_str)
        method_str_all += method_str
    s_dict[setting_str] = method_str_all

In [5]:
plt.rcParams["xtick.major.size"] = 2
plt.rcParams["ytick.major.size"] = 2
plt.rcParams['axes.linewidth'] = 0.5
plt.rcParams['font.family'] = "sans-serif"
plt.rcParams['font.sans-serif'] = "Arial"

plt.rc('font', size=7) #controls default text size
plt.rc('axes', titlesize=7) #fontsize of the title
plt.rc('axes', labelsize=7) #fontsize of the x and y labels
plt.rc('xtick', labelsize=7) #fontsize of the x tick labels
plt.rc('ytick', labelsize=7) #fontsize of the y tick labels
plt.rc('legend', fontsize=7) #fontsize of the legend
plt.ioff();

In [6]:
avg_rMSEs = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
std_rMSEs = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
for k1 in rMSEs:
    for k2 in rMSEs[k1]:
        for k3 in rMSEs[k1][k2]:
            avg_rMSEs[k1][k2][k3] = np.mean(rMSEs[k1][k2][k3])
            std_rMSEs[k1][k2][k3] = np.std(rMSEs[k1][k2][k3])           
avg_rMSEs = json.loads(json.dumps(avg_rMSEs))
std_rMSEs = json.loads(json.dumps(std_rMSEs))

avg_rMSEs_test = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
std_rMSEs_test = defaultdict(lambda: defaultdict(lambda: defaultdict(float)))
for k1 in rMSEs_test:
    for k2 in rMSEs_test[k1]:
        for k3 in rMSEs_test[k1][k2]:
            avg_rMSEs_test[k1][k2][k3] = np.mean(rMSEs_test[k1][k2][k3])
            std_rMSEs_test[k1][k2][k3] = np.std(rMSEs_test[k1][k2][k3])
            
avg_rMSEs_test = json.loads(json.dumps(avg_rMSEs_test))
std_rMSEs_test = json.loads(json.dumps(std_rMSEs_test))

pd.DataFrame(avg_rMSEs['NoiseFree-TargetShift-FeatureShift']).T

Unnamed: 0,Homoscedastic Linear,Heteroscedastic Linear,Nonlinear
Before Correction,30.42064,31.923,18.08816
Oracle,9.294636e-16,6.999362e-16,4.573721e-16
Gaussian OT,9.268627,8.944433,4.572918
MMD,5.622962,8.43448,8.793253
ConDo Gaussian KLD,0.535397,1.305096,2.070855
ConDo MMD,0.5888936,1.408509,1.768535


In [8]:
nsets = ["NoiseFree", "Noisy"]

for cur_rMSEs, name_rMSEs in [(avg_rMSEs, "train"), (avg_rMSEs_test, "test")]:
    fig, axes = plt.subplots(4, 2, figsize=(4, 6), dpi=100, sharey="row")
    subrows = list(product(
        *[["TargetShift", "NoTargetShift"], ["FeatureShift", "NoFeatureShift"]]
    ))

    for subrowix in range(4):
        subrow = subrows[subrowix]
        tset, fset = subrow
        for colix, nset in enumerate(nsets):
            sset = f"{nset}-{tset}-{fset}"
            df = pd.DataFrame(cur_rMSEs[sset]).T.drop(index=['Before Correction', 'Oracle'])
            dfmelt = df.reset_index().melt(id_vars="index")
            sns.barplot(
                data=dfmelt, x="variable", y="value", hue="index", 
                ax=axes[subrowix, colix], palette=["blue", "red", "cyan", "magenta"],
            )
            axes[subrowix, colix].get_legend().remove()
            if subrowix == 0 and colix == 0:
                axes[subrowix, colix].set_xticklabels(
                    [
                        'Homo-\nscedastic ',
                        'Hetero-\n scedastic', ' Nonlinear'
                    ],
                    fontsize=6, rotation=0);
            else:
                axes[subrowix, colix].set_xticklabels(['', '', '']);
            axes[subrowix, colix].set_xlabel('');
            axes[subrowix, colix].set_ylabel('');
            if colix == 0:
                axes[subrowix, colix].set_ylabel('rMSE');                
            axes[subrowix, colix].set_title(sset.replace('-', ' '));

    plt.tight_layout();
    plt.savefig(f"figure-continuous1d-{name_rMSEs}.pdf")
fig2, ax2 = plt.subplots(figsize = (1.5, 0.75))
ax2.legend(*axes[subrowix, colix].get_legend_handles_labels(), loc='center')
ax2.axis('off')
fig2.savefig('figure-continuous1d-legend.pdf')

  axes[subrowix, colix].set_xticklabels(
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
  axes[subrowix, colix].set_xticklabels(['', '', '']);
