# $\varepsilon$ comparison

In [None]:
import numpy as np

from conditional_expectation_methods import compute_ols_parameters, nadaraya_watson, knn_conditional_expectation_improved
from wasserstein_distances import W2_empirical
from EM import euler_maruyama_coupling

def compute_gradient_U(x, mu1=0.0, sigma1=1.0):
    return (x - mu1) / (sigma1**2)

def compute_gradient_V(y, mu2=2.0, sigma2=0.5):
    return (y - mu2) / (sigma2**2)

def sample_bivariate_normal(n, mu1, sigma1, mu2, sigma2, rho, seed=None):
    rng = np.random.default_rng(seed)
    cov = np.array([[sigma1**2, rho * sigma1 * sigma2],
                    [rho * sigma1 * sigma2, sigma2**2]])
    mean = np.array([mu1, mu2])
    samples = rng.multivariate_normal(mean, cov, size=n)
    return samples[:, 0], samples[:, 1]

num_samples = int(1e3)
d = 1
T = 20
k_neighbors = int(np.round(num_samples**(2/(d+4))))
seed = 42

X0, Y0 = sample_bivariate_normal(
    n=num_samples,
    mu1=0, sigma1=1,
    mu2=2, sigma2=0.5,
    rho=0.6,
    seed=seed
)

W2_sq_emp = W2_empirical(X0, Y0)**2

# 12 values of epsilon from 0.1 down to 1e-8
eps_list = np.logspace(-1, -8, 4)

# fix dt and N
dt = 0.01
N = int(np.round(T / dt))

rel_errors = []
time_arrays = []

for epsilon in eps_list:
    err_kNN, X_kNN, Y_kNN, _, mY_kNN, mX_kNN = euler_maruyama_coupling(
        X0, Y0,
        epsilon=epsilon,
        T=T,
        N=N,
        cond_method='knn',
        h=None,
        k=k_neighbors,
        grad_U=compute_gradient_U,
        grad_V=compute_gradient_V,
        compute_W2_sq=None,
        seed=seed
    )

    err_OLS, X_OLS, Y_OLS, _, mY_OLS, mX_OLS = euler_maruyama_coupling(
        X0, Y0,
        epsilon=epsilon,
        T=T,
        N=N,
        cond_method='ols',
        h=None,
        k=k_neighbors,
        grad_U=compute_gradient_U,
        grad_V=compute_gradient_V,
        compute_W2_sq=None,
        seed=seed
    )

    min_err_kNN = np.min(err_kNN)
    rel_err = (min_err_kNN - W2_sq_emp)**2 / (W2_sq_emp**2)
    t = np.linspace(0, T, len(err_kNN))
    rel_errors.append(rel_err)
    time_arrays.append(t)