In [None]:
from __future__ import annotations

import relaxed
import jax
import jax.numpy as jnp
from jax.random import PRNGKey, multivariate_normal
import pyhf
from typing import Callable, Any
from functools import partial

Array = jnp.ndarray


def generate_data(
    rng=0,
    num_points=10000,
    sig_mean=(-1, 1),
    bup_mean=(4.5, 2),
    bdown_mean=(-2.5, -2.5),
    b_mean=(0, 0),
):
    sig = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(sig_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )
    bkg_up = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(bup_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )
    bkg_down = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(bdown_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )

    bkg_nom = multivariate_normal(
        PRNGKey(rng),
        jnp.asarray(b_mean),
        jnp.asarray([[1, 0], [0, 1]]),
        shape=(num_points,),
    )
    return dict(sig=sig, bkg_nominal=bkg_nom, bkg_up=bkg_up, bkg_down=bkg_down)


def hists_from_pars(
    pars: dict[str, Array],
    data: dict[str, Array],
    nn: Callable,
    bandwidth: float,
    bins: Array | None = None,
    scale_factors: dict[str, float] | None = None,
) -> dict[str, Array]:
    """Function that takes in data + analysis config parameters, and constructs yields."""
    nn_output = {k: nn(pars["nn_pars"], data[k]).ravel() for k in data}
    make_hist = partial(
        relaxed.hist, bandwidth=bandwidth, bins=pars["bins"] if "bins" in pars else bins
    )
    scale_factors = scale_factors or {k: 1.0 for k in nn_output}
    hists = {k: make_hist(nn_output[k]) * scale_factors[k] for k in nn_output}
    return hists


def model_from_hists(hists: dict[str, Array]) -> pyhf.Model:
    """How to make your HistFactory model from your histograms."""
    spec = {
        "channels": [
            {
                "name": "singlechannel",
                "samples": [
                    {
                        "name": "signal",
                        "data": hists["sig"],
                        "modifiers": [
                            {"name": "mu", "type": "normfactor", "data": None},
                        ],
                    },
                    {
                        "name": "background",
                        "data": hists["bkg_nominal"],
                        "modifiers": [
                            {
                                "name": "correlated_bkg_uncertainty",
                                "type": "histosys",
                                "data": {
                                    "hi_data": hists["bkg_up"],
                                    "lo_data": hists["bkg_down"],
                                },
                            },
                        ],
                    },
                ],
            },
        ],
    }
    return pyhf.Model(spec, validate=False)


def loss_from_model(
    model: pyhf.Model,
    loss: str | Callable[[dict[str, Any]], float] = "neos",
    fit_lr=1e-3,
):
    if isinstance(loss, Callable):
        # everything
        return 0
    # loss specific
    if loss.tolower() == "discovery":
        test_stat = "q0"
        test_poi = 0.0
        hypothesis_pars = (
            jnp.asarray(model.config.suggested_init())
            .at[model.config.poi_index]
            .set(1.0)
        )
    elif loss.tolower() in ["neos", "cls"]:
        test_stat = "q"
        test_poi = 1.0
        hypothesis_pars = (
            jnp.asarray(model.config.suggested_init())
            .at[model.config.poi_index]
            .set(0.0)
        )
    elif loss.tolower() in ["inferno", "poi_uncert", "mu_uncert"]:
        test_stat = "q0"
        test_poi = 0.0
        hypothesis_pars = (
            jnp.asarray(model.config.suggested_init())
            .at[model.config.poi_index]
            .set(1.0)
        )
        observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
        return relaxed.cramer_rao_uncert(model, hypothesis_pars, observed_hist)[
            model.config.poi_index
        ]
    elif loss.tolower() in ["general_variance"]:
        test_poi = 0.0
        hypothesis_pars = (
            jnp.asarray(model.config.suggested_init())
            .at[model.config.poi_index]
            .set(1.0)
        )
        observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
        return jnp.linalg.det(
            jnp.linalg.inv(relaxed.fisher_info(model, hypothesis_pars, observed_hist))
        )

    observed_hist = jnp.asarray(model.expected_data(hypothesis_pars))
    return relaxed.infer.hypotest(
        test_poi=test_poi,
        data=observed_hist,
        model=model,
        test_stat=test_stat,
        expected_pars=hypothesis_pars,
        lr=fit_lr,
    )


def pipeline(pars, data, nn, loss, bandwidth):
    hists = hists_from_pars(pars=pars, nn=nn, data=data, bandwidth=bandwidth)
    model = model_from_hists(hists)
    return loss_from_model(model, loss=loss)