# High-dimensional MMD

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gamma
import pickle

Different combinations of $(X, Y)$ to evaluate the approximation methods.

## Examples of one to three-dimensional distributions

In [None]:
#X = np.random.normal(0, 1, 256).reshape(-1,1)
#X = np.random.exponential(size=100).reshape(-1,1)
#X = np.random.chisquare(1, size=20).reshape(-1,1)
#X = np.random.multivariate_normal([0,1,2], [[1,1,1], [1,1,1], [1,1,1]], size=256)


#Y = np.random.normal(0, 1, 256).reshape(-1,1)
#Y = np.random.exponential(size=256).reshape(-1,1)
#Y = np.random.randn(20*128).reshape(20,-1)
#Y = np.random.multivariate_normal([0,1,2], [[1,1,1], [1,1,1], [1,1,1]], size=256)

## Examples of higher-dimensional distributions

`delta` and `gam` are the departures from $H_0: P_X = P_Y$ of MMD. If they are non-zero, the alternative hypothesis $H_1: P_X \neq P_Y$ is true.

In [None]:
# dimensions
dim = 10   # resulting dimensions is dim+1

# mean shift
delta = 0

# variance shift
gam = 0
 
# how many samples
sample_size = 256

# set how far to shift Y (+ shifts to the left, - shifts to the right)
shift_par = 0

print('Shifting for:', shift_par)

In [None]:
x1 = np.linspace(0,1, dim+1)
time1 = x1
print(time1)
plt.figure(figsize=(16,10))
plt.plot(time1, 'xr', markersize=16);
plt.show()

In [None]:
x2 = np.linspace(0, 1, dim+1)

# mean shift
time2 = time1 + delta * time1**3

print(time2)
plt.figure(figsize=(16,10))
plt.plot(time2, 'xr', markersize=16);
plt.show()

# Generating high-dimensional distributions 

## Generating process for mean shift

In [None]:
# generating distributions over time for mean shift

def gen_mean_shift(time1, time2, s, sample_size):
    t1d = np.empty((0, sample_size))
    t2d = np.empty((0, sample_size))

    np.random.seed(s)
    for t in time1:
        var = np.random.uniform(0,0.25,1)
        dist1 = np.random.normal(t, var, sample_size)
        t1d = np.append(t1d, [dist1], axis=0)
    
    np.random.seed(s)
    for t in time2:
        var = np.random.uniform(0,0.25,1)    # same variance due to random seed
        dist2 = np.random.normal(t, var, sample_size)
        t2d = np.append(t2d, [dist2], axis=0)


    # shifting time2
    if shift_par >= 0:
        t2d = t2d[shift_par:]
    elif shift_par < 0:
        t2d = t2d[:shift_par]

    time1d = t1d.T
    time2d = t2d.T
    
    return time1d, time2d

In [None]:
time1d, time2d = gen_mean_shift(time1, time2, s=1, sample_size=256)

## Generating process for variance shift

In [None]:
# generating distributions over time for variance shift

def gen_var_shift(time1, time2, s, sample_size, gam):
    t1d = np.empty((0, sample_size))
    t2d = np.empty((0, sample_size))

    np.random.seed(s)
    for t in np.zeros(dim+1):
        
        # distribution for time1
        var1 = np.random.uniform(0,0.25,1)
        dist1 = np.random.normal(t, var1, sample_size)
        t1d = np.append(t1d, [dist1], axis=0)
        
        # distribution for time2
        var2 = var1 + gam    # gam added to same variance
        dist2 = np.random.normal(t, var2, sample_size)
        t2d = np.append(t2d, [dist2], axis=0)


    # shifting time2
    if shift_par >= 0:
        t2d = t2d[shift_par:]
    elif shift_par < 0:
        t2d = t2d[:shift_par]

    time1d = t1d.T
    time2d = t2d.T
    
    return time1d, time2d

In [None]:
time1d, time2d = gen_var_shift(time1, time2, s=1, sample_size=256, gam=0)

### Plottings

In [None]:
# different dimensions
print(time1d.shape)
plt.figure(figsize=(16,10))
plt.axis([-1, time1d.shape[1], -1, 2])
plt.plot(time1d.T);
plt.show()

In [None]:
print(time2d.shape)
plt.figure(figsize=(16,10))
plt.axis([-1, time2d.shape[1], -1, 2])
plt.plot(time2d.T);
plt.show()

In [None]:
# same dimensions
if shift_par > 0:
    X = time1d[:, :-shift_par]
    Y = time2d
elif shift_par < 0:
    X = time1d[:, -shift_par:]
    Y = time2d
else:
    X = time1d
    Y = time2d

In [None]:
print(X.shape)
plt.figure(figsize=(16,10))
plt.axis([-1, X.shape[1], -1, 2])
plt.plot(X.T);
plt.show()

In [None]:
print(Y.shape)
plt.figure(figsize=(16,10))
plt.axis([-1, Y.shape[1], -1, 2])
plt.plot(Y.T);
plt.show()

### Auxiliary functions

In [None]:
# median heuristic for kernel width
def width(Z):
    # compute median for Z
    size_Z = Z.shape[0]
    if size_Z > 100:
        Z_med = Z[0:100]
        size_Z = 100
    else:
        Z_med = Z

    G_Z = np.sum(np.multiply(Z_med, Z_med), axis=1).reshape(-1,1)
    Q_Z = np.tile(G_Z, (1, size_Z))
    R_Z = np.tile(G_Z.T, (size_Z, 1))
    dists_Z = Q_Z + R_Z - 2 * Z_med @ Z_med.T
    distances_Z = (dists_Z - np.tril(dists_Z)).reshape(-1, 1)
    width_Z = np.sqrt(0.5*np.median(distances_Z[distances_Z>0]))
    d_Z = Z.shape[1]
    
    return d_Z * width_Z


# rbf dot product
def rbf_dot(X, Y, width):
    size_X = X.shape
    size_Y = Y.shape
    
    G = np.sum(np.multiply(X, X), axis=1).reshape(-1,1)
    H = np.sum(np.multiply(Y, Y), axis=1).reshape(-1,1)
    
    Q = np.tile(G, (1, size_Y[0]))
    R = np.tile(H.T, (size_X[0], 1))
    
    H = Q + R - 2 * X @ Y.T
    
    # rbf kernel
    K = np.exp(-H/(2*width**2))
    
    return K

---------------------------
#### Interlude: can X and Y have different sample sizes?

We test whether our implemented test statistic function `stat_d` produces the same results as the original function `stat` which only works for same sample size of __X__ and **Y**. Both are defined as two multi-variate Gaussians of the same dimension and same sample sizes.

In [None]:
X = np.random.multivariate_normal([0,1,2], [[1,1,1], [1,1,1], [1,1,1]], size=2048)
Y = np.random.multivariate_normal([10,11,12], [[11,11,11], [11,11,11], [11,11,11]], size=2048)

print(X.shape)
print(Y.shape)

In [None]:
m = X.shape[0]
n = Y.shape[0]

# compute Gram matrices
K = rbf_dot(X, X, -1)
L = rbf_dot(Y, Y, -1)
KL = rbf_dot(X, Y, -1)

# MMD
stat_d = 1/(m*(m-1)) * np.sum(K) + 1/(n*(n-1)) * np.sum(L) - 1/(m*(m-1)) * np.sum(KL) - 1/(n*(n-1)) * np.sum(KL.T)
stat = 1/(m*(m-1)) * (np.sum(K + L - KL - KL.T))

In [None]:
print(stat_d)
print(stat)

We can see that the implemented test statistic `stat` and the equation `stat_d` are identical for the same sample size.

------------------------------

# Statistical test based on MMD
We test statistically whether $\mathcal{H}_0 : P_X = P_Y$ holds true.

## MMD with permutations

In [None]:
def MMD_permutations(X, Y, alpha, width_X, width_Y, width_XY, shuffle): # set widths to -1 for median heuristics
    
    m = X.shape[0]
    
    # median heuristics for kernel width
    if width_X == -1:
        width_X = width(X)      
    if width_Y == -1:
        width_Y = width(Y)
    if width_XY == -1:
        width_XY = width(np.concatenate([X, Y]))
    
    # compute Gram matrices
    K = rbf_dot(X, X, width_X)
    L = rbf_dot(Y, Y, width_Y)
    KL = rbf_dot(X, Y, width_XY)
    
    # test statistic
    stat = 1/m * (np.sum(K + L - KL - KL.T))
    
    Kz = np.concatenate((np.concatenate((K, KL), axis=1), np.concatenate((KL.T, L), axis=1)), axis=0)
    
    # initiating MMD
    MMD_arr = np.zeros(shuffle)
    
    # create permutations by reshuffling L except the main diagonal
    for sh in range(shuffle):
        index_perm = np.random.permutation(Kz.shape[0])
        Kz_perm = Kz[np.ix_(index_perm, index_perm)]
        
        K = Kz_perm[:m, :m]
        L = Kz_perm[m:, m:]
        KL = Kz_perm[:m, m:]
        
        MMD_arr[sh] = 1/m * (np.sum(K + L - KL - KL.T))
        
    MMD_arr_sort = np.sort(MMD_arr)
    
    # computing 1-alpha threshold
    threshold = MMD_arr_sort[round((1-alpha)*shuffle)]
        
    """
    if stat > threshold:
        print('H0 rejected')
    else:
        print('H0 accepted')
    """
    
    return stat, threshold

## MMD with Gamma distribution approximation

In [None]:
def MMD_gamma(X, Y, alpha, width_X, width_Y, width_XY):    # set widths to -1 for median heuristics
    
    m = X.shape[0]
    
    # median heuristics for kernel width
    if width_X == -1:
        width_X = width(X)      
    if width_Y == -1:
        width_Y = width(Y)
    if width_XY == -1:
        width_XY = width(np.concatenate([X, Y]))
    
    # compute Gram matrices
    K = rbf_dot(X, X, width_X)
    L = rbf_dot(Y, Y, width_Y)
    KL = rbf_dot(X, Y, width_XY)
    
    # test statistic
    stat = 1/m * (np.sum(K + L - KL - KL.T))   
    
    # fitting Gamma distribution to stat
    mMMD = 2/m * (1 - 1/m * np.trace(KL))    # mean under H0
    
    K_diag = K - np.diag(np.diag(K))
    L_diag = L - np.diag(np.diag(L))
    KL_diag = KL - np.diag(np.diag(KL))
    
    varMMD = 2/(m*(m-1)) * 1/(m*(m-1)) * np.sum(np.power((K_diag + L_diag - KL_diag - KL_diag.T), 2))    # variance under H0
    
    al = mMMD**2 / varMMD
    bet = varMMD * m / mMMD
    
    # computing 1-alpha threshold
    threshold = gamma.ppf(1-alpha, al, scale=bet)
    
    """
    if stat > threshold:
        print('H0 rejected')
    else:
        print('H0 accepted')
    """
    
    return stat, threshold

### Evaluations

In [None]:
# test level alpha = 0.05, 5000 permutations
MMD_permutations(X, Y, 0.05, -1, -1, -1, 5000)

In [None]:
# test level alpha = 0.05
MMD_gamma(X, Y, 0.05, -1, -1, -1)

## Power estimation

We estimate the statistical power based on 1000 replications for each setting. Our experiment settings compose of various dimensions, sample sizes, mean shifts `delta`, and variance shifts `gam`.

In [None]:
# dimensions
dims = [1, 5, 10, 25]

# sample sizes
sample_sizes = [64, 128, 256, 512]

# mean shift
deltas = np.concatenate([np.linspace(0, 0.2, 11), np.linspace(0.25, 4, 39)])

# variance shift
gams = np.concatenate([np.linspace(0, 0.005, 11), np.linspace(0.008, 1, 39)])

# shifting process Y against X
shift_pars = [0] #, 1, 2]

### Power estimation for mean shift

In [None]:
MMD_p_m = {}
MMD_g_m = {}

for dim in dims:
    print('Dimensions:', dim)
    for sample_size in sample_sizes:
        print('Sample size:', sample_size)
        for delta in deltas:
            print('delta:', delta)
            for shift_par in shift_pars:
                #print('Shift:', shift_par)
                
                MMD_p_m_list = []
                MMD_g_m_list = []
                
                # repeating 500 times
                for i in range(200):
            
                    # defining X
                    time1 = np.linspace(0,1, dim+1)

                    # defining Y
                    time2 = time1 + delta * time1**3

                    time1d, time2d = gen_mean_shift(time1, time2, s=i, sample_size=sample_size)    # having each time different random seed

                    # bringing X and Y in same space
                    if shift_par > 0:
                        X = time1d[:, :-shift_par]
                        Y = time2d
                    elif shift_par < 0:
                        X = time1d[:, -shift_par:]
                        Y = time2d
                    else:
                        X = time1d
                        Y = time2d

                    # test level alpha = 0.05, 5000 permutations
                    MMD_p_m_list.append(MMD_permutations(X, Y, 0.05, -1, -1, -1, 1000))

                    # test level alpha = 0.05
                    MMD_g_m_list.append(MMD_gamma(X, Y, 0.05, -1, -1, -1))
                    
                MMD_p_m[(dim, sample_size, delta, shift_par)] = MMD_p_m_list
                MMD_g_m[(dim, sample_size, delta, shift_par)] = MMD_g_m_list

In [None]:
# saving
m_shift_p = open('mean_shifts_p_{}.pkl'.format(dims), 'wb')
pickle.dump(MMD_p_m, m_shift_p)
m_shift_p.close()

m_shift_g = open('mean_shifts_g_{}.pkl'.format(dims), 'wb')
pickle.dump(MMD_g_m, m_shift_g)
m_shift_g.close()

### Power estimation for variance shift

In [None]:
MMD_p_var = {}
MMD_g_var = {}

for dim in dims:
    print('Dimensions:', dim)
    for sample_size in sample_sizes:
        print('Sample size:', sample_size)
        for gam in gams:
            print('gamma:', gam)
            for shift_par in shift_pars:
                print('Shift:', shift_par)
                
                MMD_p_var_list = []
                MMD_g_var_list = []
                
                # repeating 500 times
                for i in range(500):
            
                    # defining X
                    time1 = np.zeros(dim+1)
                    
                    # defining Y
                    time2 = np.zeros(dim+1)
                
                    time1d, time2d = gen_var_shift(time1, time2, s=i, sample_size=sample_size, gam=gam)    # having each time different random seed

                    # bringing X and Y in same space
                    if shift_par > 0:
                        X = time1d[:, :-shift_par]
                        Y = time2d
                    elif shift_par < 0:
                        X = time1d[:, -shift_par:]
                        Y = time2d
                    else:
                        X = time1d
                        Y = time2d

                    # test level alpha = 0.05, 5000 permutations
                    MMD_p_var_list.append(MMD_permutations(X, Y, 0.05, -1, -1, -1, 5000))

                    # test level alpha = 0.05
                    MMD_g_var_list.append(MMD_gamma(X, Y, 0.05, -1, -1, -1))
                    
                MMD_p_var[(dim, sample_size, delta, shift_par)] = MMD_p_var_list
                MMD_g_var[(dim, sample_size, delta, shift_par)] = MMD_g_var_list

In [None]:
# saving
var_shift_p = open('var_shifts_p.pkl', 'wb')
pickle.dump(MMD_p_var, var_shift_p)
m_shift_p.close()

var_shift_g = open('var_shifts_g.pkl', 'wb')
pickle.dump(MMD_g_var, var_shift_g)
var_shift_g.close()