### Tests for MMD

In [8]:
import sys
import torch
import matplotlib.pyplot as plt
sys.path.append('..')
import utils

n_x = 1000
n_y = 1000
d = 128
min_n = min(n_x, n_y)

# Create random data
torch.manual_seed(1)
X1 = torch.randn(n_x, d)
Y1 = torch.randn(n_y, d)

# verify kernel size as n_x x n_y
k = lambda x, y: utils.rq_kernel(x, y, alpha=1)
K = k(X1, Y1)
assert K.size() == (n_x, n_y)

# check all nonnegative
assert (K >= 0).all()

mmd1 = utils.mmd2(X1, Y1, kernel=k)
print('Same Distributions, RQ Kernel: MMD^2: {:.8f}'.format(mmd1.item()))

# do mmd for different distributions
X2 = torch.randn(n_x, d)
Y2 = 1.1*torch.randn(n_y, d) + 0.1
mmd2 = utils.mmd2(X2, Y2, kernel=k)

print('Diff Distributions, RQ Kernel: MMD^2: {:.8f}'.format(mmd2.item()))

assert mmd2 > mmd1

# test mixed kernel
alphas = [0.1, 0.2, 0.5, 1.0]
k = lambda x, y: utils.mixture_rq_kernel(x, y, alphas=alphas)

mmd3 = utils.mmd2(X1, Y1, kernel=k)
print('Same Distributions, Mixed Kernel: MMD^2: {:.8f}'.format(mmd3.item()))

mmd4 = utils.mmd2(X2, Y2, kernel=k)
print('Diff Distributions, Mixed Kernel: MMD^2: {:.8f}'.format(mmd4.item()))

assert mmd4 > mmd3

Same Distributions, RQ Kernel: MMD^2: 0.00000234
Diff Distributions, RQ Kernel: MMD^2: 0.00019234
Same Distributions, Mixed Kernel: MMD^2: 0.00004208
Diff Distributions, Mixed Kernel: MMD^2: 0.00285149
