In [None]:
! pip install densratio

In [10]:
# TODO

# 1.
# Make non-linear in such a way that you cannot extrapolate to unseen P(Y,X)

# 2. 
# Use importance estimation to extrapolate.

# 3. 
# Show that when H|Z vs Z|H, extrapolation fails even with importance estimation. 

# 4
# Make true model such that excluding variables should recover a model that is "robust" (P(X|H))

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
plt.rcParams['figure.figsize'] = [10, 3]

In [3]:
def plot_dat(vars_):
    f, axes = plt.subplots(len(vars_), 1, sharex=True, figsize=(20, 10))

    for (H,title),ax in zip(vars_, axes):
        for i,h in zip(['A', 'B', 'C', 'D'], H):
            sns.distplot(h, label = i, ax=ax)
        ax.legend()
        ax.set_title(title)
        
    # plt.title(title)
    # plt.show()

In [4]:
import numpy as np
from scipy.stats import gamma
import seaborn as sns
from copy import deepcopy

# Y := f(H, W, X, Z, N_Y)
# Y := f(W, X, Z, N_Y)
def fn(h, v, z, w):
    val = -1*(h > np.mean(h))*w*h**2 + w*h + w*z + np.random.normal(0, 5, size = h.shape[0])
    return val/10


def generate_data(N, fn, hidden_cause = True, plot=False, hiddens = [(20,2)]*4, v_conds = [(250,5,5)]*4):
    # H is latent variable, distribution changes (not )
    H = [gamma.rvs(a, loc=b, scale=1, size=N) for a,b in hiddens]

    # V := f(H, N_X)
    V = [(1/(h))*c + np.random.gamma(a, b, size=N) for h,(c,a,b) in zip(H, v_conds)]


    if not hidden_cause:
        V,H = deepcopy(H), deepcopy(V)

    # Z = [gamma.rvs(int(np.random.normal(40, 10)), loc=0, scale=1, size=N) for h in H]
    # Z := f(N_Z) 
    Z = [gamma.rvs(2, loc=1, scale=3, size=N) for h in H]

    # W := f(N_W) -- TREATMENT
    W = [np.random.binomial(1, 0.5, size=N) for h in H]

    # Y:= fn(H, V, Z, W, N_Y)
    Y = [fn(H[idx], V[idx], Z[idx], W[idx]) for idx in range(4)]

    taus = [fn(h,v,z,1) - fn(h,v,z,0) for h,v,z in zip(H,V,Z)]
    
    if plot:
        plot_dat([(H,'H'), (V, 'V') , (Z, 'Z'), (Y, 'Y'), (taus, 'tau')])

    return [(y, np.array([w,v,z]).T, tau) for y,v,z,w,h,tau in zip(Y, V, Z, W, H, taus)]

In [5]:
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.neighbors.kde import KernelDensity
from scipy.stats import gaussian_kde
from scipy.stats import wasserstein_distance
from itertools import combinations

def kde_score(X):
    return gaussian_kde(X).evaluate

def score_wasserstein(dat, phi, model, train_idx):
    y_train, X_train, _ = dat[train_idx]
    model.fit(phi(X_train), y_train)
    resids = [model.predict(phi(x)) - y for y,x,_ in dat]
    dists = [wasserstein_distance(resids[0], r) for r in resids]
    mss = [np.mean(r**2) for r in resids]
    score = np.sum(dists)
    return score


def filter_dat(d, idxs):
    return [(y, x[:,idxs],tau) for y,x,tau in d]

def search_wasserstein(dat, phi, model, train_idx):
    s = dat[0][1].shape[1]

    combs = [j for i in range(1,s) 
             for j in combinations(range(1, s), i)]

    combs = [[0] + list(c) for c in combs]

    scores = [score_wasserstein(filter_dat(dat, i), phi, model, train_idx) for i in combs]

    return combs[np.argmin(scores)]


def run_model(dat, model, phi, train_idx, target_idx, use_weights = None, model_search = False):
    X_train, X_target = dat[train_idx][1], dat[target_idx][1]

    if model_search:
        # exclude the target in the model search
        dd = [d for i,d in enumerate(dat) if i != target_idx]
        idxs = search_wasserstein(dd, phi, model, train_idx)
        # print('Choosing variables: ', idxs)
    else:
        idxs = range(0, X_train.shape[1])


    if use_weights is not None:
        ps, pt = kde_score(X_train[:, use_weights]), kde_score(X_target[:, use_weights])
        weights = pt(X_train[:, use_weights]) / ps(X_train[:, use_weights])
    else:
        weights = np.ones(X_train.shape[0])


    d = [(phi(x[:, idxs]), x[:, 0], y) for y,x,tau in dat]

    model.fit(d[train_idx][0], d[train_idx][2], sample_weight=weights)

    y0 = [np.mean(model.predict(p[w == 0])) for p,w,y in d]
    y1 = [np.mean(model.predict(p[w == 1])) for p,w,y in d]

    pred_ates = [a-b for a,b in zip(y1, y0)]
    _, _, true_ates = zip(*dat)

    return [np.round(np.abs(p-np.mean(t)), 4) for p,t in zip(pred_ates, true_ates)]

phi = PolynomialFeatures(degree=2, include_bias=True).fit_transform

In [6]:
hiddens = [(25, 2), (4, 15), (4, 20), (2, 25)]
# hiddens = [(10,2)]*4
v_conds = [(500,2,2)]*4
# v_conds = [(250,2,30), (250,5,10), (500,15,5), (500,40,2)]

dat = generate_data(2000, 
                    fn, 
                    hidden_cause = False, 
                    plot = False, 
                    hiddens = hiddens,
                    v_conds = v_conds)

In [7]:
from trees import build_tree, mse, predict

In [8]:
y, X, tau = dat[0]

In [9]:
phi(X).shape

(2000, 10)

In [10]:
tree = build_tree(mse, phi(X), y, np.arange(1,10), 5, 0.0, 10)

In [11]:
tree

Node(dim=6, thresh=1.6751720397780472, gain=250.6995844541998, left=Node(dim=8, thresh=35.08843057675259, gain=1.0049337677610772, left=Leaf(prediction=-10.054217787279395, score=495.4986611030132, N=10), right=Node(dim=1, thresh=0.5, gain=0.04886421368746485, left=Node(dim=8, thresh=42.24091763323145, gain=0.0011156622558234908, left=Leaf(prediction=0.33937912628930905, score=0.16175088802730778, N=10), right=Node(dim=8, thresh=45.63447084189335, gain=0.0018710982962970857, left=Leaf(prediction=-0.33812641465206406, score=0.09244867000261951, N=15), right=Leaf(prediction=0.012441948298714913, score=0.2521327209107257, N=955))), right=Leaf(prediction=1.9551452355894579, score=0.07364086568450202, N=13))), right=Node(dim=2, thresh=24.22412455627345, gain=717.9626243435771, left=Node(dim=2, thresh=19.585729236286323, gain=235.25491068782162, left=Node(dim=2, thresh=16.318933304986967, gain=260.25115249661695, left=Leaf(prediction=-139.29775655879394, score=338.8996932048696, N=10), right

In [12]:
yt, Xt, tt = dat[1]

In [13]:
preds = np.array([predict(x, tree) for x in phi(X)])

np.mean((preds - y)**2)

347.9654010757514

In [14]:
from sklearn.tree import DecisionTreeRegressor

model = DecisionTreeRegressor(max_depth=5, min_samples_leaf=10)

In [15]:
model.fit(phi(X)[:, 1:], y)

# np.unique(model.predict(), return_counts = True)

preds = model.predict(phi(X)[:, 1:])

np.mean((preds - y)**2)

347.9654010757514

In [16]:
np.argsort(model.feature_importances_)

array([3, 8, 0, 7, 2, 4, 6, 5, 1])

In [17]:
model = RandomForestRegressor(n_estimators=50, max_depth=2, min_samples_leaf=.05)

np.mean(np.array([run_model(dat, model, phi, 0, 3, use_weights = None, model_search = True) for i in range(10)]), 0)

array([ 0.10618, 19.78116,  0.85286,  6.71827])

In [None]:
np.mean(np.array([run_model(dat, model, phi, 0, 3, use_weights = None, model_search = False) for i in range(10)]), 0)

In [None]:
np.mean(np.array([run_model(dat, model, phi, 0, 3, use_weights = 1, model_search = True) for i in range(10)]), 0)

In [None]:
np.mean(np.array([run_model(dat, model, phi, 0, 3, use_weights = 1, model_search = False) for i in range(10)]), 0)

In [1893]:
model = LinearRegression(fit_intercept=False)
run_model(dat, model, phi, 0, 1, use_weights = None, model_search = False)

[0.0418, 23.8089, 15.9848, 6.2811]

In [1894]:
run_model(dat, model, phi, 0, 3, use_weights = None, model_search = True)

[0.0418, 19.8237, 0.1206, 7.0669]

In [1903]:
run_model(dat, model, phi, 0, 1, use_weights = 1, model_search = False)

[18.8894, 43.8729, 12.3517, 12.0683]

In [1904]:
run_model(dat, model, phi, 0, 3, use_weights = 1, model_search = True)

[12.3778, 32.2376, 12.4323, 5.3211]

In [1802]:
# get residuals for "sets" separately
# compute distance between residuals
# optimize squared errors + penalty for residual distance

# search for "sets" by looking at residuals and fitting a mixture model
# then optimize to remove that mixture...

# set up an adversarial problem: the adversary tries to find a 
# mixture model in your reiduals, the classifier tries to make force the
# adversary to fit a 1-component mixture, for example... 