In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import dirichlet
from scipy.integrate import cumtrapz, simps, trapz
from scipy.interpolate import PchipInterpolator

from pwass.distributions import Distribution
from pwass.regression.distrib_on_distrib import DistribOnDistribReg
from pwass.regression.simplicial import SimpliciadDistribOnDistrib

from pwass.spline import SplineBasis, MonotoneQuadraticSplineBasis

In [None]:
def error_from_simp(ytrue, ypred):
    grid = np.linspace(0, 1, 200)
    # reconstruct the quantiles
    qgrid_true = cumtrapz(ytrue.pdf_eval, ytrue.pdf_grid)
    qeval_true = ytrue.pdf_grid[1:]
    qtrue = PchipInterpolator(qgrid_true, qeval_true)
    
    qgrid_pred = cumtrapz(ypred.pdf_eval, ypred.pdf_grid)
    qeval_pred = ypred.pdf_grid[1:]
    qpred = PchipInterpolator(qgrid_pred, qeval_pred)
    
    er = trapz((qpred(grid) - qtrue(grid)) ** 2, grid)
    return er

def error_from_wass(ytrue, ypred):
    return trapz(
        (ytrue.quantile_eval - ypred.wbasis.eval_spline(ypred.quantile_coeffs, ytrue.quantile_grid))**2, 
        ytrue.quantile_grid)


def loo_wass(qx, qy):
    qx = np.array(qx)
    qy = np.array(qy)
    errs = []
    for i in range(len(qx)):
        reg = DistribOnDistribReg(spline_basis=wbasis, compute_spline=False) 
        reg.fit(np.delete(qx, i), np.delete(qy, i))
        errs.append(error_from_wass(qy[i], reg.predict([qx[i]])[0]))
        
    return errs


def loo_simp(pdfx, pdfy):
    pdfx = np.array(pdfx)
    pdfy = np.array(pdfy)
    errs = []
    for i in range(len(pdfx)):
        reg = SimpliciadDistribOnDistrib(spline_basis=simpbasis, compute_spline=False) 
        reg.fit(np.delete(pdfx, i), np.delete(pdfy, i))
        errs.append(error_from_simp(pdfy[i], reg.predict([pdfx[i]])[0]))
        
    return errs

## generate data from our model

In [None]:
zero_one_grid = np.linspace(0, 1, 1000)
nbasis = 20
wbasis = MonotoneQuadraticSplineBasis(nbasis, zero_one_grid)

In [None]:
def generate_quantiles(ndata, nbasis=10):
    zero_one_grid = np.linspace(0, 1, 500)
    wbasis = SplineBasis(3, nbasis=nbasis, xgrid=zero_one_grid)
    x_coeffs = np.cumsum(np.hstack(
        [np.zeros(ndata).reshape(-1, 1), dirichlet.rvs(np.ones(nbasis - 1) * 5, size=ndata)]), axis=1)
    beta = np.cumsum(np.random.uniform(0, 0.5, size=(nbasis, nbasis)), axis=1)
    y_coeffs = np.matmul(x_coeffs, beta)
    
    X_evals = np.zeros((ndata, len(zero_one_grid)))
    Y_evals = np.zeros((ndata, len(zero_one_grid)))
    
    for i in range(ndata):
        X_evals[i, :] = wbasis.eval_spline(x_coeffs[i, :])
        Y_evals[i, :] = wbasis.eval_spline(y_coeffs[i, :])
        
    return zero_one_grid, X_evals, Y_evals

In [None]:
ndata = 100
grid, X_evals, Y_evals = generate_quantiles(ndata)

simpgrid = np.linspace(np.min([X_evals, Y_evals]), np.max([X_evals, Y_evals]), 100)
simpbasis = SplineBasis(3, nbasis=100, xgrid=simpgrid)
xmin = simpgrid[0]
xmax = simpgrid[-1]


qx = []
qy = []
pdfx = []
pdfy = []

for i in range(ndata):
    print("\r{0} / {1}".format(i+1, ndata), end=" ", flush=True)
    curr_qx = Distribution(wbasis=wbasis)
    curr_qx.init_from_quantile(grid, X_evals[i, :])
    curr_qx.compute_spline_expansions()
    qx.append(curr_qx)
    
    curr_qy = Distribution(wbasis=wbasis)
    curr_qy.init_from_quantile(grid, Y_evals[i, :])
    curr_qy.compute_spline_expansions()
    qy.append(curr_qy)
    
    
    curr_pdfx = Distribution(xbasis=simpbasis)
    pdf_eval = np.diff(grid) / np.diff(X_evals[i, :])
    pdf_grid = X_evals[i, 1:]
    before = np.arange(xmin, pdf_grid[0], 0.1)
    after = np.arange(pdf_grid[-1], xmax, 0.1)[1:]
    pdf_grid = np.concatenate([before, pdf_grid, after])
    pdf_eval = np.concatenate([np.ones_like(before) * 1e-5, pdf_eval, np.ones_like(after) * 1e-5])
    curr_pdfx.init_from_pdf(pdf_grid, pdf_eval)
    curr_pdfx.compute_clr()
    curr_pdfx.compute_spline_expansions()
    pdfx.append(curr_pdfx)
    
    curr_pdfy = Distribution(xbasis=simpbasis)
    pdf_eval = np.diff(grid) / np.diff(Y_evals[i, :])
    pdf_grid = Y_evals[i, 1:]
    before = np.arange(xmin, pdf_grid[0], 0.1)
    after = np.arange(pdf_grid[-1], xmax, 0.1)[1:]
    pdf_grid = np.concatenate([before, pdf_grid, after])
    pdf_eval = np.concatenate([np.ones_like(before) * 1e-5, pdf_eval, np.ones_like(after) * 1e-5])
    curr_pdfy.init_from_pdf(pdf_grid, pdf_eval)
    curr_pdfy.compute_clr()
    curr_pdfy.compute_spline_expansions()
    pdfy.append(curr_pdfy)

In [None]:
er_wass = loo_wass(qx, qy)
er_simp = loo_simp(pdfx, pdfy)

In [None]:
print("WASS. ERROR: {0:.10f}, STD: {1:.10f}".format(np.mean(er_wass), np.std(er_wass)))
print("SIMP. ERROR: {0:.5f}, STD: {1:.5f}".format(np.mean(er_simp), np.std(er_simp)))

## generate data from simplicial

In [None]:
def inv_clr(f_eval, grid):
    out = np.exp(f_eval)
    den = simps(out, grid)
    return out / den


def generate_pdfs(ndata):
    nbasis = 20
    zero_one_grid = np.linspace(0, 1, 500)
    basis = SplineBasis(3, nbasis=nbasis, xgrid=zero_one_grid)
    
    x_coeffs = np.random.normal(scale=0.2, size=(ndata, nbasis))
    beta = np.random.normal(size=(nbasis, nbasis))
    y_coeffs = np.matmul(x_coeffs, beta)
    
    X_evals = np.zeros((ndata, len(zero_one_grid)))
    Y_evals = np.zeros((ndata, len(zero_one_grid)))
    
    for i in range(ndata):
        X_evals[i, :] = inv_clr(basis.eval_spline(x_coeffs[i, :]), zero_one_grid)
        Y_evals[i, :] = inv_clr(basis.eval_spline(y_coeffs[i, :]), zero_one_grid)
        
    return zero_one_grid, X_evals, Y_evals

In [None]:
zero_one_grid = np.linspace(0, 1, 1000)
nbasis = 20
wbasis = MonotoneQuadraticSplineBasis(nbasis, zero_one_grid)
simpbasis = SplineBasis(2, nbasis=nbasis, xgrid=zero_one_grid)

In [None]:
ndata = 100
grid, X_evals, Y_evals = generate_pdfs(ndata)

qx = []
qy = []
pdfx = []
pdfy = []

for i in range(ndata):
    curr_pdfx = Distribution(xbasis=simpbasis)
    curr_pdfx.init_from_pdf(grid, X_evals[i, :])
    curr_pdfx.compute_clr()
    curr_pdfx.compute_spline_expansions()
    pdfx.append(curr_pdfx)
    
    curr_pdfy = Distribution(xbasis=simpbasis)
    curr_pdfy.init_from_pdf(grid, Y_evals[i, :])
    curr_pdfy.compute_clr()
    curr_pdfy.compute_spline_expansions()
    pdfy.append(curr_pdfy)
    
    curr_qx = Distribution(wbasis=wbasis)
    curr_qx.init_from_pdf(grid, X_evals[i, :])
    curr_qx._invert_cdf()
    curr_qx.compute_spline_expansions()
    qx.append(curr_qx)
    
    curr_qy = Distribution(wbasis=wbasis)
    curr_qy.init_from_pdf(grid, Y_evals[i, :])
    curr_qy._invert_cdf()
    curr_qy.compute_spline_expansions()
    qy.append(curr_qy)

In [None]:
er_wass = loo_wass(qx, qy)
er_simp = loo_simp(pdfx, pdfy)

print("WASS. ERROR: {0:.4f}, STD: {1:.5f}".format(np.mean(er_wass), np.std(er_wass)))
print("SIMP. ERROR: {0:.4f}, STD: {1:.5f}".format(np.mean(er_simp), np.std(er_simp)))

In [None]:
reg1 = DistribOnDistribReg(spline_basis=wbasis, compute_spline=False) 
reg1.fit(qx, qy)

reg2 = SimpliciadDistribOnDistrib(spline_basis=wbasis, compute_spline=False) 
reg2.fit(pdfx, pdfy)

In [None]:
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

vmin = np.min([reg1.beta, reg2.beta])
vmax = np.max([reg1.beta, reg2.beta])

axes[0].imshow(reg1.beta)
axes[1].imshow(reg2.beta)
plt.show()