In [None]:
import numpy as np
from pyfrechet.metric_spaces import *
from pyfrechet.regression.trees import Tree
from pyfrechet.regression.bagged_regressor import BaggedRegressor
from pyfrechet.metrics import mse
from datetime import datetime
import sklearn
import time
import json

import sys, os; sys.path.append(os.path.dirname(os.getcwd())) 
from benchmark import bench

import numpy as np
from scipy import stats

from pyfrechet.metric_spaces import *
import pyfrechet.metric_spaces.wasserstein_1d as W1d

def gen_data(N, p, alpha, beta, sig0=1, gam=2.5):
    M = W1d.Wasserstein1D()
    
    STD_NORMAL_Q = stats.norm.ppf(W1d.Wasserstein1D.GRID)
    STD_NORMAL_Q[0] = 2*STD_NORMAL_Q[1] - STD_NORMAL_Q[2] # lexp to avoid infs
    STD_NORMAL_Q[-1] = 2*STD_NORMAL_Q[-2] - STD_NORMAL_Q[-3] # lexp to avoid infs

    def m(x):
        eta = 10*(x-0.5).dot(beta)/np.sqrt(p) + alpha
        mu = eta
        sig = sig0 + gam*stats.logistic.cdf(eta)
        return mu + sig * STD_NORMAL_Q

    x = np.random.rand(N*p).reshape((N,p))
    mx = np.array([ m(x[i,:]) for i in range(N)])
    y = np.array([ W1d.noise_2(mx[i,:], l=2) for i in range(N)])

    return x, MetricData(M, y), MetricData(M, mx)

OUT_FILE = './20231108-result_wasserstein.json'

try:

    bench(
        gen_data,
        OUT_FILE,
        ps=[2, 5, 10, 20],
        Ns=[50,100,200,400],
        min_split_sizes=[5],
        subsample_fracs=[0.75],
        mtry_fracs=[0.5],
        replicas=100,
        n_trees=100,
        is_honest=False
    )

except Exception:
    import traceback
    traceback.print_exc()

