In [1]:
import torch 
import torch.nn.functional as F
import jax.numpy as np
import numpy as onp

In [2]:

def searchsorted(bin_locations, inputs, eps=1e-6):
    bin_locations[..., -1] += eps
    return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1


def unconstrained_linear_spline(
    inputs, unnormalized_pdf, inverse=False, tail_bound=1.0, tails="linear"
):
    inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
    outside_interval_mask = ~inside_interval_mask

    outputs = torch.zeros_like(inputs)
    logabsdet = torch.zeros_like(inputs)

    if tails == "linear":
        outputs[outside_interval_mask] = inputs[outside_interval_mask]
        logabsdet[outside_interval_mask] = 0
    else:
        raise RuntimeError("{} tails are not implemented.".format(tails))

    if torch.any(inside_interval_mask):
        outputs[inside_interval_mask], logabsdet[inside_interval_mask] = linear_spline(
            inputs=inputs[inside_interval_mask],
            unnormalized_pdf=unnormalized_pdf[inside_interval_mask, :],
            inverse=inverse,
            left=-tail_bound,
            right=tail_bound,
            bottom=-tail_bound,
            top=tail_bound,
        )

    return outputs, logabsdet




def linear_spline(
    inputs, unnormalized_pdf, inverse=False, left=0.0, right=1.0, bottom=0.0, top=1.0
):
    """
    Reference:
    > Müller et al., Neural Importance Sampling, arXiv:1808.03856, 2018.
    """

    if inverse:
        inputs = (inputs - bottom) / (top - bottom)
    else:
        inputs = (inputs - left) / (right - left)

    num_bins = unnormalized_pdf.size(-1)

    # print("unnormalized_pdf:", unnormalized_pdf.shape, unnormalized_pdf)
    pdf = F.softmax(unnormalized_pdf, dim=-1)    


    cdf = torch.cumsum(pdf, dim=-1)
    cdf[..., -1] = 1.0
    cdf = F.pad(cdf, pad=(1, 0), mode="constant", value=0.0)

    if inverse:
        inv_bin_idx = searchsorted(cdf, inputs)

        bin_boundaries = (
            torch.linspace(0, 1, num_bins + 1)
            .view([1] * inputs.dim() + [-1])
            .expand(*inputs.shape, -1)
        )

        slopes = (cdf[..., 1:] - cdf[..., :-1]) / (
            bin_boundaries[..., 1:] - bin_boundaries[..., :-1]
        )
        offsets = cdf[..., 1:] - slopes * bin_boundaries[..., 1:]

        inv_bin_idx = inv_bin_idx.unsqueeze(-1)
        input_slopes = slopes.gather(-1, inv_bin_idx)[..., 0]
        input_offsets = offsets.gather(-1, inv_bin_idx)[..., 0]

        outputs = (inputs - input_offsets) / input_slopes
        outputs = torch.clamp(outputs, 0, 1)

        logabsdet = -torch.log(input_slopes)
    else:
        bin_pos = inputs * num_bins

        bin_idx = torch.floor(bin_pos).long()
        bin_idx[bin_idx >= num_bins] = num_bins - 1

        alpha = bin_pos - bin_idx.float()

        input_pdfs = pdf.gather(-1, bin_idx[..., None])[..., 0]

        outputs = cdf.gather(-1, bin_idx[..., None])[..., 0]
        outputs += alpha * input_pdfs
        outputs = torch.clamp(outputs, 0, 1)

        bin_width = 1.0 / num_bins
        logabsdet = torch.log(input_pdfs) - onp.log(bin_width)

    if inverse:
        outputs = outputs * (right - left) + left
    else:
        outputs = outputs * (top - bottom) + bottom

    return outputs, logabsdet

In [3]:
import lbi.models.flows.splines.linear as me

In [4]:
batch_size = 10
data_dim = 3
num_bins = 5

inputs = torch.rand(size=(batch_size, data_dim))
trans_params = torch.rand(size=(batch_size, num_bins * data_dim))

torch_params = trans_params.view(batch_size, -1, num_bins)

jax_params = trans_params.numpy().reshape(batch_size, num_bins, -1)
jax_params = np.transpose(jax_params, (0, 2, 1))




In [5]:
torch_outputs, torch_logabsdet = unconstrained_linear_spline(
    inputs, torch_params, inverse=False, tail_bound=20.2,
)

In [6]:
torch_inputs, torch_logabsdet = unconstrained_linear_spline(
    torch_outputs, torch_params, inverse=True, tail_bound=20.2,
)

In [7]:
jax_outputs, jax_logabsdet = me.unconstrained_linear_spline(
    inputs.numpy(), jax_params, inverse=False, tail_bound=20.2,
)

In [8]:
jax_inputs, jax_logabsdet = me.unconstrained_linear_spline(
    jax_outputs, jax_params, inverse=True, tail_bound=20.2,
)

In [9]:
np.mean(torch_outputs.numpy() - jax_outputs), np.mean(torch_inputs.numpy() - jax_inputs), np.mean(inputs.numpy() - jax_inputs)

(DeviceArray(-0.18788618, dtype=float32),
 DeviceArray(3.1789145e-07, dtype=float32),
 DeviceArray(-5.0266584e-07, dtype=float32))