In [None]:
import energyflow as ef
import h5py
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from cycler import cycler
from torch import Tensor
from torch.distributions import Normal
from torchdyn.core import NeuralODE
from tqdm import tqdm
from zuko.utils import odeint

In [None]:
# define plot style
mpl.rcParams["axes.prop_cycle"] = cycler(
    color=[
        "#B6BFC3",
        "#3B515B",
        "#0271BB",
        "#E2001A",
    ]
)
mpl.rcParams["font.size"] = 15
mpl.rcParams["patch.linewidth"] = 1.25

In [None]:
path = "/beegfs/desy/user/ewencedr/data/lhco/final_data/processed_data_background_rel.h5"
with h5py.File(path, "r") as f:
    jets = f["jet_data"][:]

In [None]:
print(jets.shape)

In [None]:
p4_jets = ef.p4s_from_ptyphims(jets)

In [None]:
# get mjj from p4_jets
pj_x = np.sqrt(np.sum(p4_jets[:, 0] ** 2, axis=1))
pj_y = np.sqrt(np.sum(p4_jets[:, 1] ** 2, axis=1))
mjj = (pj_x + pj_y) ** 2
print(mjj.shape)

In [None]:
# cut window
window_left = 0.33e8
window_right = 0.37e8
args_to_remove = (mjj >= window_left) & (mjj <= window_right)
mjj_cut = mjj[~args_to_remove]

In [None]:
hist = plt.hist(
    mjj, bins=np.arange(0.005e8, 1.8e8, 0.005e8), histtype="stepfilled", label="mjj", alpha=0.5
)
plt.hist(mjj_cut, bins=hist[1], histtype="step", label="mjj with cut")
plt.legend()
# plt.xlim(window_left-0.01e8, window_right+0.01e8)
# print(np.arange(0.01e8,1.80e8, 0.01e8))
plt.show()

In [None]:
jets_cut = jets[~args_to_remove]
print(jets_cut.shape)

In [None]:
train_data = np.reshape(jets_cut, (jets_cut.shape[0], -1))
print(train_data.shape)

In [None]:
jets_plot = jets.reshape(-1, 8)
label_map = {
    "0": r"${p_T}_1$",
    "1": r"$\eta_1$",
    "2": r"$\phi_1$",
    "3": r"$m_1$",
    "4": r"${p_T}_2$",
    "5": r"$\eta_2$",
    "6": r"$\phi_2$",
    "7": r"$m_2$",
}
fig, axs = plt.subplots(2, 4, figsize=(15, 10))
for index, ax in enumerate(axs.reshape(-1)):
    hist1 = ax.hist(jets_plot[:, index], bins=100, label="original")

    next(ax._get_lines.prop_cycler)
    ax.hist(train_data[:, index], bins=hist1[1], label="with cut", histtype="step")
    ax.set_xlabel(f"{label_map[str(index)]}")
    ax.set_yscale("log")
    if index == 2 or index == 6:
        ax.legend(frameon=False)
        ax.set_ylim(1e-1, 1e6)
plt.tight_layout()
plt.show()

In [None]:
class MLP(nn.Sequential):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        hidden_features: list[int] = [64, 64],
    ):
        layers = []

        for a, b in zip(
            [in_features] + hidden_features,
            hidden_features + [out_features],
        ):
            layers.extend([nn.Linear(a, b), nn.ELU()])

        super().__init__(*layers[:-1])

In [None]:
class CNF(nn.Module):
    def __init__(
        self,
        features: int,
        freqs: int = 3,
        **kwargs,
    ):
        super().__init__()

        self.net = MLP(2 * freqs + features, features, **kwargs)

        self.register_buffer("freqs", torch.arange(1, freqs + 1) * torch.pi)

    def forward(self, t: Tensor, x: Tensor) -> Tensor:
        t = self.freqs * t[..., None]
        t = torch.cat((t.cos(), t.sin()), dim=-1)
        t = t.expand(*x.shape[:-1], -1)

        return self.net(torch.cat((t, x), dim=-1))

    def encode(self, x: Tensor) -> Tensor:
        # return odeint(self, x, 0.0, 1.0, phi=self.parameters())
        pass

    def decode(self, z: Tensor) -> Tensor:
        # return odeint(self, z, 1.0, 0.0, phi=self.parameters())
        node = NeuralODE(z, solver="midpoint", sensitivity="adjoint")
        t_span = torch.linspace(1.0, 0.0, 50)
        traj = node.trajectory(z, t_span)
        return traj[-1]

    def log_prob(self, x: Tensor) -> Tensor:
        i = torch.eye(x.shape[-1]).to(x)
        i = i.expand(x.shape + x.shape[-1:]).movedim(-1, 0)

        def augmented(t: Tensor, x: Tensor, ladj: Tensor) -> Tensor:
            with torch.enable_grad():
                x = x.requires_grad_()
                dx = self(t, x)

            jacobian = torch.autograd.grad(dx, x, i, is_grads_batched=True, create_graph=True)[0]
            trace = torch.einsum("i...i", jacobian)

            return dx, trace * 1e-2

        ladj = torch.zeros_like(x[..., 0])
        z, ladj = odeint(augmented, (x, ladj), 0.0, 1.0, phi=self.parameters())

        return Normal(0.0, z.new_tensor(1.0)).log_prob(z).sum(dim=-1) + ladj * 1e2

In [None]:
class FlowMatchingLoss(nn.Module):
    def __init__(self, v: nn.Module):
        super().__init__()

        self.v = v

    def forward(self, x: Tensor) -> Tensor:
        t = torch.rand_like(x[..., 0]).unsqueeze(-1)
        z = torch.randn_like(x)
        y = (1 - t) * x + (1e-4 + (1 - 1e-4) * t) * z
        u = (1 - 1e-4) * z - x

        return (self.v(t.squeeze(-1), y) - u).square().mean()

In [None]:
if __name__ == "__main__":
    flow = CNF(8, hidden_features=[256] * 3)

    # Training
    loss = FlowMatchingLoss(flow)
    optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3)

    # data, _ = make_moons(4096, noise=0.05)
    # data = torch.from_numpy(data).float()
    data = torch.from_numpy(train_data).float()

    for epoch in tqdm(range(100000), ncols=88):
        subset = torch.randint(0, len(data), (256,))
        x = data[subset]

        optimizer.zero_grad()
        loss(x).backward()
        optimizer.step()

    # Sampling
    with torch.no_grad():
        z = torch.randn(4096, 8)
        x = flow.decode(z).numpy()

In [None]:
generated_data = x
plot_train_data = train_data[: len(generated_data)]

In [None]:
jets_plot = jets.reshape(-1, 8)
label_map = {
    "0": r"${p_T}_1$",
    "1": r"$\eta_1$",
    "2": r"$\phi_1$",
    "3": r"$m_1$",
    "4": r"${p_T}_2$",
    "5": r"$\eta_2$",
    "6": r"$\phi_2$",
    "7": r"$m_2$",
}
fig, axs = plt.subplots(2, 4, figsize=(15, 10))
for index, ax in enumerate(axs.reshape(-1)):
    hist1 = ax.hist(plot_train_data[:, index], bins=100, label="train data")

    next(ax._get_lines.prop_cycler)
    ax.hist(generated_data[:, index], bins=hist1[1], label="generated", histtype="step")
    ax.set_xlabel(f"{label_map[str(index)]}")
    ax.set_yscale("log")
    if index == 2 or index == 6:
        ax.legend(frameon=False)
        ax.set_ylim(1e-1, 1e6)
plt.tight_layout()
plt.show()