# Testing inversion sampling for the Gaussian

In [None]:
import math
from random import random

import numpy as np
import torch
from torch.distributions import (
    constraints,
    Transform,
    TransformedDistribution,
    Normal,
    Uniform,
    VonMises,
)
import matplotlib.pyplot as plt

%load_ext lab_black

In [None]:
class GaussToUnif(Transform):
    bijective = True
    domain = constraints.real
    codomain = constraints.interval(0, 1)

    def __init__(self, gauss_sigma=1, unif_low=0, unif_high=1):
        super().__init__()
        self._gauss_sigma = gauss_sigma
        self._unif_low = unif_low
        self._unif_high = unif_high
        self._unif_interval = unif_high - unif_low
        self._gauss = Normal(loc=0, scale=gauss_sigma)

    def _call(self, x):
        """Gauss -> Uniform"""
        return (
            torch.special.ndtr(x / self._gauss_sigma) * self._unif_interval
            + self._unif_low
        )

    def _inverse(self, z):
        """Uniform -> Gauss"""
        return (
            torch.special.ndtri((z - self._unif_low) / self._unif_interval)
            * self._gauss_sigma
        )

    def log_abs_det_jacobian(self, x, z):
        return self._gauss.log_prob(x)

In [None]:
GAUSS_SIGMA = 0.05
UNIF_LOW = 0
UNIF_HIGH = 1
N_SAMPLE = 1e7
N_BINS = 100

target = Normal(loc=0, scale=GAUSS_SIGMA)
model = TransformedDistribution(
    base_distribution=Uniform(low=UNIF_LOW, high=UNIF_HIGH),
    transforms=GaussToUnif(GAUSS_SIGMA, UNIF_LOW, UNIF_HIGH).inv,
)


bins = torch.linspace(-5 * GAUSS_SIGMA, 5 * GAUSS_SIGMA, N_BINS)
midpoints = 0.5 * (bins[:-1] + bins[1:])
target_func = target.log_prob(midpoints).exp().numpy()

model_sample = model.sample([int(N_SAMPLE)])
model_hist, _ = np.histogram(model_sample.numpy(), bins=bins, density=True)
rel_error = (model_hist - target_func) / target_func

fig, ax = plt.subplots(figsize=(12, 8))
ax.hist(model_sample.numpy(), bins=bins, label="model", density=True)
ax.plot(midpoints, target_func, "r--", label="target")
ax.set_ylabel("density")
ax.set_xlabel("x")
ax.legend()
# fig.savefig("ndtri_test_density.png")

fig2, ax2 = plt.subplots(figsize=(12, 8))
ax2.plot(rel_error, "bo", label="model > target")
ax2.plot(-rel_error, "ro", label="model < target")
ax2.set_yscale("log")
ax2.set_ylabel("|(model - target) / target|")
ax2.set_xlabel("x")
ax2.legend()
# fig2.savefig("ndtri_test_rel_error.png")