In [16]:
from __future__ import annotations

import relaxed
import jax
import jax.numpy as jnp
from jax.random import PRNGKey, multivariate_normal
import pyhf
from typing import Callable, Any
from functools import partial

jax.config.update("jax_enable_x64", True)

Array = jnp.ndarray

# specific also
def generate_data(
    rng=0,
    num_points=10000,
    sig_mean=(-1, 1),
    bup_mean=(4.5, 2),
    bdown_mean=(-2.5, -2.5),
    b_mean=(0, 0),
):
    sig = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(sig_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )
    bkg_up = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(bup_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )
    bkg_down = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(bdown_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )

    bkg_nom = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(b_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )
    return sig, bkg_nom, bkg_up, bkg_down


def hists_from_pars(
    pars: dict[str, Array],
    data: dict[str, Array],
    nn: Callable,
    bandwidth: float,
    bins: Array | None = None,
    scale_factors: dict[str, float] | None = None,
    overall_scale: float = 10.0,
) -> dict[str, Array]:
    """Function that takes in data + analysis config parameters, and constructs yields."""
    nn_output = {k: nn(pars["nn_pars"], data[k]).ravel() for k in data}
    make_hist = partial(
        relaxed.hist, bandwidth=bandwidth, bins=pars["bins"] if "bins" in pars else bins
    )
    scale_factors = scale_factors or {k: 1.0 for k in nn_output}
    # every histogram is scaled to the number of points from that data source in the batch
    # so we have more control over the scaling of sig/bkg for realism
    hists = {
        k: make_hist(nn_output[k]) * scale_factors[k] * overall_scale / len(v)
        + 1e-3  # no zeros!
        for k, v in nn_output.items()
    }
    return hists


# specific to use case
def model_from_hists(hists: dict[str, Array]) -> pyhf.Model:
    """How to make your HistFactory model from your histograms."""
    spec = {
        "channels": [
            {
                "name": "singlechannel",
                "samples": [
                    {
                        "name": "signal",
                        "data": hists["sig"],
                        "modifiers": [
                            {"name": "mu", "type": "normfactor", "data": None},
                        ],
                    },
                    {
                        "name": "background",
                        "data": hists["bkg_nominal"],
                        "modifiers": [
                            {
                                "name": "correlated_bkg_uncertainty",
                                "type": "histosys",
                                "data": {
                                    "hi_data": hists["bkg_up"],
                                    "lo_data": hists["bkg_down"],
                                },
                            },
                        ],
                    },
                ],
            },
        ],
    }
    return pyhf.Model(spec, validate=False)


def poi_uncert(model: pyhf.Model) -> float:
    hypothesis_pars = (
        jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(1.0)
    )
    observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
    return relaxed.cramer_rao_uncert(model, hypothesis_pars, observed_hist)[
        model.config.poi_index
    ]


def discovery_significance(model: pyhf.Model, fit_lr: float) -> float:
    test_stat = "q0"
    test_poi = 0.0
    hypothesis_pars = (
        jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(1.0)
    )
    observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
    return relaxed.infer.hypotest(
        test_poi=test_poi,
        data=observed_hist,
        model=model,
        test_stat=test_stat,
        expected_pars=hypothesis_pars,
        lr=fit_lr,
    )


def cls_value(model: pyhf.Model, fit_lr: float) -> float:
    test_stat = "q"
    test_poi = 1.0
    hypothesis_pars = (
        jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(0.0)
    )
    observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
    return relaxed.infer.hypotest(
        test_poi=test_poi,
        data=observed_hist,
        model=model,
        test_stat=test_stat,
        expected_pars=hypothesis_pars,
        lr=fit_lr,
    )


def generalised_variance(model: pyhf.Model) -> float:
    hypothesis_pars = (
        jnp.asarray(model.config.suggested_init()).at[model.config.poi_index].set(0.0)
    )
    observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
    return jnp.linalg.det(
        jnp.linalg.inv(relaxed.fisher_info(model, hypothesis_pars, observed_hist))
    )


def loss_from_model(
    model: pyhf.Model,
    loss: str | Callable[[dict[str, Any]], float] = "neos",
    fit_lr: float = 1e-3,
) -> float:
    if isinstance(loss, Callable):
        # everything
        return 0
    # loss specific
    if loss.lower() == "discovery":
        return discovery_significance(model, fit_lr)
    elif loss.lower() in ["neos", "cls"]:
        return cls_value(model, fit_lr)
    elif loss.lower() in ["inferno", "poi_uncert", "mu_uncert"]:
        return poi_uncert(model)
    elif loss.lower() in [
        "general_variance",
        "generalised_variance",
        "generalized_variance",
    ]:
        return generalised_variance(model)
    else:
        raise ValueError(f"loss function {loss} not recognised")


def pipeline(pars, data, bins, nn, loss, bandwidth, keys, scale_factors):
    hists = hists_from_pars(
        pars=pars,
        nn=nn,
        data={k: v for k, v in zip(keys, data)},
        bandwidth=bandwidth,
        bins=bins,
        scale_factors=scale_factors,
    )
    model = model_from_hists(hists)
    return loss_from_model(model, loss=loss)

In [17]:
from jax.experimental import stax

rng_state = 0  # random state

# feel free to modify :)
init_random_params, nn = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid,
)

num_features = 2
_, init = init_random_params(PRNGKey(rng_state), (-1, num_features))
init_pars = dict(nn_pars=init)

In [18]:
from sklearn.model_selection import train_test_split
import numpy.random as npr

batch_size = 256

data = generate_data(rng=rng_state, num_points=10000)
split = train_test_split(*data, random_state=rng_state)
train, test = split[::2], split[1::2]

num_train = train[0].shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

# batching mechanism
def data_stream():
    rng = npr.RandomState(rng_state)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size : (i + 1) * batch_size]
            yield [points[batch_idx] for points in train]


batch_iterator = data_stream()

In [24]:
from jaxopt import OptaxSolver
import optax
from time import perf_counter


bins = jnp.linspace(0, 1, 5)
lr = 1e-3
num_steps = 100
data_types = ["sig", "bkg_nominal", "bkg_up", "bkg_down"]
loss = partial(
    pipeline,
    bandwidth=1e-1,
    nn=nn,
    keys=data_types,
    scale_factors={k: 2.0 if k == "sig" else 10.0 for k in data_types},
)

solver = OptaxSolver(loss, opt=optax.adam(lr), jit=True)

pyhf.set_backend("jax", default=True)
state = solver.init_state(init_pars, bins=bins)
params = init_pars
metrics = {k: [] for k in ["cls", "discovery", "generalised_variance"]}
objective = "cls"
for i in range(num_steps):
    print(f"step {i}: loss={objective}")
    data = next(batch_iterator)
    start = perf_counter()
    params, state = solver.update(params, state, bins=bins, data=data, loss=objective)
    end = perf_counter()
    print(f"update took {end-start:.4f}s")
    for metric in metrics:
        test_metric = loss(params, bins=bins, data=test, loss=metric)
        print(f"{metric}={test_metric:.4f}")
        metrics[metric].append(test_metric)
    print()

step 0: loss=cls




update took 2.0456s
cls=0.0287
discovery=0.0114
generalised_variance=2301.2326

step 1: loss=cls
update took 1.9931s
cls=0.0135
discovery=0.0039
generalised_variance=1934.1711

step 2: loss=cls
update took 2.0036s
cls=0.0118
discovery=0.0029
generalised_variance=880.1147

step 3: loss=cls
update took 2.0046s
cls=0.0132
discovery=0.0031
generalised_variance=172.8951

step 4: loss=cls
update took 2.4035s
cls=0.0139
discovery=0.0033
generalised_variance=80.4252

step 5: loss=cls
update took 2.0239s
cls=0.0140
discovery=0.0034
generalised_variance=49.0265

step 6: loss=cls
update took 2.0591s
cls=0.0129
discovery=0.0030
generalised_variance=41.0883

step 7: loss=cls
update took 2.1175s


KeyboardInterrupt: 

In [None]:
loss(dict(nn_pars=init_pars), next(batch_iterator), bins)

DeviceArray(0.05952097, dtype=float64)