# Sample notebook of Permutation test with MMD

In [1]:
import sys
sys.path.append("../")
sys.path.append(".")
import numpy as np
import torch
from model_criticism_mmd import MMD
from model_criticism_mmd.backends.kernels_torch import BasicRBFKernelFunction
from model_criticism_mmd.supports.permutation_tests import PermutationTest
from model_criticism_mmd.models import TwoSampleDataSet
from model_criticism_mmd import ModelTrainerTorchBackend, MMD, TwoSampleDataSet, split_data
from model_criticism_mmd.backends import kernels_torch



Next, we set dataset. The input type into the Permutation class is `TwoSampleDataSet`.
We can set either `numpy.ndarray` or `torch.tensor`.

In [2]:
np.random.seed(seed=1)
x = np.random.normal(3, 0.5, size=(500, 2))
y = np.random.normal(3, 0.5, size=(500, 2))

Then, we run the Permutation test.

In [7]:
init_scale = torch.tensor(np.array([0.05, 0.55]))
device_obj = torch.device(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
kernel_function = BasicRBFKernelFunction(log_sigma=0.0, device_obj=device_obj, opt_sigma=True)
mmd_estimator = MMD(kernel_function_obj=kernel_function, device_obj=device_obj, scales=init_scale)
dataset_train = TwoSampleDataSet(x, y, device_obj)

permutation_tester = PermutationTest(is_normalize=True, mmd_estimator=mmd_estimator, dataset=dataset_train)
statistics = permutation_tester.compute_statistic()
threshold = permutation_tester.compute_threshold(alpha=0.05)
p_value = permutation_tester.compute_p_value(statistics)
print(f'statistics: {statistics}, threshold: {threshold}, p-value: {p_value}')
if p_value > 0.05:
    print('Same distribution!')
else:
    print('Probably different distribution!')

100%|██████████| 1000/1000 [00:10<00:00, 91.05it/s]

statistics: 0.005706118359427581, threshold: 0.26779616255229177, p-value: 0.894
Same distribution!





## Permutation test with optimized kernels

To run the permutation test, we have to define a MMD estimator who has a designed kernel function.

In normal cases, we search the optimal kernel on the given datset (i.e. trainings, optimizations).

In [5]:
n_train = 400
x_train = x[:n_train]
y_train = y[:n_train]
x_test = x[n_train:]
y_test = y[n_train:]
dataset_val = TwoSampleDataSet(x_test, y_test, device_obj=device_obj)

In [6]:
init_scale = torch.tensor(np.array([0.05, 0.55]))
kernel_function = BasicRBFKernelFunction(log_sigma=0.0, device_obj=device_obj, opt_sigma=True)
mmd_estimator = MMD(kernel_function_obj=kernel_function, device_obj=device_obj, scales=init_scale)
trainer = ModelTrainerTorchBackend(mmd_estimator=mmd_estimator, device_obj=device_obj)
trained_obj = trainer.train(dataset_training=dataset_train, dataset_validation=dataset_val, num_epochs=500, batchsize=200)

2021-08-05 14:43:14,024 - model_criticism_mmd.logger_unit - INFO - Validation at 0. MMD^2 = 0.006782989967585085, ratio = [1.33677] obj = [-0.29025626]
2021-08-05 14:43:14,386 - model_criticism_mmd.logger_unit - INFO -      5: [avg train] MMD^2 0.001336887696456938 obj [-2.38391235] val-MMD^2 0.0022629932290427757 val-ratio [22.62993229] val-obj [-3.11927347]  elapsed: 0.0
2021-08-05 14:43:15,724 - model_criticism_mmd.logger_unit - INFO -     25: [avg train] MMD^2 0.013374156857882182 obj [-4.24255325] val-MMD^2 0.01999721436950821 val-ratio [85.63179015] val-obj [-4.45005659]  elapsed: 0.0
2021-08-05 14:43:17,361 - model_criticism_mmd.logger_unit - INFO -     50: [avg train] MMD^2 0.013346477760713434 obj [-4.83768322] val-MMD^2 0.019990796354718437 val-ratio [199.90796355] val-obj [-5.29785708]  elapsed: 0.0
2021-08-05 14:43:20,554 - model_criticism_mmd.logger_unit - INFO -    100: [avg train] MMD^2 0.013346521481474754 obj [-4.83768504] val-MMD^2 0.019990946634031313 val-ratio [199.

Now, we have the trained MMD estimator.

In [8]:
trained_mmd_estimator = MMD(trained_obj.kernel_function_obj, device_obj=device_obj)

Finally, we run a permutation test. For that, we call a class named `PermutationTest`.

In [11]:
permutation_tester = PermutationTest(n_permutation_test=1000, 
                                     mmd_estimator=trained_mmd_estimator, 
                                     dataset=dataset_train, 
                                     batch_size=-1)
statistics = permutation_tester.compute_statistic()
threshold = permutation_tester.compute_threshold(alpha=0.05)
p_value = permutation_tester.compute_p_value(statistics)
print(f"MMD-statistics: {statistics}, threshold: {threshold}, p-value: {p_value}")
if p_value > 0.05:
    print('Same distribution!')
else:
    print('Probably different distribution!')

100%|██████████| 1000/1000 [00:17<00:00, 56.63it/s]

MMD-statistics: 0.004041565216796814, threshold: 0.004046224310567938, p-value: 0.07199999999999995
Same distribution!



