In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

from sklearn.datasets import make_moons

In [None]:
X, y = make_moons(n_samples=1000)

In [None]:
import pandas as pd
import janitor
from random import random

# Now, assign 0.8 probability of responder to treatment group == 1
# and 0.2 probability of responder to treatment group == 0
def assign_response(x):
    if x["treatment_group"] == 1:
        p = 0.8
    else:
        p = 0.2
    
    return (random() < p) * 1

data = pd.DataFrame(X).rename_columns({0: "covariate_1", 1: "covariate_2"}).add_column("treatment_group", y).join_apply(assign_response, new_column_name="response")

In [None]:
data

In [None]:
import matplotlib.pyplot as plt
import holoviews as hv
hv.extension('bokeh')
import hvplot.pandas

data.hvplot(kind="scatter", x="covariate_1", y="covariate_2", c="response", colormap="viridis").opts(width=600, height=400)

In [None]:
from fundl.layers import dense
from fundl.activations import relu, sigmoid
from fundl.losses import _cross_entropy_loss
import numpy.random as npr
import jax.numpy as np

def reps(params, x):
    x = dense(params["dense1"], x, nonlin=relu)
    x = dense(params["dense2"], x, nonlin=relu)
    return x


def phi(params, x):
    x = reps(params, x)
    x = dense(params["dense3"], x, nonlin=relu)
    return x


def h(params, x):
    x = dense(params["dense4"], x, nonlin=relu)
    x = dense(params["dense5"], x, nonlin=sigmoid)
    return x

def model(params, x, t):
    x = phi(params, x)
    x = np.hstack([x, t])
    x = h(params, x)
    return x


def cross_entropy_loss(params, model, x, y, t):
    y_hat = model(params, x, t)
    return _cross_entropy_loss(y, y_hat)


def maximum_mean_discrepancy_loss(params, x, indicator):
    """MMD"""
    r = reps(params, x) # shape: (n, k), where n = number of samples, k = number of reps dimensions
    
    # THERE MAY BE A BUG ON THESE TWO LINES. TODO: CHECK!!!
    non_treated = np.take(r, np.where(indicator == 0, 0, 1), axis=0)  # (:, k)
    treated = np.take(r, np.where(indicator == 1, 0, 1), axis=0)      # (:, k)
    
    A = 10 * npr.normal(size=(r.shape[1], 1))  # adaptation_matrix of shape (k, 1)
    mmd = np.abs(np.mean(np.dot(non_treated, A)) - np.mean(np.dot(treated, A)))  # scalar
    return mmd

def combined_loss(params, model, x, y, t):
    ce_loss = cross_entropy_loss(params, model, x, y, t)
    mmd_loss = maximum_mean_discrepancy_loss(params, x, t)
    return -ce_loss + mmd_loss


from jax import grad
dloss = grad(combined_loss)

In [None]:
r = reps(params, x)
non_treated = np.take(r, np.where(t == 0, 0, 1), axis=0)
treated = np.take(r, np.where(t == 1, 0, 1), axis=0)      # (:, k)

plt.hist(treated.flatten())
plt.hist(non_treated.flatten())
# treated

In [None]:
from fundl.weights import add_dense_params
params = dict()
params = add_dense_params(params, "dense1", 2, 10)
params = add_dense_params(params, "dense2", 10, 10)
params = add_dense_params(params, "dense3", 10, 10)
params = add_dense_params(params, "dense4", 11, 10)
params = add_dense_params(params, "dense5", 10, 1)

# Test of forward pass
x = data[["covariate_1", "covariate_2"]].values
t = data["treatment_group"].values.reshape(-1, 1)
y = data["response"].values.reshape(-1, 1)
y_hat = model(params, x, y)
y_hat

In [None]:
# Test loss function
combined_loss(params, model, x, y, t)

In [None]:
from jax.experimental.optimizers import adam

init, update, get_params = adam(step_size=0.005)

state = init(params)
for i in range(1000):
    g = dloss(params, model, x, y, t)
    state = update(i, g, state)
    params = get_params(state)
    l = combined_loss(params, model, x, y, t)
    print(l)

In [None]:
r = reps(params, x) 
twodims = np.dot(r, npr.normal(size=(10, 2)))
pd.DataFrame(twodims).rename_columns({0: "x", 1: "y"}).add_column("treatment_group", t).hvplot.scatter(x="x", y="y", c="treatment_group")

In [None]:
data.query("treatment_group == 0")

In [None]:
non_treated = np.take(x, np.where(t == 0, 0, 1), axis=0)
non_treated

In [None]:
np.where??