In [4]:
import jax
import jax.numpy as jnp
import jax.scipy as jsc
import matplotlib.pyplot as plt
import numpy as np
import pyhf
from celluloid import Camera
from model import hepdata_like
import optax
from jaxopt import OptaxSolver
import relaxed

jax.config.update("jax_enable_x64", True)

In [5]:
def yields(x):
    s = 15 + x
    b = 45 - 2 * x
    db = 1 + 0.2 * x**2
    return jnp.asarray([s]), jnp.asarray([b]), jnp.asarray([db])

In [9]:
@jax.jit
def pipeline(phi):
    model = hepdata_like(*yields(phi))
    bonly = model.config.suggested_init().at[model.config.poi_index].set(0.0)
    data = model.expected_data(bonly)
    return relaxed.infer.hypotest(1.0, data, model, lr=1e-3, expected_pars=bonly)

@jax.jit
def pipeline_stat_only(phi):
    model = hepdata_like(*yields(phi)[:-1], 0.01)
    bonly = model.config.suggested_init().at[model.config.poi_index].set(0.0)
    data = model.expected_data(bonly)
    return relaxed.infer.hypotest(1.0, data, model, lr=1e-3, expected_pars=bonly)

@jax.jit
def loss(phi):
    y = yields(phi)
    model = hepdata_like(*y)
    bonly = model.config.suggested_init().at[model.config.poi_index].set(0.0)
    data = model.expected_data(bonly)
    return relaxed.infer.hypotest(1.0, data, model, lr=1e-3, expected_pars=bonly), y

phis = jnp.linspace(0,10,100) 
cls_vals = jax.vmap(pipeline)(phis)
cls_vals_stat = jax.vmap(pipeline_stat_only)(phis)



In [19]:
solver = OptaxSolver(loss, opt=optax.adam(1e-1), has_aux=True)
pars = 9.
state = solver.init_state(pars)

plt.rc('figure', figsize=(6,3), dpi=220, facecolor="w")
fig, axs = plt.subplots(1,2)
cam = Camera(fig)
steps = 200
for i in range(steps):
    pars, state = solver.update(pars, state)
    s, b, db = state.aux
    val = state.value
    ax = axs[0]
    ax.plot(phis, cls_vals, c='C0', label='p-value (with uncertainty)')
    ax.plot(phis, cls_vals_stat, c='green', label = 'p-value (without uncertainty)')
    ax.scatter(pars, val, c='C0')
    ax.set_xlabel(r'analysis config $\phi$')
    ax.set_ylabel('p-value')
    if i==0:
        ax.legend()
    ax = axs[1]
    ax.bar(.5, b, facecolor='C1', label='b')
    ax.bar(.5, s, bottom=b, facecolor ='C9', label='s')
    ax.bar(.5, db, bottom=b-db/2, facecolor='k', alpha=.5, label=r'$\sigma_b$')
    ax.set_xticks([])
    if i==0:
        ax.legend()
    cam.snap()
    
ani = cam.animate()
ani.save('a.gif')



In [2]:
import numpy as np
import matplotlib.pyplot as plt

x = np.linspace(0, 10)
s = 15 + x
b = 45 - 2 * x
db = 1 + 0.2 * x**2


def get_cls(s, b, db):
    model = pyhf.simplemodels.hepdata_like([s], [b], [db])
    data = [b] + model.config.auxdata
    cls_val = pyhf.infer.hypotest(1.0, data, model)
    return cls_val


cls_v = [get_cls(s, b, db) for s, b, db in np.stack([s, b, db]).T]
plt.scatter(x[np.argmin(cls_v)], cls_v[np.argmin(cls_v)])
plt.vlines(x[np.argmin(cls_v)], 0, 0.1, colors="steelblue", linestyles="dashed")
plt.plot(x, cls_v, label="CLs w/ syst.")
cls_v_nosyst = [get_cls(s, b, 0.01) for s, b, db in np.stack([s, b, db]).T]
plt.plot(x, cls_v_nosyst, label="CLs no syst.")
plt.vlines(x[np.argmin(cls_v_nosyst)], 0, 0.1, colors="green", linestyles="dashed")
plt.xlabel(r"Analysis Configuration $\phi$")
plt.ylim(0, 0.1)

plt.show()

plt.fill_between(x, s + b, b)
plt.fill_between(x, b)
plt.fill_between(x, b - db, b + db, facecolor="k", alpha=0.2)
plt.vlines(x[np.argmin(cls_v)], 0, 80, colors="orange", linestyles="dashed")
plt.vlines(x[np.argmin(cls_v_nosyst)], 0, 80, colors="green", linestyles="dashed")
plt.ylim(0, 80)
plt.xlabel(r"Analysis Configuration $\phi$")

NameError: name 'np' is not defined