# TODO

* I want to be able to choose the architecture of each node independently. Maybe add a dictionary of architectures with a set default value for every node that is not in the dictionary?

# Introductory Tutorial -- Proxy Equalizer

## Imports

In [None]:
import torch

from mlp import MLP
from sem import SEM
from interventions import Interventions
import utils

In [None]:
from importlib import reload

In [None]:
reload(utils)

## Input the graph for the SEM

First we have to set up a structural equation model.
It consists of a graph and the corresponding equations.
We initialize an `SEM` object by passing in a graph as a dictionary. (Details of the data structure are in the docstring of the `SEM` class.

We can then draw the graph with `sem.draw()` and print a lot of information about it with `sem.summary()`.

In [None]:
sem = SEM({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})
sem.summary()
sem.draw()

## Specify the structural equations

Let us first check the status of the vertices to make sure we attach valid equations.

In [None]:
# All vertices
print("All vertices: ", sem.vertices())
# Root vertices => provide distributions
print("Roots: ", sem.roots())
# Non root vertices => provide equations making use of all parents
print("Non-roots: ", sem.non_roots())

Now we attach structural equations to the vertices with `sem.attach_equation(vertex, callable)`.
For the root vertices, we draw from a standard normal.

The only argument to the callable is an integer `n`, the number of samples to draw. Of course, we could also attach different distributions separately.

**Note**: The `callable` attached to a vertex needs to return a `torch.tensor`.

In [None]:
for v in sem.roots():
    sem.attach_equation(v, lambda n: torch.randn(n, 1))

For the non-root vertices we attach made up functions.

The only argument to the callable for non-roots is a dictionary `data` that must have the vertex names as keys. This example shows how the parent vertices are accessed. We just construct a fully linear model in which all coefficients are just 1.

In [None]:
sem.attach_equation("P", lambda data: 1 * data['Np'] + 1 * data['A'])
sem.attach_equation("X", lambda data: 1 * data['A'] + 1 * data['P'] + 1 * data['Nx'])
sem.attach_equation("Y", lambda data: 1 * data['P'] + 1 * data['X'])

## Sample from the SEM

Now the SEM is fully specified and we can draw samples from it.

In [None]:
orig_sample = sem.sample(8192)

The `utils` module contains functions for plotting whole samples, where each variable is plotted as a function of its parents.

In [None]:
utils.plot_samples(sem, orig_sample)

## Learn the structural equations from data

While in this example we provided analytical equations for the structural equation model, in reality we only get data. Our assumptions are that we guessed the causal graph correctly, but we do not know the structural equations. We assume that we have a observed samples from the graph. In this example, we will use the generated sample as our observed data.

Given the graph and the observed data, we can now try to learn the structural equations. **Note**: This can be done even if we had not attached structural equations to the `SEM` object.

**Arguments**: We pass in our "observed" sample, and can specify the number and sizes of hidden layers by `hidden_sizes` as a tuple (default: `()` i.e. no hidden layers). For single hidden layer please use `(n,)` instead of `(n)`. Moreover, we can pass a list of vertices to the `binarize` keyword to add a `torch.nn.Sigmoid()` layer at the end when predicting those vertices (default: `[]`). Further, we can pass `epochs` (default: `50`) and `batchsize` (default: `32`) as named arguments.

In [None]:
sem.learn_from_sample(sample=orig_sample, hidden_sizes=(), binarize=[])

We can look at what networks have been learned.

In [None]:
sem.learned

For smaller networks (especially in the linear case with no hidden layers), it can be insightful to check whether the learned parameters match the actual coefficients in the analytical equations from which the sample was generated. In our simple case we get only ones, so we almost perfectly learned the linear equations (unsurprisingly).

In [None]:
sem.print_learned_parameters(weights=True, biases=False)

## Sample from the learned equations

Similarly to how we sampled from the analytical structural equations before, we can now sample from the learned equations.

Note, however, that we did not learn the distributions for the root vertices. Hence we have to provide values for the root vertices and can then pass those down to predict the other vertices with our learned functions with the `predict_from_sample()` function. Without further arguments, it does not mutate the input, but returns a new sample that has identical values for the root vertices and updates all non-root vertices with predictions from the learned functions.

**Note**: The `predict_from_sample()` function is more flexible. One can choose manually which vertices to update (`update` argument), whether to mutate the passed in sample instead of creating a new one with `mutate=True` (then the return value is `None`) and also to use a different predictor for specified vertices by `replace={vertex: predictor}`. 

In [None]:
learned_sample = sem.predict_from_sample(orig_sample)

We can now plot the original sample and the learned sample simultaneously by passing a list of samples to `utils.plot_samples()`.

In [None]:
utils.plot_samples(sem, [orig_sample, learned_sample], legend=['analytic', 'learned'], alpha=0.5)

In the fully linear case, we recover the original sample basically perfectly, i.e. we learned the structural equations exactly.

## Specify the interventions

This is our self made format to specify interventions. In a dict, for each proxy variable, we store another dict, which we call `functions`. In `functions`, keys are preset strings that correspond to the `known_functions` in the `Intervention` class. Current options: `'randn'`, `'rand'`, `'const'`, `'range'`, `'bernoulli'`. Every value of `functions` must be a list of tuples (!), where the tuples hold one or multiple scalar arguments (depending on the key).

**Example:**

This specifies five different intervened values for the proxy `'P1'` and four different intervened values for the proxy `'P2'`, a total of `5 * 4 = 20` different intervened samples.

```python
intervention_spec = {
    'P1': {
          'randn': [(0, 3), (0, 3), (0, 5)],
          'const':[(1,), (0,)],
          },
    'P2': {
          'range': [(-1, 1), (-5, 5)],
          'rand':[(-1, 1), (-5, 5)],
          },
    }
```

Note that `Interventions` also takes a sample as an argument. Currently, interventions are done on an existing sample, i.e. first, we compute the intervened graph, given the proxies specified in the `intervention_spec`. Then we copy the sample `n_interventions` times and fill the proxy values in each sample with one of the possible combination of specified interventions. In the intervened graph, we then update all descendents of the proxies (in topological order), where we might also need values from other root vertices. This is why we already provide a sample.

Strictly, this corresponds to neither counterfactuals nor interventions. As always there's no "right" way to this, but I'm happy for your opinions on the following options:

1. Always use one single sample for the other root vertices in the intervened graph:
    a. Use the same original sample that was used to learn the equations.
    b. Draw a new "base sample" for the retraining part.
2. For each intervened sample, draw the other root vertices in the intervened graph anew.

Consider also:

* In reality, we do not observe a full sample of the graph (root vertices are not observed).
* Can we make assumptions about distribution of root vertices in real life, e.g. Gaussian? If so, how do we find the corresponding root vertex values belonging to one specific observation. (If we see P, X, Y, how do we find the corresponding Nx, A, Np?) While the distributions are enough to sample new values, the specific corresponding values are needed to learn the equations in the first step.

For the linear example, we choose random normal distributions with different variances as interventions.

In [None]:
intervention_spec = {
    'P': {
         'randn': [(0, 3), (0, 3), (0, 3)],
         },
    }
interventions = Interventions(sem, orig_sample, intervention_spec)
interventions.summary()

## Train a corrected version

Eventually we can actually retrain part of the target network, in this case the network for `'Y'` to minimize the variance of predictions across all different intervened samples. Note that here it seems like it only makes sense to do this for the same values of root vertices (closer to counterfactual?), because why would I want similar `'Y'` values for completely different starting values? On the other hand, we want that to be true in distribution, hence for a large batch size, we could also try to enforce that criterion with different values for the root vertices in each intervened sample.

In [None]:
corrected = interventions.train_corrected(epochs=100, batchsize=64, biases=False)

## Evaluate the corrected model

### Small linear models: check parameters directly

For this small linear network we can look directly at the parameters it has learned. We indeed see that it learns the ones everywhere originally and in the corrected version has a -1 for `'P'` instead, exactly what theory demands.

In [None]:
from pprint import pprint
print("Original weights:")
sem.print_learned_parameters(show=['Y'], weights=True, biases=False)

print("")
print("Fair parameters:")
for name, param in corrected.named_parameters():
    if 'bias' not in name:
        print(param.data.numpy())

### Comparison on a new sample

Let's look at the full sample plots we have already encountered above for a new sample, its learned reproduction and the corrected results.

In [None]:
base, orig, fair = utils.evaluate_on_new_sample(sem, 'Y', corrected, plot=True)

As we have already seen, the learned perfectly recovers the original sample from the analytical structural equation model. The fair results coincide up to the target value `'Y'` of course, because we did not touch any other part. The dependence of `'Y'` on both `'P'` and `'X'` has been decreased, but is **not** zero (see next section for an explanation).

### Evaluation tools for linear prediction

In the linear case, we can also look at (print and plot) all sorts of correlations, i.e. the slopes, r-values (Pearson Correlation Coefficient), p-values and standard errors of these tests.

We see that the correlation between `'Yfair'` and `'P'` goes down as compared to `'Y'` and `'P'`, but is **not** zero. There is still correlation bettwen `'Yfair'` and `'P'` left through the confounder `'A'`. This is the main difference to all "learning fair representation" approaches so far.

In [None]:
utils.print_correlations(orig, sem=sem, sources=['A', 'P', 'X'], targets=['Y', 'Yfair'])

In [None]:
all_vars = sem.vertices() + ['Yfair']
utils.plot_correlations(orig, sem=sem, sources=all_vars, targets=all_vars)

## Quick run through a binarized example

Now we go through the whole workflow from specifying a graph to the final evaluation (without unnecessary intermediate steps), where we binarize the value of `'P'`.

In [None]:
# Construct the graph
sem = SEM({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})

# Attach equations
for v in sem.roots():
    sem.attach_equation(v, lambda n: torch.randn(n, 1))
sem.attach_equation("P", lambda data: (1 * data['Np'] + 1 * data['A'] > 0.0).float())
sem.attach_equation("X", lambda data: 1 * data['A'] + 5 * data['P'] + 1 * data['Nx'])
sem.attach_equation("Y", lambda data: 1 * data['P'] + 1 * data['X'])

# Learn the equations (internally computes sample) and binarize the proxy P
orig = sem.learn_from_sample(hidden_sizes=(), epochs=50, binarize=['P'])
learned_sample = sem.predict_from_sample(orig_sample)

# Specify interventions, this time 4 bernoulli interventions with p=1/2
intervention_spec = {'P': {'bernoulli': [(0.5,), (0.5,), (0.5,), (0.5,)]}}
interventions = Interventions(sem, orig_sample, intervention_spec)

# Remove proxy discrimination
corrected = interventions.train_corrected(epochs=100, batchsize=64, biases=False)
                    
# Evaluate on new sample
base, orig, fair = utils.evaluate_on_new_sample(sem, 'Y', corrected, plot=True)

We can see that in this case, the predictions `Y` for `P = 1` have been drastically pushed down to closer match the predictions for `P = 0`. Since `X` was dominated by `P` (coefficient of `5` as compared to `1` for `A` and `Nx`), also the distributions of `Y` conditioned on `X` has changed drastically. The slight discrepancy of outcomes between `P = 0` and `P = 1` can be explained by the common confounding through `A`. However, I have not quantified that yet.

## Polynomials

In [None]:
# Parameters
n_sample = 8192

# Construct the graph
sem = SEM({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})

# Attach equations
for v in sem.roots():
    sem.attach_equation(v, lambda n: torch.randn(n, 1) * 0.1)
sem.attach_equation("P", lambda data: - 0.5 * data['Np'] + \
                                      0.5 * data['Np']**2 + \
                                      1.0 * data['A'] - \
                                      1.7 * data['A']**2 + \
                                      0.4 * data['A']**3)
sem.attach_equation("X", lambda data: 1.0 * data['A'] + \
                                      0.3 * data['A']**3 - \
                                      0.8 * data['P']**3 - \
                                      0.6 * data['Nx'] + \
                                      0.7 * data['Nx']**3)
sem.attach_equation("Y", lambda data: 0.2 * data['P'] - \
                                      0.6 * data['P']**2 + \
                                      0.5 * data['X'])

# Learn the equations (internally computes sample) and binarize the proxy P
hidden_sizes = {None: (128, 128)}
orig_sample = sem.learn_from_sample(hidden_sizes=hidden_sizes, epochs=25, sample=n_sample, batchsize=64, weight_decay=0.0001)
learned_sample = sem.predict_from_sample(orig_sample)

# Specify interventions, this time 4 bernoulli interventions with p=1/2
intervention_spec = {'P': {'rand': [(-1,1), (-1,1), (-1,1), (-1,1)]}}
interventions = Interventions(sem, orig_sample, intervention_spec)

# Remove proxy discrimination
corrected = interventions.train_corrected(epochs=50, batchsize=64, biases=False, weight_decay=0.0001)
                    
# Evaluate on new sample
n_sample_test = 1024
base, orig, fair = utils.evaluate_on_new_sample(sem, 'Y', corrected, plot=True, n_sample=n_sample_test)

In [None]:
# Construct the same graph
sem_fair = SEM({"Np": None, "A": None, "Nx": None, "P": ["Np", "A"], "X": ["A", "P", "Nx"], "Y": ["P", "X"]})

# Attach fair equations
for v in sem.roots():
    sem_fair.attach_equation(v, lambda n: torch.randn(n, 1) * 0.1)
sem_fair.attach_equation("P", lambda data: - 0.5 * data['Np'] + \
                                           0.5 * data['Np']**2 + \
                                           1.0 * data['A'] - \
                                           1.7 * data['A']**2 + \
                                           0.4 * data['A']**3)
sem_fair.attach_equation("X", lambda data: 1.0 * data['A'] + \
                                           0.3 * data['A']**3 - \
                                           0.8 * data['P']**3 - \
                                           0.6 * data['Nx'] + \
                                           0.7 * data['Nx']**3)
sem_fair.attach_equation("Y", lambda data: 0.5 * data['X'] + \
                                           0.4 * data['P']**3)

# Learn the equations (internally computes sample) and binarize the proxy P
true_fair = sem_fair.sample(n_sample_test)

pred_fair = sem.predict_from_sample(true_fair, update='Y', replace={'Y': corrected})

utils.plot_samples(sem, [base, true_fair, pred_fair], legend=['base', 'true_fair', 'pred_fair'], alpha=0.5)


### Evaluate in 3D Plots?

In [None]:
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np

In [None]:
s1 = true_fair
s2 = pred_fair
y1 = s1['Y'].numpy().squeeze()
y2 = s2['Y'].numpy().squeeze()
p1 = s1['P'].numpy().squeeze()
p2 = s2['P'].numpy().squeeze()
x1 = s1['X'].numpy().squeeze()
x2 = s2['X'].numpy().squeeze()
dat1 = np.hstack((p1[:, np.newaxis], x1[:, np.newaxis], y1[:, np.newaxis]))
dat2 = np.hstack((p2[:, np.newaxis], x2[:, np.newaxis], y2[:, np.newaxis]))

In [None]:
fig = plt.figure(figsize=(15,15))
ax = fig.gca(projection='3d')
ax.scatter3D(p1, x1, y1, alpha=0.3)
ax.scatter3D(p2, x2, y2, alpha=0.3)

# Experiments

## Adult data set

In [None]:
import pandas as pd
import numpy as np

In [None]:
def import_adult(to_drop=['fnlwgt', 'education']):
    """Import the relevant parts of the adult data set.
    
        We drop 'fnlwgt', because it is externally computed and highly predictive of the income.
        We drop 'education', because it is redundantly encoded in 'education-num'.
    
    """
    base = '/Users/nkilbertus/ownCloud/PhD/projects/others_fairness/fairtest/data/'
    filename = base + 'adult/adult.csv'
    
    print("Importing adult data set...")
    data = pd.read_csv(filename, engine='python', delimiter="\s*,\s*", na_values=['?', ' ?', '? ', ' ? '])

    if to_drop:
        for attribute in to_drop:
            data = data.drop(attribute, axis=1)

    print("Data set size: ", len(data))
    print("Dropped: ", to_drop)
    print("Remove nan values...", end=' ')
    data = data.dropna()
    print("DONE")
    print("New data set size: ", len(data))
    print("DONE import adult data set")
    return data

def preprocess(data, features):
    to_one_hot = []
    for v in features:
        if not any([not isinstance(x, str) for x in set(data[v])]):
            print('{}: string, categorical'.format(v))
            pool_small_groups(data[v], inplace=True)
            if not try_binarize(data[v]):
                to_one_hot.append(v)
        elif not any([int(x) != x for x in set(data[v])]):
            print('{}: int, categorical'.format(v))
        elif not any([float(x) != x for x in set(data[v])]):
            print('{}: float'.format(v))
        else:
            raise ValueError('Could not determine type of column {}'.format(v))
        print()
    print("Encode the following columns to one-hot:")
    print(to_one_hot)
    return pd.get_dummies(data, prefix=to_one_hot, columns=to_one_hot)
            
def pool_small_groups(data, min_examples=100, inplace=True):
    counts = data.value_counts()
    too_small = [k for (k, v) in counts.items() if v < min_examples]
    
    if len(too_small) > 1:
        name = data.name + '_rest'
        print("The following values contain fewer than {} examples and are combined into {}:".format(min_examples, name))
        print(too_small)
        data.replace(too_small, name, inplace=inplace)

    if not inplace:
        return data

def try_binarize(data):
    n_values = len(set(data.values))
    binarized = False
    if n_values <= 2:
        print("Binarize values for {} to {}".format(data.name, list(range(n_values))))
        data.replace(set(data.values), range(n_values), inplace=True)
        binarized = True
    return binarized

def get_sample(data, features):
    import torch
    sample = {}
    for v in features:
        same_start = data.columns[d.columns.str.startswith(v)]
        print("{} is {}-dimensional".format(v, len(same_start)))
        sample[v] = torch.from_numpy(d.loc[:, d.columns.str.startswith(v)].values).float()
    return sample

In [None]:
data = import_adult()
features = list(data.columns)
protected = ['race']
use = list(set(features) - set(protected))
print("All features: ", features)

In [None]:
d = preprocess(data, features)

In [None]:
sample = get_sample(d, features)

In [None]:
binarize = ['Workclass', 'marital-status', 'occupation', 'relationship', 'race', 'sex', 'native-country', 'income']

g = {
    'age': None,
    'native-country': None,
    'sex': None,
    'race': ['native-country'],
    'relationship': ['sex', 'education', 'age'],
    'marital-status': ['relationship', 'sex', 'age', 'native-country', 'race'],
    'hours-per-week': ['education', 'marital-status', 'relationship', 'native-country', 'race', 'occupation', 'Workclass'],
    'education': ['native-country', 'race', 'sex', 'age'],
    'occupation': ['education', 'age', 'native-country', 'race', 'marital-status', 'relationship', 'sex', 'Workclass'],
    'Workclass': ['education', 'age', 'native-country', 'race', 'marital-status', 'relationship', 'sex'],
    'income': ['education', 'age', 'native-country', 'race', 'marital-status', 'sex', 'Workclass', 'occupation', 'hours-per-week']
}

In [None]:
sem = SEM(g)
sem.summary()
sem.draw()

In [None]:
# Learn the equations (internally computes sample) and binarize the proxy P
hidden_sizes = {None: (256)}
sem.learn_from_sample(sample=sample, hidden_sizes=hidden_sizes, epochs=10, batchsize=256, weight_decay=0.0001, binarize=binarize)
learned_sample = sem.predict_from_sample(sample)

In [None]:
# Specify interventions, this time 4 bernoulli interventions with p=1/2
intervention_spec = {'P': {'rand': [(-1,1), (-1,1), (-1,1), (-1,1)]}}
interventions = Interventions(sem, orig_sample, intervention_spec)

# Remove proxy discrimination
corrected = interventions.train_corrected(epochs=50, batchsize=64, biases=False, weight_decay=0.0001)
                    
# Evaluate on new sample
n_sample_test = 1024
base, orig, fair = utils.evaluate_on_new_sample(sem, 'Y', corrected, plot=True, n_sample=n_sample_test)

## MISC

### Evaluation tools for binary target

In [None]:
import copy
from sklearn.metrics import confusion_matrix

In [None]:
s1 = copy.deepcopy(orig)
s2 = copy.deepcopy(fair)
s1['Y'] = (s1['Y'] > 0.5).float()
s2['Y'] = (s2['Y'] > 0.5).float()

In [None]:
utils.plot_samples(sem, [base, s1, s2], legend=['analytical', 'learned', 'fair'], alpha=0.3)

In [None]:
confusion_matrix(s1['Y'].int().numpy(), s2['Y'].int().numpy())

### Something with MMD

In [None]:
# MMD?
from sklearn.preprocessing import scale
from scipy.spatial.distance import cdist

def median_heuristic(x):
    d = cdist(x, x, 'sqeuclidean')
    return 1. / np.median(d.ravel())

def rbf(x1, x2, gamma=1):
    d = cdist(x1, x2, 'sqeuclidean')
    return np.exp(-gamma * d)

def mmd(x, y, gamma=1):
    Kxx = rbf(x, x, gamma)
    Kyy = rbf(y, y, gamma)
    Kxy = rbf(x, y, gamma)
    m, n = Kxx.shape[0], Kyy.shape[0]
    t1  = (1. / (m*(m-1))) * np.sum(Kxx - np.diag(np.diagonal(Kxx)))
    t2  = (2. / (m*n)) * np.sum(Kxy)
    t3  = (1. / (n*(n-1))) * np.sum(Kyy - np.diag(np.diagonal(Kyy)))
    return t1 - t2 + t3

In [None]:
mmd(dat1, dat2)

In [None]:
from sklearn.metrics import pairwise_kernels
from sys import stdout


def MMD2u(K, m, n):
    """The MMD^2_u unbiased statistic.
    """
    Kx = K[:m, :m]
    Ky = K[m:, m:]
    Kxy = K[:m, m:]
    return 1.0 / (m * (m - 1.0)) * (Kx.sum() - Kx.diagonal().sum()) + \
        1.0 / (n * (n - 1.0)) * (Ky.sum() - Ky.diagonal().sum()) - \
        2.0 / (m * n) * Kxy.sum()


def compute_null_distribution(K, m, n, iterations=10000, verbose=False,
                              random_state=None, marker_interval=1000):
    """Compute the bootstrap null-distribution of MMD2u.
    """
    if type(random_state) == type(np.random.RandomState()):
        rng = random_state
    else:
        rng = np.random.RandomState(random_state)

    mmd2u_null = np.zeros(iterations)
    for i in range(iterations):
        if verbose and (i % marker_interval) == 0:
            print(i),
            stdout.flush()
        idx = rng.permutation(m+n)
        K_i = K[idx, idx[:, None]]
        mmd2u_null[i] = MMD2u(K_i, m, n)

    if verbose:
        print("")

    return mmd2u_null


def compute_null_distribution_given_permutations(K, m, n, permutation,
                                                 iterations=None):
    """Compute the bootstrap null-distribution of MMD2u given
    predefined permutations.
    Note:: verbosity is removed to improve speed.
    """
    if iterations is None:
        iterations = len(permutation)

    mmd2u_null = np.zeros(iterations)
    for i in range(iterations):
        idx = permutation[i]
        K_i = K[idx, idx[:, None]]
        mmd2u_null[i] = MMD2u(K_i, m, n)

    return mmd2u_null


def kernel_two_sample_test(X, Y, kernel_function='rbf', iterations=10000,
                           verbose=False, random_state=None, **kwargs):
    """Compute MMD^2_u, its null distribution and the p-value of the
    kernel two-sample test.
    Note that extra parameters captured by **kwargs will be passed to
    pairwise_kernels() as kernel parameters. E.g. if
    kernel_two_sample_test(..., kernel_function='rbf', gamma=0.1),
    then this will result in getting the kernel through
    kernel_function(metric='rbf', gamma=0.1).
    """
    m = len(X)
    n = len(Y)
    XY = np.vstack([X, Y])
    K = pairwise_kernels(XY, metric=kernel_function, **kwargs)
    mmd2u = MMD2u(K, m, n)
    if verbose:
        print("MMD^2_u = %s" % mmd2u)
        print("Computing the null distribution.")

    mmd2u_null = compute_null_distribution(K, m, n, iterations,
                                           verbose=verbose,
                                           random_state=random_state)
    p_value = max(1.0/iterations, (mmd2u_null > mmd2u).sum() /
                  float(iterations))
    if verbose:
        print("p-value ~= %s \t (resolution : %s)" % (p_value, 1.0/iterations))

    return mmd2u, mmd2u_null, p_value


In [None]:
kernel_two_sample_test(dat1, dat2, verbose=True, iterations=1000)