In [None]:
from collections import defaultdict
from itertools import product

import matplotlib.pyplot as plt
import numpy as np
import os
import ot
import pathlib
import seaborn as sns
import time
import warnings

import condo

In [None]:
plt.rcParams["xtick.major.size"] = 2
plt.rcParams["ytick.major.size"] = 2
plt.rcParams['axes.linewidth'] = 0.5
plt.rc('font', size=7) #controls default text size
plt.rc('axes', titlesize=7) #fontsize of the title
plt.rc('axes', labelsize=7) #fontsize of the x and y labels
plt.rc('xtick', labelsize=7) #fontsize of the x tick labels
plt.rc('ytick', labelsize=7) #fontsize of the y tick labels
plt.rc('legend', fontsize=7) #fontsize of the legend
#plt.ioff();

In [None]:
n = 200
d = 2
sigma = .1

num_random = 1

prob_settings = ["Toy8"]#, "Toy8Confounded", "Toy8Hard", "TwoMoons"]
num_probs = len(prob_settings)
rMSEs = defaultdict(lambda: defaultdict(list))
rMSEs_test = defaultdict(lambda: defaultdict(list))
for rix in range(num_random):
    rng = np.random.RandomState(rix)
    
    fig, axes = plt.subplots(
        nrows=6, ncols=num_probs, sharex="all", sharey="all", squeeze=False,
        gridspec_kw={"hspace": 0.03, "wspace": 0.03},
        figsize=(6, 10), dpi=150)
    msize = 1
    basic_leg = ['target', 'source: batch-effected', 'source: true (unobserved)']
    figname = f"figure-categorical2d-{rix}.pdf"
    fsizemse = 7

    for pix, prob_setting in enumerate(prob_settings):
        print(f"rix:{rix} {prob_setting}")

        if prob_setting == "Toy8":
            n_up = n // 2
            n_down = n - n_up
            ys = np.array(
                [['up']*n_up + ['down']*n_down]).reshape((n, 1))
            yt = np.array(
                [['up']*n_up + ['down']*n_down]).reshape((n, 1))
            
            # source samples
            angles = rng.rand(n, 1) * 2 * np.pi
            XS = (
                np.concatenate([np.cos(angles), np.sin(angles)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS[:n_up, 1] += 2
            angles_test = rng.rand(n, 1) * 2 * np.pi
            XS_test = (
                np.concatenate([np.cos(angles_test), np.sin(angles_test)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XS_test[:n_up, 1] += 2

            # target samples
            anglet = rng.rand(n, 1) * 2 * np.pi
            XT = (
                np.concatenate([np.sin(anglet), np.cos(anglet)], axis=1)
                + sigma * rng.randn(n, 2)
            )
            XT[:n_up, 1] += 2
            
            # nonconfounding variable
            RvL_s_true = np.where(
                (angles < 0.5 * np.pi) | (angles > 1.5 * np.pi), 
                'right', 'left')
            
            RvL_s_test_true = np.where(
                (angles_test < 0.5 * np.pi) | (angles_test > 1.5 * np.pi),
                'right', 'left')
            
            A_true = np.array([[1.5, .7], [.7, 1.5]])
            b_true = np.array([[4, 2]])
            XT = XT.dot(A_true) + b_true
            
            XS_true = XS.dot(A_true) + b_true
            XS_test_true = XS_test.dot(A_true) + b_true
        else:
            assert False

        # Before correction
        method = "Before Correction"
        axes[0, pix].tick_params(axis="both", which="both", direction="in")    
        axes[0, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[0, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[0, pix].scatter(XS_true[:,0], XS_true[:,1], s=msize)
        axes[0, pix].set_title("No Confounding");
        
        # Oracle
        method = "Oracle"
        XSoracle = XS @ A_true + b_true
        XSoracle_back = (XSoracle - b_true) @ np.linalg.inv(A_true)
        XSoracle_pred_ud = np.where(XSoracle_back[:,[1]] > 1.5, "up", "down")
        XSoracle_pred_lr = np.where(XSoracle_back[:,[0]] > 0, "right", "left")
        XSoracle_test = XS_test @ A_true + b_true
        XSoracle_test_back = (XSoracle_test - b_true) @ np.linalg.inv(A_true)
        XSoracle_test_pred_ud = np.where(XSoracle_test_back[:,[1]] > 1.5, "up", "down")
        XSoracle_test_pred_lr = np.where(XSoracle_test_back[:,[0]] > 0, "right", "left")        
        ud_acc = np.mean(XSoracle_pred_ud == ys)
        ud_test_acc = np.mean(XSoracle_test_pred_ud == ys)
        ud_str = f"U-vs-D: {ud_acc:.3f} ({ud_test_acc:.3f})"
        lr_acc = np.mean(XSoracle_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSoracle_test_pred_lr == RvL_s_test_true)
        lr_str = f"L-vs-R:  {lr_acc:.3f} ({lr_test_acc:.3f})"        
        axes[1, pix].tick_params(axis="both", which="both", direction="in")    
        axes[1, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[1, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[1, pix].scatter(XS_true[:,0], XS_true[:,1], s=msize)
        axes[1, pix].scatter(XSoracle[:,0], XSoracle[:,1], s=msize)
        axes[1, pix].plot(
            np.vstack([XS[:, [0]].T, XSoracle[:, [0]].T]),
            np.vstack([XS[:, [1]].T, XSoracle[:, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[1, pix].text(
            0.1, 0.8, 
            f"{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[1, pix].transAxes);        
        if pix == num_probs - 1:
            axes[1, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);

        # Gaussian OT
        method = "Gaussian OT"
        A_otda, b_otda = ot.da.OT_mapping_linear(XS, XT)
        XSotda = XS @ A_otda + b_otda
        XSotda_back = (XSotda - b_true) @ np.linalg.inv(A_true)
        XSotda_pred_ud = np.where(XSotda_back[:,[1]] > 1.5, "up", "down")
        XSotda_pred_lr = np.where(XSotda_back[:,[0]] > 0, "right", "left")
        XSotda_test = XS_test @ A_true + b_true
        XSotda_test_back = (XSotda_test - b_true) @ np.linalg.inv(A_true)
        XSotda_test_pred_ud = np.where(XSotda_test_back[:,[1]] > 1.5, "up", "down")
        XSotda_test_pred_lr = np.where(XSotda_test_back[:,[0]] > 0, "right", "left")        
        ud_acc = np.mean(XSotda_pred_ud == ys)
        ud_test_acc = np.mean(XSotda_test_pred_ud == ys)
        ud_str = f"U-vs-D: {ud_acc:.3f} ({ud_test_acc:.3f})"
        lr_acc = np.mean(XSotda_pred_lr == RvL_s_true)
        lr_test_acc = np.mean(XSotda_test_pred_lr == RvL_s_test_true)
        lr_str = f"L-vs-R:  {lr_acc:.3f} ({lr_test_acc:.3f})"        
        axes[2, pix].tick_params(axis="both", which="both", direction="in")    
        axes[2, pix].scatter(XT[:,0], XT[:,1], s=msize)
        axes[2, pix].scatter(XS[:,0], XS[:,1], s=msize)
        axes[2, pix].scatter(XS_true[:,0], XS_true[:,1], s=msize)
        axes[2, pix].scatter(XSotda[:,0], XSotda[:,1], s=msize)
        axes[2, pix].plot(
            np.vstack([XS[:, [0]].T, XSotda[:, [0]].T]),
            np.vstack([XS[:, [1]].T, XSotda[:, [1]].T]),
            alpha=0.5, linewidth=0.2, color="lightgray");
        axes[2, pix].text(
            0.1, 0.8, 
            f"{ud_str}\n{lr_str}",
            size=fsizemse, transform = axes[2, pix].transAxes);        
        if pix == num_probs - 1:
            axes[2, pix].legend(
                basic_leg + [f"source: {method}"],
                loc="center left", bbox_to_anchor=(1.05, 0.5),
                frameon=False, borderpad=0.2, handlelength=0.3, borderaxespad=0.2);