# High-dimensional MMD

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gamma
import pickle
from sklearn.metrics import pairwise_distances
from sklearn.metrics import pairwise_kernels

In [None]:
# dimensions
dim = 10
 
# how many samples
sample_size = 256
sample_points = np.linspace(0, 1, dim)

# Generating high-dimensional distributions 

## Generating process for mean shift

In [None]:
# generating distributions over time for mean shift

def gen_mean_shift(sample_size, sample_points, delta=0, sd=np.sqrt(0.25), random_state=None):
    """    
    sample_size : number of function samples
    sample_points : observation points
    delta : the coefficient of X^3
    sd : the standard deviation of the observation noise
    """
    rng = np.random.RandomState(random_state)
    n_points = len(sample_points)
    X = rng.normal(0, np.sqrt(10), (sample_size,1)) * np.sqrt(2) * np.sin(2*np.pi*sample_points) + rng.normal(0, np.sqrt(5), (sample_size,1)) * np.sqrt(2) * np.cos(2*np.pi*sample_points)    # Fourier basis functions
    X += sample_points + delta * sample_points**3    # adding mean function
    X += rng.normal(0, sd, (sample_size,n_points))    # adding noise epsilon
    return X

## Generating process for variance shift

In [None]:
# generating distributions over time for variance shift

def gen_variance_shift(sample_size, sample_points, delta=0, sd=np.sqrt(0.25), random_state=None):
    """    
    sample_size : number of function samples
    sample_points : observation points
    delta : controls the variance of the sin term.
    sd : the standard deviation of the observation noise
    """
    rng = np.random.RandomState(random_state)
    n_points = len(sample_points)
    X = rng.normal(0, np.sqrt(10+delta), (sample_size,1)) * np.sqrt(2) * np.sin(2*np.pi*sample_points) + rng.normal(0, np.sqrt(5), (sample_size,1)) * np.sqrt(2) * np.cos(2*np.pi*sample_points)
    X += rng.normal(0, sd, (sample_size,n_points))    # adding noise epsilon
    return X

### Median heuristic

In [None]:
# median heuristic for kernel width
def width(Z):
    dist_mat = pairwise_distances(Z, metric='euclidean')
    width_Z = np.median(dist_mat[dist_mat > 0])
    
    return width_Z

------------------------------

# Statistical test based on MMD
We test statistically whether $\mathcal{H}_0 : P_X = P_Y$ holds true.

## MMD with permutations

In [None]:
def MMD_permutations(X, Y, alpha, width_XY, shuffle): # set widths to -1 for median heuristics
    
    m = X.shape[0]
    n = Y.shape[0]
    
    # median heuristics for kernel width
    if width_XY == -1:
        width_XY = width(np.concatenate([X, Y]))   # aggregating samples
    
    # compute Gram matrices
    K = pairwise_kernels(X, X, metric='rbf', gamma=0.5/(width_XY**2))
    L = pairwise_kernels(Y, Y, metric='rbf', gamma=0.5/(width_XY**2))
    KL = pairwise_kernels(X, Y, metric='rbf', gamma=0.5/(width_XY**2))
    
    K_diag = K - np.diag(np.diagonal(K))
    L_diag = L - np.diag(np.diagonal(L))
    
    # biased test statistic
    #stat = 1/m * (np.sum(K + L - KL - KL.T))
    
    # unbiased test statistic
    stat = (1/(m*(m-1)))*np.sum(K-K_diag) + (1/(n*(n-1)))*np.sum(L-L_diag) - (2/(m*n))*np.sum(KL)
    
    Kz = np.concatenate((np.concatenate((K, KL), axis=1), np.concatenate((KL.T, L), axis=1)), axis=0)
    
    # initiating MMD
    MMD_arr = np.zeros(shuffle)
    
    # create permutations by reshuffling L except the main diagonal
    for sh in range(shuffle):
        index_perm = np.random.permutation(Kz.shape[0])
        Kz_perm = Kz[index_perm, index_perm[:, None]]
        
        K = Kz_perm[:m, :m]
        L = Kz_perm[m:, m:]
        KL = Kz_perm[:m, m:]
        
        K_diag = K - np.diag(np.diagonal(K))
        L_diag = L - np.diag(np.diagonal(L))
        
        # biased
        #MMD_arr[sh] = 1/m * (np.sum(K + L - KL - KL.T))
        
        # unbiased
        MMD_arr[sh] = (1/(m*(m-1)))*np.sum(K-K_diag) + (1/(n*(n-1)))*np.sum(L-L_diag) - (2/(m*n))*np.sum(KL)
        
    MMD_arr_sort = np.sort(MMD_arr)
    
    # computing 1-alpha threshold
    threshold = MMD_arr_sort[round((1-alpha)*shuffle)]
        
    """
    if stat > threshold:
        print('H0 rejected')
    else:
        print('H0 accepted')
    """
    
    return stat, threshold

## MMD with Gamma distribution approximation

In [None]:
def MMD_gamma(X, Y, alpha, width_XY):    # set widths to -1 for median heuristics
    
    m = X.shape[0]
    n = Y.shape[0]
    
    # median heuristics for kernel width
    if width_XY == -1:
        width_XY = width(np.concatenate([X, Y]))   # aggregating samples
    
    # compute Gram matrices
    K = pairwise_kernels(X, X, metric='rbf', gamma=0.5/(width_XY**2))
    L = pairwise_kernels(Y, Y, metric='rbf', gamma=0.5/(width_XY**2))
    KL = pairwise_kernels(X, Y, metric='rbf', gamma=0.5/(width_XY**2))
    
    K_diag = K - np.diag(np.diagonal(K))
    L_diag = L - np.diag(np.diagonal(L))
    KL_diag = KL - np.diag(np.diagonal(KL))
    
    # biased test statistic
    #stat = 1/m * (np.sum(K + L - KL - KL.T))
    
    # unbiased test statistic
    stat = (1/(m*(m-1)))*np.sum(K-K_diag) + (1/(n*(n-1)))*np.sum(L-L_diag) - (2/(m*n))*np.sum(KL)
    
    # fitting Gamma distribution to stat
    mMMD = 2/m * (1 - 1/m * np.trace(KL))    # mean under H0
    
    varMMD = 2/(m*(m-1)) * 1/(m*(m-1)) * np.sum(np.power((K_diag + L_diag - KL_diag - KL_diag.T), 2))    # variance under H0
    
    al = mMMD**2 / varMMD
    bet = varMMD * m / mMMD
    
    # computing 1-alpha threshold
    threshold = gamma.ppf(1-alpha, al, scale=bet)
    
    """
    if stat > threshold:
        print('H0 rejected')
    else:
        print('H0 accepted')
    """
    
    return stat, threshold

## Power estimation

We estimate the statistical power based on 200 replications for each setting. Our experiment settings compose of various dimensions, sample sizes, mean shifts `delta_m`, and variance shifts `delta_var`.

In [None]:
# dimensions
dims = [5, 10, 25, 50, 100]

# sample sizes
sample_sizes = [100, 200, 300, 500]

# mean and variance shifts
delta_m = np.linspace(0, 8, 17)
delta_var = np.linspace(0, 32, 33)

### Power estimation for mean shift

In [None]:
MMD_p_m = {}
MMD_g_m = {}

for dim in dims:
    print('Dimensions:', dim)
    sample_points = np.linspace(0, 1, dim)
    
    for sample_size in sample_sizes:
        print('Sample size:', sample_size)
        for delta in delta_m:
            print('delta:', delta)
                
            MMD_p_m_list = []
            MMD_g_m_list = []

            # repeating 200 times
            for i in range(200):

                # defining X
                X = gen_mean_shift(sample_size, sample_points, delta=0)    # delta=0 for X

                # defining Y
                Y = gen_mean_shift(sample_size, sample_points, delta=delta)    # delta=delta for Y

                # test level alpha = 0.05, 5000 permutations
                MMD_p_m_list.append(MMD_permutations(X, Y, 0.05, -1, 5000))

                # test level alpha = 0.05
                MMD_g_m_list.append(MMD_gamma(X, Y, 0.05, -1))

            MMD_p_m[(dim, sample_size, delta)] = MMD_p_m_list
            MMD_g_m[(dim, sample_size, delta)] = MMD_g_m_list

In [None]:
# saving
m_shift_p = open('mean_shifts_p_{}.pkl'.format(dims), 'wb')
pickle.dump(MMD_p_m, m_shift_p)
m_shift_p.close()

m_shift_g = open('mean_shifts_g_{}.pkl'.format(dims), 'wb')
pickle.dump(MMD_g_m, m_shift_g)
m_shift_g.close()

### Power estimation for variance shift

In [None]:
MMD_p_var = {}
MMD_g_var = {}

for dim in dims:
    print('Dimensions:', dim)
    sample_points = np.linspace(0, 1, dim)
    
    for sample_size in sample_sizes:
        print('Sample size:', sample_size)
        for delta in delta_var:
            print('delta:', delta)
                
            MMD_p_var_list = []
            MMD_g_var_list = []

            # repeating 200 times
            for i in range(200):

                # defining X
                X = gen_variance_shift(sample_size, sample_points, delta=0)    # delta=0 for X

                # defining Y
                Y = gen_variance_shift(sample_size, sample_points, delta=delta)    # delta=delta for Y

                # test level alpha = 0.05, 5000 permutations
                MMD_p_var_list.append(MMD_permutations(X, Y, 0.05, -1, 5000))

                # test level alpha = 0.05
                MMD_g_var_list.append(MMD_gamma(X, Y, 0.05, -1))

            MMD_p_var[(dim, sample_size, delta)] = MMD_p_var_list
            MMD_g_var[(dim, sample_size, delta)] = MMD_g_var_list

In [None]:
# saving
var_shift_p = open('var_shifts_p_{}.pkl'.format(dims), 'wb')
pickle.dump(MMD_p_var, var_shift_p)
m_shift_p.close()

var_shift_g = open('var_shifts_g_{}.pkl'.format(dims), 'wb')
pickle.dump(MMD_g_var, var_shift_g)
var_shift_g.close()

## Maximising test power

In [None]:
def var_MMD(X, Y):
    m = X.shape[0]
    K = pairwise_kernels(X, X, metric='rbf', gamma=0.5/(1**2))
    L = pairwise_kernels(Y, Y, metric='rbf', gamma=0.5/(1**2))
    KL = pairwise_kernels(X, Y, metric='rbf', gamma=0.5/(1**2))
    K_diag = np.diag(K)
    L_diag = np.diag(L) 
    
    K_sums = np.sum(K, 1) - K_diag
    L_sums = np.sum(L, 1) - L_diag

    K_sum = np.sum(K_sums)
    L_sum = np.sum(L_sums)

    KL_sums_0 = np.sum(KL, 0)
    KL_sums_1 = np.sum(KL, 1)

    KL_sum = np.sum(KL_sums_0)

    K_diag_sum = np.sum(K_diag)
    L_diag_sum = np.sum(L_diag)

    K_diag_sum2 = np.sum(np.power(K_diag, 2))
    L_diag_sum2 = np.sum(np.power(L_diag, 2))

    K_sqsum = np.sum(np.power(K_sums, 2))
    L_sqsum = np.sum(np.power(L_sums, 2))
    KL_sqsum_0 = np.sum(np.power(KL_sums_0, 2))
    KL_sqsum_1 = np.sum(np.power(KL_sums_1, 2))

    K_2_sqsum = np.sum(np.power(K, 2)) - K_diag_sum2
    L_2_sqsum = np.sum(np.power(L, 2)) - L_diag_sum2
    KL_2_sqsum = np.sum(np.power(KL, 2))
    
    var_MMD = 2 / (m**2 * (m-1)**2) * (2*K_sqsum - K_2_sqsum + 2*L_sqsum - L_2_sqsum) - (4*m-6) / (m**3 * (m-1)**3) * (K_sum**2 + L_sum**2) + 4*(m-2) / (m**3 * (m-1)**2) * (KL_sqsum_1 + KL_sqsum_0) - 4 * (m-3) / (m**3 * (m-1)**2) * KL_2_sqsum - (8*m - 12) / (m**5 * (m-1)) * KL_sum**2 + 8 / (m**3 * (m-1)) * (1/m * (K_sum + L_sum) * KL_sum - np.dot(K_sums, KL_sums_1) - np.dot(L_sums, KL_sums_0))
    
    return var_MMD

In [None]:
def maximise(MMD, MMD_var, threshold, m):
    ratio = MMD / np.sqrt(MMD_var) - threshold / (m*np.sqrt(MMD_var))
    return ratio

## Power estimation

We estimate the statistical power based on 200 replications for each setting. Our experiment settings compose of various dimensions, sample sizes, mean shifts `delta_m`, and variance shifts `delta_var`. We iterate over pre-defined search spaces for the optimal Gaussian kernel bandwidth $\sigma$.

In [None]:
# dimensions
dims = [5, 10, 25, 50, 100]

# sample sizes
sample_sizes = [100, 200, 300, 500]

# mean and variance shifts
delta_m = np.linspace(0, 8, 17)
delta_var = np.linspace(0, 32, 33)

### Power estimation for mean shift

In [None]:
MMD_p_m = {}
MMD_g_m = {}

for dim in dims:
    print('dimensions:', dim)
    sample_points = np.linspace(0, 1, dim)
    
    for sample_size in sample_sizes:
        print('sample size:', sample_size)
        for delta in delta_m:
            print('delta:', delta)
                
            MMD_p_m_list = []
            MMD_g_m_list = []
            
            # repeating 200 times
            for i in range(200):
        
                # defining X
                X_train = gen_mean_shift(sample_size, sample_points, delta=0)    # delta=0 for X
                X_test = gen_mean_shift(sample_size, sample_points, delta=0) 

                # defining Y
                Y_train = gen_mean_shift(sample_size, sample_points, delta=delta)    # delta=delta for Y
                Y_test = gen_mean_shift(sample_size, sample_points, delta=delta)

                m = X_train.shape[0]

                # sigma is dependent on delta
                if 0 <= delta <= 2:
                    sigmas = np.linspace(1, 21, 11)
                elif 2 < delta <= 3:
                    sigmas = np.linspace(6, 26, 11)
                elif 3 < delta <= 5:
                    sigmas = np.linspace(11, 31, 11)
                elif 5 < delta <= 8:
                    sigmas = np.linspace(16, 36, 11)

                ratios = []

                for sigma in sigmas:
                    MMD, threshold = MMD_permutations(X_train, Y_train, 0.05, sigma, 5000)
                    MMD_var = var_MMD(X_train, Y_train)
                    ratios.append(maximise(MMD, MMD_var, threshold, m))

                # sigma of maximum ratio
                sigma_max = sigmas[np.argmax(ratios)]

                # test level alpha = 0.05, 5000 permutations
                MMD_p_m_list.append(MMD_permutations(X_test, Y_test, 0.05, sigma_max, 5000))

                # test level alpha = 0.05
                MMD_g_m_list.append(MMD_gamma(X_test, Y_test, 0.05, sigma_max))
            
            MMD_p_m[(dim, sample_size, delta)] = MMD_p_m_list
            MMD_g_m[(dim, sample_size, delta)] = MMD_g_m_list

In [None]:
# saving
m_shift_p = open('mean_shifts_p_{}_max.pkl'.format(dims), 'wb')
pickle.dump(MMD_p_m, m_shift_p)
m_shift_p.close()

m_shift_g = open('mean_shifts_g_{}_max.pkl'.format(dims), 'wb')
pickle.dump(MMD_g_m, m_shift_g)
m_shift_g.close()

### Power estimation for variance shift

In [None]:
MMD_p_m = {}
MMD_g_m = {}

for dim in dims:
    print('dimensions:', dim)
    sample_points = np.linspace(0, 1, dim)
    
    for sample_size in sample_sizes:
        print('sample size:', sample_size)
        for delta in delta_var:
            print('delta:', delta)
                
            MMD_p_m_list = []
            MMD_g_m_list = []
            
            # repeating 200 times
            for i in range(200):
        
                # defining X
                X_train = gen_variance_shift(sample_size, sample_points, delta=0)    # delta=0 for X
                X_test = gen_variance_shift(sample_size, sample_points, delta=0) 

                # defining Y
                Y_train = gen_variance_shift(sample_size, sample_points, delta=delta)    # delta=delta for Y
                Y_test = gen_variance_shift(sample_size, sample_points, delta=delta)

                m = X_train.shape[0]

                # sigma is dependent on delta
                if 0 <= delta <= 4:
                    sigmas = np.linspace(10, 30, 11)
                elif 4 < delta <= 14:
                    sigmas = np.linspace(20, 40, 11)
                elif 14 < delta <= 32:
                    sigmas = np.linspace(30, 50, 11)

                ratios = []

                for sigma in sigmas:
                    MMD, threshold = MMD_permutations(X_train, Y_train, 0.05, sigma, 5000)
                    MMD_var = var_MMD(X_train, Y_train)
                    ratios.append(maximise(MMD, MMD_var, threshold, m))

                # sigma of maximum ratio
                sigma_max = sigmas[np.argmax(ratios)]

                # test level alpha = 0.05, 5000 permutations
                MMD_p_m_list.append(MMD_permutations(X_test, Y_test, 0.05, sigma_max, 5000))

                # test level alpha = 0.05
                MMD_g_m_list.append(MMD_gamma(X_test, Y_test, 0.05, sigma_max))
            
            MMD_p_m[(dim, sample_size, delta)] = MMD_p_m_list
            MMD_g_m[(dim, sample_size, delta)] = MMD_g_m_list

In [None]:
# saving
var_shift_p = open('var_shifts_p_{}_max.pkl'.format(dims), 'wb')
pickle.dump(MMD_p_var, var_shift_p)
m_shift_p.close()

var_shift_g = open('var_shifts_g_{}_max.pkl'.format(dims), 'wb')
pickle.dump(MMD_g_var, var_shift_g)
var_shift_g.close()