In [1]:
from collections import defaultdict
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import os
import ot
import pathlib
import seaborn as sns
import time
import warnings

import condo

In [2]:
plt.rcParams["xtick.major.size"] = 2
plt.rcParams["ytick.major.size"] = 2
plt.rcParams['axes.linewidth'] = 0.5
plt.rc('font', size=7) #controls default text size
plt.rc('axes', titlesize=7) #fontsize of the title
plt.rc('axes', labelsize=7) #fontsize of the x and y labels
plt.rc('xtick', labelsize=7) #fontsize of the x tick labels
plt.rc('ytick', labelsize=7) #fontsize of the y tick labels
plt.rc('legend', fontsize=7) #fontsize of the legend
plt.ioff();

In [3]:
num_random = 5

prob_settings = ["No Confounding", "Confounded", "Confounded - Challenging"]
num_probs = len(prob_settings)

rMSEs = defaultdict(lambda: defaultdict(list))
rMSEs_test = defaultdict(lambda: defaultdict(list))
udAccs = defaultdict(lambda: defaultdict(list))
lrAccs = defaultdict(lambda: defaultdict(list))
udAccs_test = defaultdict(lambda: defaultdict(list))
lrAccs_test = defaultdict(lambda: defaultdict(list))

for rix in range(num_random):
    rng = np.random.RandomState(rix)
    
    fig, axes = plt.subplots(
        nrows=6, ncols=num_probs, sharex="all", sharey="all", squeeze=False,
        gridspec_kw={"hspace": 0.03, "wspace": 0.03},
        figsize=(6, 9), dpi=150)
    msize = 0.25
    basic_leg = ['target', 'source: batch-effected', 'source: true (unobserved)']
    figname = f"figure-categorical2d-{rix}.pdf"
    fsizemse = 7
    n_lines = 100

    for pix, prob_setting in enumerate(prob_settings):
        print(f"rix:{rix} {prob_setting}")

        if prob_setting == "No Confounding":
            n = 200
            d = 2
            sigma = .1
            levery = n // n_lines
            n_up = n // 2
            n_down = n - n_up
            ys = np.array(
                [['up']*n_up + ['down']*n_down]).reshape((n, 1))
            yt = np.array(
                [['up']*n_up + ['down']*n_down]).reshape((n, 1))
            
            # source samples
            angles = rng.rand(n, 1) * 2 * np.pi
            XS = (
                np.concatenate([np.cos(angles), np.sin(angles)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS[:n_up, 1] += 2
            angles_test = rng.rand(n, 1) * 2 * np.pi
            XS_test = (
                np.concatenate([np.cos(angles_test), np.sin(angles_test)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS_test[:n_up, 1] += 2

            # target samples
            anglet = rng.rand(n, 1) * 2 * np.pi
            XT = (
                np.concatenate([np.sin(anglet), np.cos(anglet)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XT[:n_up, 1] += 2
            
            # nonconfounding variable
            RvL_s_true = np.where(
                (angles < 0.5 * np.pi) | (angles > 1.5 * np.pi), 
                'right', 'left')            
            RvL_s_test_true = np.where(
                (angles_test < 0.5 * np.pi) | (angles_test > 1.5 * np.pi),
                'right', 'left')
            
            A_true = np.array([[1.5, .7], [.7, 1.5]])
            b_true = np.array([[4, 2]])
            XT = XT.dot(A_true) + b_true          
            XS_true = XS.dot(A_true) + b_true
            XS_test_true = XS_test.dot(A_true) + b_true
        elif prob_setting == "Confounded":
            n = 200
            d = 2
            sigma = .1
            levery = n // n_lines
            n_up_s = n // 4
            n_down_s = n - n_up_s
            n_up_t = n // 2
            n_down_t = n // 2
            ys = np.array(
                [['up']*n_up_s + ['down']*n_down_s]).reshape((n, 1))
            yt = np.array(
                [['up']*n_up_t + ['down']*n_down_t]).reshape((n, 1))
            
            # source samples
            angles = rng.rand(n, 1) * 2 * np.pi
            XS = (
                np.concatenate([np.cos(angles), np.sin(angles)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS[:n_up_s, 1] += 2
            angles_test = rng.rand(n, 1) * 2 * np.pi
            XS_test = (
                np.concatenate([np.cos(angles_test), np.sin(angles_test)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS_test[:n_up_t, 1] += 2

            # target samples
            anglet = rng.rand(n, 1) * 2 * np.pi
            XT = (
                np.concatenate([np.sin(anglet), np.cos(anglet)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XT[:n_up_t, 1] += 2
            
            # nonconfounding variable
            RvL_s_true = np.where(
                (angles < 0.5 * np.pi) | (angles > 1.5 * np.pi), 
                'right', 'left')            
            RvL_s_test_true = np.where(
                (angles_test < 0.5 * np.pi) | (angles_test > 1.5 * np.pi),
                'right', 'left')
            
            A_true = np.array([[1.5, .7], [.7, 1.5]])
            b_true = np.array([[4, 2]])
            XT = XT.dot(A_true) + b_true
            XS_true = XS.dot(A_true) + b_true
            XS_test_true = XS_test.dot(A_true) + b_true
        elif prob_setting == "Confounded - Challenging":
            n = 200
            d = 2
            sigma = .1
            levery = n // n_lines
            n_up_s = n // 4
            n_down_s = n - n_up_s
            n_up_t = n // 2
            n_down_t = n // 2
            ys = np.array(
                [['up']*n_up_s + ['down']*n_down_s]).reshape((n, 1))
            yt = np.array(
                [['up']*n_up_t + ['down']*n_down_t]).reshape((n, 1))
            
            # source samples
            angles = rng.rand(n, 1) * 2 * np.pi
            XS = (
                np.concatenate([np.cos(angles), np.sin(angles)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS[:n_up_s, 1] += 2
            angles_test = rng.rand(n, 1) * 2 * np.pi
            XS_test = (
                np.concatenate([np.cos(angles_test), np.sin(angles_test)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS_test[:n_up_t, 1] += 2

            # target samples
            anglet = rng.rand(n, 1) * 2 * np.pi
            XT = (
                np.concatenate([np.sin(anglet), np.cos(anglet)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XT[:n_up_t, 1] += 2
            
            # nonconfounding variable
            RvL_s_true = np.where(
                (angles < 0.5 * np.pi) | (angles > 1.5 * np.pi), 
                'right', 'left')            
            RvL_s_test_true = np.where(
                (angles_test < 0.5 * np.pi) | (angles_test > 1.5 * np.pi),
                'right', 'left')
            
            A_true = np.array([[ 0.05,  1.1], [-1.5,  1.3]])
            b_true = np.array([[4, 2]])
            XT = XT.dot(A_true) + b_true
            XS_true = XS.dot(A_true) + b_true
            XS_test_true = XS_test.dot(A_true) + b_true
        # Before correction
        method = "Before Correction"
        mix = 0
        rMSE = np.sqrt(np.mean((XS - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XS_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        axes[mix, pix].tick_params(axis="both", which="both", direction="in")    
        axes[mix, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[mix, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[mix, pix].scatter(XS_true[:,0], XS_true[:,1], marker='+', s=msize);
        axes[mix, pix].plot(
            np.vstack([XS[::levery, [0]].T, XS_true[::levery, [0]].T]),
            np.vstack([XS[::levery, [1]].T, XS_true[::levery, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[mix, pix].set_title(prob_setting);
        axes[mix, pix].plot([-1.5, 1.5], [1, 1], ':m');
        axes[mix, pix].plot([0, 0], [-1.5, 3.5], '--c');
        if prob_setting == "Confounded - Challenging":
            ud_x1 = 2.4
            ud_x2 = 2.6
            lr_x1 = -0.5
            lr_x2 = 6
        else:
            ud_x1 = 3
            ud_x2 = 7
            lr_x1 = 3
            lr_x2 = 6
        axes[mix, pix].plot(
            [lr_x1, lr_x2],
            [
                A_true[1, 1]/A_true[1, 0]*lr_x1+b_true[0,1]-b_true[0, 0]*A_true[1, 1]/A_true[1, 0], 
                A_true[1, 1]/A_true[1, 0]*lr_x2+b_true[0,1]-b_true[0, 0]*A_true[1, 1]/A_true[1, 0]
            ], '--c');
        axes[mix, pix].plot(
            [ud_x1, ud_x2],
            [
                A_true[0, 1]/A_true[0, 0]*ud_x1
                +(A_true[1,1]+b_true[0,1]-A_true[0,1]/A_true[0,0]*(A_true[1,0]+b_true[0,0])),
                A_true[0, 1]/A_true[0, 0]*ud_x2
                +(A_true[1,1]+b_true[0,1]-A_true[0,1]/A_true[0,0]*(A_true[1,0]+b_true[0,0])),
            ], ':m');
        axes[mix, pix].text(
            0.1, 0.9, 
            f"{rmse_str}",
            size=fsizemse, transform = axes[mix, pix].transAxes);
        axes[mix, pix].set_ylim(-1, 10);
        
        # Oracle - not displayed
        method = "Oracle"
        XSoracle = XS @ A_true + b_true
        XSoracle_back = (XSoracle - b_true) @ np.linalg.inv(A_true)
        XSoracle_pred_ud = np.where(XSoracle_back[:,[1]] > 1.5, "up", "down")
        XSoracle_pred_lr = np.where(XSoracle_back[:,[0]] > 0, "right", "left")
        XSoracle_test = XS_test @ A_true + b_true
        XSoracle_test_back = (XSoracle_test - b_true) @ np.linalg.inv(A_true)
        XSoracle_test_pred_ud = np.where(XSoracle_test_back[:,[1]] > 1.5, "up", "down")
        XSoracle_test_pred_lr = np.where(XSoracle_test_back[:,[0]] > 0, "right", "left")        
        # rmse
        rMSE = np.sqrt(np.mean((XSoracle - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XSoracle_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        # acc
        ud_acc = np.mean(XSoracle_pred_ud == ys)
        ud_test_acc = np.mean(XSoracle_test_pred_ud == ys)
        udAccs[method][prob_setting].append(ud_acc)      
        udAccs_test[method][prob_setting].append(ud_test_acc)
        avg_udAccs = np.mean(udAccs[method][prob_setting])
        avg_udAccs_test = np.mean(udAccs_test[method][prob_setting])
        ud_str = f"U-vs-D: {avg_udAccs:.2f} ({avg_udAccs_test:.2f})"
        lr_acc = np.mean(XSoracle_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSoracle_test_pred_lr == RvL_s_test_true)
        lrAccs[method][prob_setting].append(lr_acc)      
        lrAccs_test[method][prob_setting].append(lr_test_acc)
        avg_lrAccs = np.mean(lrAccs[method][prob_setting])
        avg_lrAccs_test = np.mean(lrAccs_test[method][prob_setting])
        lr_str = f"L-vs-R:  {avg_lrAccs:.2f} ({avg_lrAccs_test:.2f})"

        # Gaussian OT
        method = "Gaussian OT"
        mix = 1
        A_otda, b_otda = ot.da.OT_mapping_linear(XS, XT)
        XSotda = XS @ A_otda + b_otda
        XSotda_back = (XSotda - b_true) @ np.linalg.inv(A_true)
        XSotda_pred_ud = np.where(XSotda_back[:,[1]] > 1.5, "up", "down")
        XSotda_pred_lr = np.where(XSotda_back[:,[0]] > 0, "right", "left")
        XSotda_test = XS_test @ A_otda + b_otda
        XSotda_test_back = (XSotda_test - b_true) @ np.linalg.inv(A_true)
        XSotda_test_pred_ud = np.where(XSotda_test_back[:,[1]] > 1.5, "up", "down")
        XSotda_test_pred_lr = np.where(XSotda_test_back[:,[0]] > 0, "right", "left")        
        # rmse
        rMSE = np.sqrt(np.mean((XSotda - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XSotda_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        # acc
        ud_acc = np.mean(XSotda_pred_ud == ys)
        ud_test_acc = np.mean(XSotda_test_pred_ud == ys)
        udAccs[method][prob_setting].append(ud_acc)      
        udAccs_test[method][prob_setting].append(ud_test_acc)
        avg_udAccs = np.mean(udAccs[method][prob_setting])
        avg_udAccs_test = np.mean(udAccs_test[method][prob_setting])
        ud_str = f"U-vs-D: {avg_udAccs:.2f} ({avg_udAccs_test:.2f})"
        lr_acc = np.mean(XSotda_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSotda_test_pred_lr == RvL_s_test_true)
        lrAccs[method][prob_setting].append(lr_acc)      
        lrAccs_test[method][prob_setting].append(lr_test_acc)
        avg_lrAccs = np.mean(lrAccs[method][prob_setting])
        avg_lrAccs_test = np.mean(lrAccs_test[method][prob_setting])
        lr_str = f"L-vs-R:  {avg_lrAccs:.2f} ({avg_lrAccs_test:.2f})"
        # plot
        axes[mix, pix].tick_params(axis="both", which="both", direction="in")    
        axes[mix, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[mix, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[mix, pix].scatter(XS_true[:,0], XS_true[:,1], marker='+', s=msize)
        axes[mix, pix].scatter(XSotda[:,0], XSotda[:,1], marker='x', s=msize)
        axes[mix, pix].plot(
            np.vstack([XS[::levery, [0]].T, XSotda[::levery, [0]].T]),
            np.vstack([XS[::levery, [1]].T, XSotda[::levery, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[mix, pix].text(
            0.1, 0.7, 
            f"{rmse_str}\n{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[mix, pix].transAxes);        
        if pix == num_probs - 1:
            axes[mix, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

        # MMD
        method = "MMD"
        mix = 2
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cder = condo.MMDAdapter(
                transform_type="affine",
                optim_kwargs={"epochs": 25, "alpha": 0.1, "beta": 0.9},
                debug=False,
                verbose=0,
            )
            cder.fit(XS, XT)
            XSmmd = cder.transform(XS)
            XSmmd_test = cder.transform(XS_test)
        XSmmd_back = (XSmmd - b_true) @ np.linalg.inv(A_true)
        XSmmd_pred_ud = np.where(XSmmd_back[:,[1]] > 1.5, "up", "down")
        XSmmd_pred_lr = np.where(XSmmd_back[:,[0]] > 0, "right", "left")
        XSmmd_test_back = (XSmmd_test - b_true) @ np.linalg.inv(A_true)
        XSmmd_test_pred_ud = np.where(XSmmd_test_back[:,[1]] > 1.5, "up", "down")
        XSmmd_test_pred_lr = np.where(XSmmd_test_back[:,[0]] > 0, "right", "left")        
        # rmse
        rMSE = np.sqrt(np.mean((XSmmd - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XSmmd_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        # acc
        ud_acc = np.mean(XSmmd_pred_ud == ys)
        ud_test_acc = np.mean(XSmmd_test_pred_ud == ys)
        udAccs[method][prob_setting].append(ud_acc)      
        udAccs_test[method][prob_setting].append(ud_test_acc)
        avg_udAccs = np.mean(udAccs[method][prob_setting])
        avg_udAccs_test = np.mean(udAccs_test[method][prob_setting])
        ud_str = f"U-vs-D: {avg_udAccs:.2f} ({avg_udAccs_test:.2f})"
        lr_acc = np.mean(XSmmd_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSmmd_test_pred_lr == RvL_s_test_true)
        lrAccs[method][prob_setting].append(lr_acc)      
        lrAccs_test[method][prob_setting].append(lr_test_acc)
        avg_lrAccs = np.mean(lrAccs[method][prob_setting])
        avg_lrAccs_test = np.mean(lrAccs_test[method][prob_setting])
        lr_str = f"L-vs-R:  {avg_lrAccs:.2f} ({avg_lrAccs_test:.2f})"
        # plot
        axes[mix, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[mix, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[mix, pix].scatter(XS_true[:,0], XS_true[:,1], marker='+', s=msize)
        axes[mix, pix].scatter(XSmmd[:,0], XSmmd[:,1], marker='x', s=msize)
        axes[mix, pix].plot(
            np.vstack([XS[::levery, [0]].T, XSmmd[::levery, [0]].T]),
            np.vstack([XS[::levery, [1]].T, XSmmd[::levery, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[mix, pix].text(
            0.1, 0.7, 
            f"{rmse_str}\n{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[mix, pix].transAxes);        
        if pix == num_probs - 1:
            axes[mix, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);
            
        # ConDo Linear-ReverseKL
        method = "ConDo Linear-ReverseKL"
        mix = 3
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cder = condo.ConDoAdapter(
                sampling="product",
                transform_type="affine",
                model_type="linear",
                divergence="reverse",
                debug=False,
                verbose=0
            )
            cder.fit(XS, XT, ys, yt)
            XSreverse = cder.transform(XS)
            XSreverse_test = cder.transform(XS_test)
        XSreverse_back = (XSreverse - b_true) @ np.linalg.inv(A_true)
        XSreverse_pred_ud = np.where(XSreverse_back[:,[1]] > 1.5, "up", "down")
        XSreverse_pred_lr = np.where(XSreverse_back[:,[0]] > 0, "right", "left")
        XSreverse_test_back = (XSreverse_test - b_true) @ np.linalg.inv(A_true)
        XSreverse_test_pred_ud = np.where(XSreverse_test_back[:,[1]] > 1.5, "up", "down")
        XSreverse_test_pred_lr = np.where(XSreverse_test_back[:,[0]] > 0, "right", "left")        
        # rmse
        rMSE = np.sqrt(np.mean((XSreverse - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XSreverse_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        # acc
        ud_acc = np.mean(XSreverse_pred_ud == ys)
        ud_test_acc = np.mean(XSreverse_test_pred_ud == ys)
        udAccs[method][prob_setting].append(ud_acc)      
        udAccs_test[method][prob_setting].append(ud_test_acc)
        avg_udAccs = np.mean(udAccs[method][prob_setting])
        avg_udAccs_test = np.mean(udAccs_test[method][prob_setting])
        ud_str = f"U-vs-D: {avg_udAccs:.2f} ({avg_udAccs_test:.2f})"
        lr_acc = np.mean(XSreverse_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSreverse_test_pred_lr == RvL_s_test_true)
        lrAccs[method][prob_setting].append(lr_acc)      
        lrAccs_test[method][prob_setting].append(lr_test_acc)
        avg_lrAccs = np.mean(lrAccs[method][prob_setting])
        avg_lrAccs_test = np.mean(lrAccs_test[method][prob_setting])
        lr_str = f"L-vs-R:  {avg_lrAccs:.2f} ({avg_lrAccs_test:.2f})"
        # plot
        axes[mix, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[mix, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[mix, pix].scatter(XS_true[:,0], XS_true[:,1], marker='+', s=msize)
        axes[mix, pix].scatter(XSreverse[:,0], XSreverse[:,1], marker='x', s=msize)
        axes[mix, pix].plot(
            np.vstack([XS[::levery, [0]].T, XSreverse[::levery, [0]].T]),
            np.vstack([XS[::levery, [1]].T, XSreverse[::levery, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[mix, pix].text(
            0.1, 0.7, 
            f"{rmse_str}\n{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[mix, pix].transAxes);        
        if pix == num_probs - 1:
            axes[mix, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);
        
        # ConDo PoGMM-ReverseKL
        method = "ConDo PoGMM-ReverseKL"
        mix = 4
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cder = condo.ConDoAdapter(
                sampling="product",
                transform_type="affine",
                model_type="pogmm",
                divergence="reverse",
                debug=False,
                verbose=0
            )
            cder.fit(XS, XT, ys, yt)
            XSreverse = cder.transform(XS)
            XSreverse_test = cder.transform(XS_test)
        XSreverse_back = (XSreverse - b_true) @ np.linalg.inv(A_true)
        XSreverse_pred_ud = np.where(XSreverse_back[:,[1]] > 1.5, "up", "down")
        XSreverse_pred_lr = np.where(XSreverse_back[:,[0]] > 0, "right", "left")
        XSreverse_test_back = (XSreverse_test - b_true) @ np.linalg.inv(A_true)
        XSreverse_test_pred_ud = np.where(XSreverse_test_back[:,[1]] > 1.5, "up", "down")
        XSreverse_test_pred_lr = np.where(XSreverse_test_back[:,[0]] > 0, "right", "left")        
        # rmse
        rMSE = np.sqrt(np.mean((XSreverse - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XSreverse_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        # acc
        ud_acc = np.mean(XSreverse_pred_ud == ys)
        ud_test_acc = np.mean(XSreverse_test_pred_ud == ys)
        udAccs[method][prob_setting].append(ud_acc)      
        udAccs_test[method][prob_setting].append(ud_test_acc)
        avg_udAccs = np.mean(udAccs[method][prob_setting])
        avg_udAccs_test = np.mean(udAccs_test[method][prob_setting])
        ud_str = f"U-vs-D: {avg_udAccs:.2f} ({avg_udAccs_test:.2f})"
        lr_acc = np.mean(XSreverse_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSreverse_test_pred_lr == RvL_s_test_true)
        lrAccs[method][prob_setting].append(lr_acc)      
        lrAccs_test[method][prob_setting].append(lr_test_acc)
        avg_lrAccs = np.mean(lrAccs[method][prob_setting])
        avg_lrAccs_test = np.mean(lrAccs_test[method][prob_setting])
        lr_str = f"L-vs-R:  {avg_lrAccs:.2f} ({avg_lrAccs_test:.2f})"
        # plot
        axes[mix, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[mix, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[mix, pix].scatter(XS_true[:,0], XS_true[:,1], marker='+', s=msize)
        axes[mix, pix].scatter(XSreverse[:,0], XSreverse[:,1], marker='x', s=msize)
        axes[mix, pix].plot(
            np.vstack([XS[::levery, [0]].T, XSreverse[::levery, [0]].T]),
            np.vstack([XS[::levery, [1]].T, XSreverse[::levery, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[mix, pix].text(
            0.1, 0.7, 
            f"{rmse_str}\n{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[mix, pix].transAxes);        
        if pix == num_probs - 1:
            axes[mix, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

        # ConDo MMD
        method = "ConDo MMD"
        mix = 5
        with warnings.catch_warnings():
            warnings.simplefilter("ignore")
            cder = condo.ConDoAdapter(
                sampling="product",
                transform_type="affine",
                model_type="empirical",
                divergence="mmd",
                optim_kwargs={"epochs": 25, "alpha": 0.01, "beta": 0.9},
                debug=False,
                verbose=0,
            )
            cder.fit(XS, XT, ys, yt)
            XSmmd = cder.transform(XS)
            XSmmd_test = cder.transform(XS_test)
        XSmmd_back = (XSmmd - b_true) @ np.linalg.inv(A_true)
        XSmmd_pred_ud = np.where(XSmmd_back[:,[1]] > 1.5, "up", "down")
        XSmmd_pred_lr = np.where(XSmmd_back[:,[0]] > 0, "right", "left")
        XSmmd_test_back = (XSmmd_test - b_true) @ np.linalg.inv(A_true)
        XSmmd_test_pred_ud = np.where(XSmmd_test_back[:,[1]] > 1.5, "up", "down")
        XSmmd_test_pred_lr = np.where(XSmmd_test_back[:,[0]] > 0, "right", "left")        
        # rmse
        rMSE = np.sqrt(np.mean((XSmmd - XS_true) ** 2));
        rMSEs[method][prob_setting].append(rMSE)
        rMSE_test = np.sqrt(np.mean((XSmmd_test - XS_test_true) ** 2));
        rMSEs_test[method][prob_setting].append(rMSE_test)
        avgrMSE = np.mean(rMSEs[method][prob_setting])
        avgrMSE_test = np.mean(rMSEs_test[method][prob_setting])
        rmse_str = f"rMSE:   {avgrMSE:.2f} ({avgrMSE_test:.2f})"
        # acc
        ud_acc = np.mean(XSmmd_pred_ud == ys)
        ud_test_acc = np.mean(XSmmd_test_pred_ud == ys)
        udAccs[method][prob_setting].append(ud_acc)      
        udAccs_test[method][prob_setting].append(ud_test_acc)
        avg_udAccs = np.mean(udAccs[method][prob_setting])
        avg_udAccs_test = np.mean(udAccs_test[method][prob_setting])
        ud_str = f"U-vs-D: {avg_udAccs:.2f} ({avg_udAccs_test:.2f})"
        lr_acc = np.mean(XSmmd_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSmmd_test_pred_lr == RvL_s_test_true)
        lrAccs[method][prob_setting].append(lr_acc)      
        lrAccs_test[method][prob_setting].append(lr_test_acc)
        avg_lrAccs = np.mean(lrAccs[method][prob_setting])
        avg_lrAccs_test = np.mean(lrAccs_test[method][prob_setting])
        lr_str = f"L-vs-R:  {avg_lrAccs:.2f} ({avg_lrAccs_test:.2f})"
        # plot
        axes[mix, pix].tick_params(axis="both", which="both", direction="in")    
        axes[mix, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[mix, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[mix, pix].scatter(XS_true[:,0], XS_true[:,1], marker='+', s=msize)
        axes[mix, pix].scatter(XSmmd[:,0], XSmmd[:,1], marker='x', s=msize)
        axes[mix, pix].plot(
            np.vstack([XS[::levery, [0]].T, XSmmd[::levery, [0]].T]),
            np.vstack([XS[::levery, [1]].T, XSmmd[::levery, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[mix, pix].text(
            0.1, 0.7, 
            f"{rmse_str}\n{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[mix, pix].transAxes);        
        if pix == num_probs - 1:
            axes[mix, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);
    if rix in (0, num_random - 1):
        fig.savefig(figname, bbox_inches="tight")
    plt.close()

rix:0 No Confounding
rix:0 Confounded
rix:0 Confounded - Challenging
rix:1 No Confounding
rix:1 Confounded
rix:1 Confounded - Challenging
rix:2 No Confounding
rix:2 Confounded
rix:2 Confounded - Challenging
rix:3 No Confounding
rix:3 Confounded
rix:3 Confounded - Challenging
rix:4 No Confounding
rix:4 Confounded
rix:4 Confounded - Challenging
