In [1]:
from itertools import product
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import scipy

import warnings
warnings.filterwarnings("ignore")

In [2]:
a = np.array([1,2,3])

In [3]:
aa = a[:,None] @ a[:,None].T
np.broadcast_to(np.expand_dims(np.diag(aa), 0), aa.shape)

array([[1, 4, 9],
       [1, 4, 9],
       [1, 4, 9]])

In [4]:
def MMD(x: np.ndarray, y: np.ndarray, kernel: str):
    """Emprical maximum mean discrepancy. The lower the result
       the more evidence that distributions are the same.

    Args:
        x: first sample, distribution P
        y: second sample, distribution Q
        kernel: kernel type such as "multiscale" or "rbf"
    """
    xx, yy, zz = x[:,None] @ x[:,None].T, y[:,None] @ y[:,None].T, x[:,None] @ y[:,None].T
    rx = np.broadcast_to(np.expand_dims(np.diag(xx), 0), xx.shape)
    ry = np.broadcast_to(np.expand_dims(np.diag(yy), 0), yy.shape)

    dxx = rx.T + rx - 2. * xx
    dyy = ry.T + ry - 2. * yy
    dxy = rx.T + ry - 2. * zz
    
    XX = np.zeros(xx.shape)
    YY = np.zeros(xx.shape)
    XY = np.zeros(xx.shape)

    if kernel == "multiscale":
        bandwidth_range = [0.2, 0.5, 0.9, 1.3]
        for a in bandwidth_range:
            XX += a**2 * (a**2 + dxx)**-1
            YY += a**2 * (a**2 + dyy)**-1
            XY += a**2 * (a**2 + dxy)**-1

    if kernel == "rbf":
        bandwidth_range = [10, 15, 20, 50]
        for a in bandwidth_range:
            XX += np.exp(-0.5*dxx/a)
            YY += np.exp(-0.5*dyy/a)
            XY += np.exp(-0.5*dxy/a)

    return np.mean(XX + YY - 2. * XY)

In [5]:
MMD(np.array([1,2,3]), np.array([1000,1000,10000]), 'rbf')

6.072698104628811

In [1]:
def simulate_nn_location_scale(M: int, N: int, n: int, n_runs: int):
    np.random.seed(0)
    error_est = []
    error_observed = []

    for _ in tqdm(range(n_runs)):
        true_dists = {}

        # Table of n x m Gaussians with columns having same variance and rows having same means
        mean_rows = np.random.uniform(-5,5,M)
        std_cols = np.random.uniform(1,5,N)
        
        data_table = np.zeros((M, N, n))

        for i, j in product(range(M), range(N)):
            data_table[i,j,:] = np.sort(np.random.normal(mean_rows[i], std_cols[j], n))
            true_dists[i,j] = normal_ppf(mean_rows[i], std_cols[j])

        mask = np.ones((M,N)).astype(int)
        mask[0,0] = 0
        
        eta = search_eta(data_table, mask, 3)
        
        est_data_table = {}

        d = get_user_user_distances_fast(data_table, mask)
        est_data_table[0,0] = estimate_fast(data_table, mask, 0,0, d, eta)
        
        error_est.append(wasserstein2(empirical_quantile(est_data_table[0,0]), true_dists[0,0]))
        error_observed.append(wasserstein2(empirical_quantile(data_table[0,0]), true_dists[0,0]))

    return error_est, error_observed