In [None]:
### Magic functinos
%load_ext autoreload
%autoreload 2
%load_ext tensorboard
%matplotlib inline

In [None]:
### imports
import warnings
warnings.simplefilter('ignore')
import itertools
import numpy as np
import matplotlib.pyplot as plt 
import torch
import torch.nn as nn
from sklearn.model_selection import train_test_split
from mliv.dgps import get_data, get_tau_fn, fn_dict
from mliv.neuralnet.deepiv_fit import deep_iv_fit
from mliv.neuralnet.utilities import log_metrics, plot_results
from mliv.neuralnet.rbflayer import gaussian, inverse_multiquadric
from mliv.neuralnet import AGMM, KernelLayerMMDGMM, CentroidMMDGMM, KernelLossAGMM, MMDGMM

# DGPs

In [None]:
n = 100
n_z = 3
n_t = 3
iv_strength = .6
fname = 'abs'
dgp_num = 5
Z, T, Y, true_fn = get_data(n, n_z, iv_strength, get_tau_fn(fn_dict[fname]), dgp_num)

In [None]:
ind = 0
x_grid = np.linspace(np.quantile(T[:, ind], .01), np.quantile(T[:, ind], .99), 100)
T_test = np.zeros((100, T.shape[1])) + np.median(T, axis=0, keepdims=True)
T_test[:, ind] = x_grid

In [None]:
plt.figure(figsize=(10,3))
plt.subplot(1, 2, 1)
plt.scatter(Z[:, 0], Y)
plt.subplot(1, 2, 2)
plt.scatter(T[:, 0], Y)
plt.plot(T[np.argsort(T[:, 0]), 0], true_fn(T[np.argsort(T[:, 0])]))
plt.show()

# DeepIV

In [None]:
deep_iv = deep_iv_fit(Z[:,[0]], T[:,[0]], Y, x=T[:,1:],
                      epochs=100, hidden=[100])

In [None]:
y_pred = deep_iv.predict([T_test[:, 1:], T_test[:, [0]]])
plt.scatter(true_fn(T_test).flatten(), y_pred)
plt.title(1 - np.mean((true_fn(T_test).flatten()-y_pred.flatten())**2) / np.var(true_fn(T_test).flatten()))
plt.show()

# AGMM Variants

## Learner Network

In [None]:
p = 0.1 # dropout prob of dropout layers throughout notebook
n_hidden = 100 # width of hidden layers throughout notebook

learner = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_t, n_hidden), nn.LeakyReLU(),
                        #nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ReLU(),
                        nn.Dropout(p=p), nn.Linear(n_hidden, 1))

## Adversary Networks

In [None]:
# For any method that use a projection of z into features g(z)
g_features = 100
adversary_g = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.LeakyReLU(),
                            nn.Dropout(p=p), nn.Linear(n_hidden, g_features), nn.ReLU())
# The kernel function
kernel_fn = gaussian
# kernel_fn = inverse_multiquadric

# For any method that uses an unstructured adversary test function f(z)
adversary_fn = nn.Sequential(nn.Dropout(p=p), nn.Linear(n_z, n_hidden), nn.LeakyReLU(),
                             #nn.Dropout(p=p), nn.Linear(n_hidden, n_hidden), nn.ReLU(),
                             nn.Dropout(p=p), nn.Linear(n_hidden, 1))

# HyperParameters

In [None]:
learner_lr = 1e-4
adversary_lr = 1e-4
learner_l2 = 1e-3
adversary_l2 = 1e-4
adversary_norm_reg = 1e-3
n_epochs = 300
bs = 100
sigma = 2.0/g_features
n_centers = 100
device = torch.cuda.current_device() if torch.cuda.is_available() else None

In [None]:
%tensorboard --logdir=runs

# Train-Val Split

In [None]:
if torch.cuda.is_available():
    torch.cuda.empty_cache()
Z_train, Z_val, T_train, T_val, Y_train, Y_val = train_test_split(Z, T, Y, test_size=.1, shuffle=True)
Z_train, T_train, Y_train = map(lambda x: torch.Tensor(x), (Z_train, T_train, Y_train))
Z_val, T_val, Y_val = map(lambda x: torch.Tensor(x).to(device), (Z_val, T_val, Y_val))
G_train = true_fn(T_train)
G_val = true_fn(T_val)
G_train, G_val = map(lambda x: x.to(device), (G_train,G_val))
T_test_tens = torch.Tensor(T_test).to(device)
G_test_tens = true_fn(T_test_tens).to(device)

## $\ell_2$-Regularized AGMM with Neural Net Test Function

We solve the problem:
\begin{equation}
\min_{\theta} \max_{w} \frac{1}{n} \sum_i (y_i - h_{\theta}(x_i)) f_w(z_i) - f_w(z_i)^2
\end{equation}
where $h_{\theta}$ and $f_w$ are two neural nets.

In [None]:
def logger(learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary[-1].weight, epoch)
    log_metrics(Z_val, T_val, Y_val, Z_val, T_val, Y_val, T_test_tens,
                learner, adversary, epoch, writer, true_of_T=G_val)

np.random.seed(12356)
agmm = AGMM(learner, adversary_fn).fit(Z_train, T_train, Y_train, learner_lr=learner_lr, adversary_lr=adversary_lr,
                                       learner_l2=learner_l2, adversary_l2=adversary_l2,
                                       n_epochs=n_epochs, bs=bs, logger=logger,
                                       model_dir='agmm_model', device=device)

In [None]:
plot_results(agmm, T_test_tens, G_test_tens, ind=0)

## $\ell_2$-regularized MMD-GMM with Fixed Kernel Points and No Kernel Learning

We choose a fixed set of centers and approximate the kernel function with the kernel distance to these centers.

Here we learn only the $\beta$ and fix $c_i, \sigma_i$, i.e.:
\begin{align}
f(z) = \sum_{i=1}^s \beta_i K_{\sigma_i}(c_i, z_i)
\end{align}
and both $c_i$ is a grid and $\sigma$ is a constant. We are also approximating the RKHS norm with:
\begin{align}
\|f\|_{K}^2 = \sum_{i, j \in [s]} \beta_i \beta_j \frac{K_{\sigma_i}(c_i, c_j) + K_{\sigma_j}(c_j, c_i)}{2}
\end{align}

In [None]:
def logger(learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary.beta.weight, epoch)
    log_metrics(Z_val, T_val, Y_val, Z_val, T_val, Y_val, T_test_tens,
                learner, adversary, epoch, writer, true_of_T=G_val)

# Fixed centers c_1, ..., c_s and kernel precision sigma_1, ..., sigma_s
centers = np.tile(np.linspace(-4, 4, n_centers).reshape(-1, 1), (1, n_z))
sigmas = np.ones((n_centers,)) * 2/n_z

mmdgmm_fixed = KernelLayerMMDGMM(learner, lambda x: x, n_z, n_centers, kernel_fn,
                      centers=centers, sigmas=sigmas, trainable=False)
mmdgmm_fixed.fit(Z_train, T_train, Y_train, learner_l2=learner_l2, adversary_l2=adversary_l2,
                 adversary_norm_reg=adversary_norm_reg,
                 learner_lr=learner_lr, adversary_lr=adversary_lr, n_epochs=n_epochs, bs=bs, logger=logger,
                 model_dir='mmd_fixed_model', device=device)

In [None]:
plot_results(mmdgmm_fixed, T_test_tens, G_test_tens)

## $\ell_2$-regularized MMD-GMM with Kernel Learning and Approximation via Kernel Layer


Here we learn both the G function and the centers and sigmas, i.e.:
\begin{align}
f(z) = \sum_{i=1}^s \beta_i K_{\sigma_i}(c_i, g(z_i))
\end{align}
and both $c_i$ and $\sigma_i$ are also trained. $c_i$ are $n_{\text{features}}$-dimensional vectors, initialized to some inital set of grid values. Moreover, $g(z)$ is a fully connected network with RELU gates. We are also approximating the RKHS norm with:
\begin{align}
\|f\|_{K}^2 = \sum_{i, j \in [s]} \beta_i \beta_j \frac{K_{\sigma_i}(c_i, c_j) + K_{\sigma_j}(c_j, c_i)}{2}
\end{align}


In [None]:
def logger(learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary.beta.weight, epoch)
    log_metrics(Z_val, T_val, Y_val, Z_val, T_val, Y_val, T_test_tens,
                learner, adversary, epoch, writer, true_of_T=G_val)

# Trainable centers c_1, ..., c_s and precisions sigma_1, ..., sigma_s initialization
#centers = np.tile(np.linspace(-4, 4, n_centers).reshape(-1, 1), (1, g_features))
centers = np.random.uniform(-4, 4, size=(n_centers, g_features))
sigmas = np.ones((n_centers,)) * sigma

klayermmdgmm = KernelLayerMMDGMM(learner, adversary_g, g_features,
                                 n_centers, kernel_fn, centers=centers, sigmas=sigmas)
klayermmdgmm.fit(Z_train, T_train, Y_train, learner_l2=learner_l2, adversary_l2=adversary_l2,
                 adversary_norm_reg=adversary_norm_reg,
                 learner_lr=learner_lr, adversary_lr = adversary_lr,  n_epochs=n_epochs, bs=bs, logger=logger,
                 model_dir='klayer_model', device=device)

In [None]:
plot_results(klayermmdgmm, T_test_tens, G_test_tens, ind=0)

## $\ell_2$-regularized MMD-GMM with Kernel Learning and Approximation via KMeans Centroids

Here we do exactly what is described in the paper for low rank approximation. We choose centers of a kmeans clusters in the $z$-space, $c_1, \ldots, c_m$ and then we test function of the form:
\begin{equation}
f(z) = \sum_{i=1}^m \beta_i K_{\sigma_i}(g_w(c_i), g_w(z))
\end{equation}
and penalizes the approximate RKHS norm:
\begin{align}
\|f\|_{K}^2 = \sum_{i, j \in [s]} \beta_i \beta_j \frac{K_{\sigma_i}(g_w(z_i), g_w(z_j)) + K_{\sigma_j}(g_w(z_j), g_w(z_i))}{2}
\end{align}

In [None]:
def logger(learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary.beta.weight, epoch)
    log_metrics(Z_val, T_val, Y_val, Z_val, T_val, Y_val, T_test_tens,
                learner, adversary, epoch, writer, true_of_T=G_val)

# Kmeans based centers z_1, ..., z_s
from sklearn.cluster import KMeans
centers = KMeans(n_clusters=n_centers).fit(Z).cluster_centers_
centroid_mmd = CentroidMMDGMM(learner, adversary_g, kernel_fn, centers, np.ones(n_centers)*sigma)
centroid_mmd.fit(Z_train, T_train, Y_train, learner_l2=learner_l2, adversary_l2=adversary_l2,
                 adversary_norm_reg=adversary_norm_reg,
                 learner_lr=learner_lr, adversary_lr=adversary_lr, n_epochs=n_epochs, bs=bs, logger=logger,
                 model_dir='centroid_model', device=device)

In [None]:
plot_results(centroid_mmd, T_test_tens, G_test_tens, ind=0)

## Un-Regularized MMD-GMM with Kernel Loss and Learned Kernel

We are minimizing the objective:
\begin{equation}
\min_{\theta} \max_{w} \frac{1}{n^2} \sum_{i,j} (y_i - h_{\theta}(x_i)) K_{\sigma}(g_w(z_i), g_w(z_j)) (y_j - h_{\theta}(x_j)
\end{equation}

In [None]:
def logger(learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary.sigma, epoch)
    log_metrics(Z_val, T_val, Y_val, Z_val, T_val, Y_val, T_test_tens,
                learner, adversary, epoch, writer, true_of_T=G_val, loss='kernel')

kernelgmm = KernelLossAGMM(learner, adversary_g, kernel_fn, sigma)
kernelgmm.fit(Z_train, T_train, Y_train, learner_l2=learner_l2**2, adversary_l2=adversary_l2,
              learner_lr=learner_lr, adversary_lr=adversary_lr, n_epochs=n_epochs,
              bs=bs, logger=logger, model_dir='kernel_model', device=device)

In [None]:
plot_results(kernelgmm, T_test_tens, G_test_tens)

## $\ell_2$-Regularized MMD-GMM without Kernel Approximation

Here we test for function of the form:
\begin{equation}
f(z) = \sum_{j=1}^n \beta_j K(g_w(z_j), g_w(z))
\end{equation}
where $i$ ranges over all the training samples. Since the function itself depends on all the data, we need to create unbiased stochastic estimates of the test function at each time step, as well as unbiased stochastic estimtes of its RKHS and $\ell_{2,n}$ norms. 

We sample a subset of the indices $S$ and test for:
\begin{equation}
\hat{f}(z) = \frac{n}{|S|}\sum_{j\in S} \beta_j K(g_w(z_i), g_w(z)).
\end{equation}
Moreover, in order to create an unbiased estimate of the second moment of the test function on the training data, we also draw another set of indices $T$ and a minibatch of data $B$ and approximate it via: 
\begin{equation}
\widehat{\|f\|_{2,n}^2} = \frac{1}{|B|} \frac{n}{|S|} \frac{n}{|T|}\sum_{i\in B, j\in S, k\in T} \beta_j \beta_k K(g_w(z_j), g_w(z_i)) K(g_w(z_k), g_w(z_i)). 
\end{equation}
Moreover, an unbiased estimate of the RKHS norm of the function is created as:
\begin{equation}
\widehat{\|f\|_{K}^2} = \frac{n}{|S|} \frac{n}{|T|}\sum_{j \in S, k\in T} \beta_j \beta_k K(g_w(z_j), g_w(z_k))
\end{equation}

This might require some more fine tuning as the current version is too slow to learn. The gradients here have huge variance. We most probably need to sample the indices $j$ should be sampled proportional to $|\beta_j|$.

In [None]:
np.random.seed(123)

def logger(learner, adversary, epoch, writer):
    writer.add_histogram('learner', learner[-1].weight, epoch)
    writer.add_histogram('adversary', adversary.sigma, epoch)
    log_metrics(Z_val, T_val, Y_val, Z_val, T_val, Y_val, T_test_tens,
                learner, adversary, epoch, writer, true_of_T=G_val, loss=None)

mmdgmm = MMDGMM(learner, adversary_g, Y_train.shape[0], kernel_fn, sigma*np.ones(Y_train.shape[0]))
mmdgmm.fit(Z_train, T_train, Y_train, learner_l2=learner_l2, adversary_l2=adversary_l2,
           adversary_norm_reg=adversary_norm_reg,
           learner_lr=learner_lr, adversary_lr=adversary_lr, n_epochs=n_epochs, bs1=bs, bs2=300, bs3=100,
           logger=logger, model_dir='mmd_model', device=device)

In [None]:
plot_results(mmdgmm, T_test_tens, G_test_tens)