In [37]:
import torch
from MMDLearning.training.losses import MMDLoss, LinearizedMMDLoss, RBF

In [38]:
num_batches = 10
batch_size = 128

#dim of latent space
n_latent = 32

# params of X and Y distributions
mu_x = 0.0
sigma_x = 1.0

mu_y = 0.2
sigma_y = 1.0

# kernel parameters
n_kernels = 1
kernel_mul = 2.0


In [39]:
X = (torch.randn(num_batches, batch_size, n_latent) * sigma_x + mu_x).detach().requires_grad_()
Y = (torch.randn(num_batches, batch_size, n_latent) * sigma_y + mu_y).detach().requires_grad_()
print('is a leaf:', X.is_leaf, Y.is_leaf)

is a leaf: True True


In [40]:
kernel = RBF(mul_factor=kernel_mul, n_kernels=n_kernels)
mmd_orig = MMDLoss(kernel=kernel)

In [41]:
#number of fourier features for linearized MMD
n_feat = 10000

mmd_new = LinearizedMMDLoss(n_feat=n_feat, n_latent=n_latent, n_kernels=n_kernels, mul_factor=kernel_mul)

In [42]:
res_orig = [mmd_orig(X[i], Y[i]) for i in range(num_batches)]
res_new = [mmd_new(X[i], Y[i]) for i in range(num_batches)]
frac_diff = [((r1-r2)/r1) for r1, r2 in zip(res_orig, res_new)]


In [43]:
print("Original MMD results:")
print(res_orig)
print("Linearized MMD results:")
print(res_new)
print("fractional difference:")
print(frac_diff)

Original MMD results:
[tensor(0.0090, grad_fn=<AddBackward0>), tensor(0.0147, grad_fn=<AddBackward0>), tensor(0.0162, grad_fn=<AddBackward0>), tensor(0.0120, grad_fn=<AddBackward0>), tensor(0.0102, grad_fn=<AddBackward0>), tensor(0.0165, grad_fn=<AddBackward0>), tensor(0.0211, grad_fn=<AddBackward0>), tensor(0.0136, grad_fn=<AddBackward0>), tensor(0.0171, grad_fn=<AddBackward0>), tensor(0.0103, grad_fn=<AddBackward0>)]
Linearized MMD results:
[tensor(0.0090, grad_fn=<SumBackward0>), tensor(0.0146, grad_fn=<SumBackward0>), tensor(0.0162, grad_fn=<SumBackward0>), tensor(0.0119, grad_fn=<SumBackward0>), tensor(0.0103, grad_fn=<SumBackward0>), tensor(0.0166, grad_fn=<SumBackward0>), tensor(0.0211, grad_fn=<SumBackward0>), tensor(0.0137, grad_fn=<SumBackward0>), tensor(0.0172, grad_fn=<SumBackward0>), tensor(0.0104, grad_fn=<SumBackward0>)]
fractional difference:
[tensor(-0.0075, grad_fn=<DivBackward0>), tensor(0.0059, grad_fn=<DivBackward0>), tensor(-0.0024, grad_fn=<DivBackward0>), tensor

In [44]:
res_tot_orig = sum(res_orig)
res_tot_new = sum(res_new)
X.grad = None
Y.grad = None
res_tot_orig.backward(retain_graph=True)
X_grad_orig = X.grad.clone()
Y_grad_orig = Y.grad.clone()
X.grad = None
Y.grad = None
res_tot_new.backward()
X_grad_new = X.grad.clone()
Y_grad_new = Y.grad.clone()


In [45]:
X_max_diff = torch.max(torch.abs(X_grad_orig - X_grad_new))
X_grad_diff_std = torch.std(X_grad_orig-X_grad_new)
X_grad_std = torch.std(X_grad_orig)
print('max diff X:', X_max_diff)
print('std diff X:', X_grad_diff_std)
print('std X:', X_grad_std)

Y_max_diff = torch.max(torch.abs(Y_grad_orig - Y_grad_new))
Y_grad_diff_std = torch.std(Y_grad_orig-Y_grad_new)
Y_grad_std = torch.std(Y_grad_orig)
print('max diff Y:', Y_max_diff)
print('std diff Y:', Y_grad_diff_std)
print('std Y:', Y_grad_std)

max diff X: tensor(1.2518e-05)
std diff X: tensor(3.0542e-06)
std X: tensor(2.4192e-05)
max diff Y: tensor(1.3312e-05)
std diff Y: tensor(3.0473e-06)
std Y: tensor(2.4832e-05)


In [36]:
print(torch.mean(X_grad_orig/X_grad_new))

tensor(0.9992)
