In [None]:
import matplotlib.pyplot as plt
import numpy as np
from scipy import stats

from halt.stats import moments
%matplotlib widget

In [None]:
rng = np.random.default_rng(94105)

In [None]:
N1 = 10000
N2 = 10000

In [None]:
def chisq(a, b, ndof=None):
    return np.sum((a - b)**2) / (ndof if ndof is not None else len(a))

In [None]:
def jensen_shannon_divergence(a, b):
    m = (a + b) / 2
    return (stats.entropy(a, m) + stats.entropy(b, m)) / 2

In [None]:
def jeffreys_distance(a, b):
    return np.sum((np.sqrt(a) - np.sqrt(b))**2)

In [None]:
def sample(distribution, n):
    return distribution.rvs(n, random_state=rng)

In [None]:
def reldiff(a, b):
    return (a - b) * 2 / (a + b)

In [None]:
def compare(x, a, b, *, out=None):
    if out is None:
        out = dict()
    mom_a = moments(x, a, moments.all)
    mom_b = moments(x, b, moments.all)
    for name in moments.all:
        out.setdefault(name, []).append(reldiff(mom_a[name], mom_b[name]))
    for name, fn in (('DKL', stats.entropy), ('JSD', jensen_shannon_divergence),
                     ('J', jeffreys_distance)):
        out.setdefault(name, []).append(fn(a, b))
    return out

In [None]:
x = np.linspace(-3, 3, 20)
delta_x = x[1] - x[0]
bin_edges = np.r_[x - delta_x/2, x[-1] + delta_x/2]

In [None]:
ref_loc = 1.0
ref_scale = 0.4
ref_sample = sample(stats.norm(loc=ref_loc, scale=ref_scale), N1)
ref_dist = np.histogram(ref_sample, bins=bin_edges, density=True)[0]

In [None]:
scales = np.linspace(0.1, 1.0, 40, endpoint=True)
metrics = dict()
for scale in scales:
    dist = np.histogram(sample(stats.norm(loc=ref_loc, scale=scale), N2),
                        bins=bin_edges, density=True)[0]
    compare(x, ref_dist, dist, out=metrics)

In [None]:
n_plots = len(metrics)
fig, axs = plt.subplots(nrows=min(n_plots, 4), ncols = int(np.ceil(n_plots / min(n_plots, 4))))
xlim = (scales[0] - (scales[1]-scales[0])/10, scales[-1] + (scales[1]-scales[0])/10)
for ax, (name, vals) in zip(axs.flat, metrics.items()):
    ax.set_title(name)
    ax.set_xlabel('scale')
    ax.plot(scales, vals)    
    ax.set_xlim(xlim)
fig.tight_layout()