In [None]:
import jax
import jax.numpy as jnp
from event2vec.datasets import VBFHDataset

In [None]:
data = VBFHDataset.from_lhe("../data/smeftsim_VBFH-*.lhe.gz")
data

In [None]:
from event2vec.prior import SMPlusNormalParameterPrior
import matplotlib.pyplot as plt


prior = SMPlusNormalParameterPrior(
    # ["cHbox", "cHDD", "cHW", "cHB", "cHWB"]
    mean=jnp.array([0.0, 0.0, 0.0, 0.0, 0.0]),
    cov=jnp.diag(jnp.array([1.e+1, 1.e+1, 1.e-1, 1.e-1, 1.e+0])),
)


bins = jnp.linspace(0, 1000, 50)
plt.hist(data.observables[:, -3], bins=bins, histtype="step", color="k", linewidth=2)
plt.yscale("log")


for i in range(10):
    key = jax.random.PRNGKey(i)

    plt.hist(
        data.observables[:, -3],
        weights=data.weight(prior.sample(key)),
        bins=bins,
        histtype="step",
    )

In [None]:
from event2vec.experiment import run_experiment
from event2vec.model import E2VMLPConfig, CARLMLPConfig
from event2vec.prior import UncorrelatedJointPrior
from event2vec.training import TrainingConfig
from event2vec.loss import BCELoss

key = jax.random.PRNGKey(42)


def data_factory(key):
    return data


model_config = CARLMLPConfig(
    event_dim=data.observable_dim,
    param_dim=data.parameter_dim,
    hidden_size=64,
    depth=3,
    quadratic=True,
)
train_config = TrainingConfig(
    test_fraction=0.1,
    batch_size=64,
    learning_rate=0.005,
    epochs=1_000,
    loss_fn=BCELoss(parameter_prior=UncorrelatedJointPrior(prior)),
)

# jax.config.update("jax_debug_nans", True)

model, data, loss_train, loss_test = run_experiment(
    data_factory, model_config, train_config, key=key
)


In [None]:
fig, ax = plt.subplots()

ax.plot(loss_train, label="train loss")
ax.plot(loss_test, label="test loss")
ax.set_xlabel("Epoch")
ax.set_ylabel("Loss")
ax.legend()

In [None]:
fig, (axl, axr) = plt.subplots(1, 2, figsize=(10, 5))


key = jax.random.PRNGKey(22)

param_0 = jnp.array([1.0, 0.0, 0.0, 0.0, 0.0, 0.0])
param_1 = prior.sample(key=key)
print("Param 0:", param_0)
print("Param 1:", param_1)

llr_pred = jax.vmap(model.llr_pred, in_axes=(0, None, None))(data.observables, param_0, param_1)
llr_true = jnp.log(data.likelihood(param_1)) - jnp.log(data.likelihood(param_0))

amin = min(jnp.min(llr_pred).item(), jnp.min(llr_true).item())
amax = max(jnp.max(llr_pred).item(), jnp.max(llr_true).item())
axl.set_xlim(amin, amax)
axl.set_ylim(amin, amax)
axl.set_aspect("equal")

axl.plot([0, 1], [0, 1], color="grey", linestyle="--", transform=axl.transAxes)
axl.scatter(llr_pred, llr_true, s=1)
axl.set_xlabel("Predicted LLR")
axl.set_ylabel("True LLR")
p1short = ",".join([f"{p:.2f}" for p in param_1])
axl.set_title(fr"$\theta_1$: [{p1short}], $\theta_0$: SM")

lr_true = jnp.exp(llr_true)
lr_pred = jnp.exp(llr_pred)

qbins = jnp.quantile(lr_pred, jnp.linspace(0, 1, 21))

sumc, _ = jnp.histogram(lr_pred, bins=qbins)
sumw, _ = jnp.histogram(lr_pred, bins=qbins, weights=lr_true)
sumw2, _ = jnp.histogram(lr_pred, bins=qbins, weights=lr_true**2)

mean = sumw / sumc
std = jnp.sqrt(sumw2 / sumc - (sumw / sumc)**2)

axr.errorbar(
    0.5 * (qbins[1:] + qbins[:-1]),
    mean,
    xerr=0.5 * (qbins[1:] - qbins[:-1]),
    yerr=std,
    fmt="o",
    markersize=5,
    capsize=3,
)
axr.plot([0, 1], [0, 1], color="grey", linestyle="--", transform=axr.transAxes)

# axr.set_aspect("equal")
axr.set_xlabel("Predicted LR")
axr.set_ylabel("Mean true LR")