In [2]:
from data.particle_clouds.jets import JetDataModule
from pipeline.configs import ExperimentConfigs

config = ExperimentConfigs('config.yaml')
raw_jets = JetDataModule(config, preprocess=False)
jets = JetDataModule(config)


In [3]:
raw_jets.setup('fit')
jets.setup('fit')

[94m[1mINFO: [0m[00m Setting up datasets for training...
[94m[1mINFO: [0m[00m Setting up datasets for training...
[94m[1mINFO: [0m[00m Preprocessing source/target fit datasets...


  self.continuous = (self.continuous - torch.tensor(mean)) / (
  torch.tensor(std)


In [36]:

def get_flavor_counts(tensor):
    mask = tensor.mask_bool
    jet_flavors = tensor.discrete[mask]
    for i in range(8):
        print(i, len(jet_flavors[jet_flavors==i]))

get_flavor_counts(jets.target)


0 154312
1 38183
2 142701
3 146611
4 665
5 742
6 471
7 542


In [40]:
def get_one_hot_flavor_counts(tensor):
    isPhoton = (tensor.discrete[..., 0] == 1) & (tensor.discrete[..., -1] == 0) 
    isNeutralHadron = (tensor.discrete[..., 1] == 1) & (tensor.discrete[..., -1] == 0) 
    isNegativeHadron = (tensor.discrete[..., 2] == 1) & (tensor.discrete[..., -1] == -1) 
    isPositiveHadron = (tensor.discrete[..., 2] == 1) & (tensor.discrete[..., -1] == 1) 
    isElectron = (tensor.discrete[..., 3] == 1) & (tensor.discrete[..., -1] == -1) 
    isPositron = (tensor.discrete[..., 3] == 1) & (tensor.discrete[..., -1] == 1) 
    isMuon = (tensor.discrete[..., 4] == 1) & (tensor.discrete[..., -1] == -1) 
    isAntiMuon = (tensor.discrete[..., 4] == 1) & (tensor.discrete[..., -1] == 1) 

    flavors = {0:isPhoton, 1:isNeutralHadron, 2:isNegativeHadron, 3:isPositiveHadron, 4:isElectron, 5:isPositron, 6:isMuon, 7:isAntiMuon}

    for i in range(8):
        print(i, tensor.discrete[flavors[i]].shape[0])

get_one_hot_flavor_counts(raw_jets.target)

0 154307
1 38186
2 142701
3 146613
4 665
5 742
6 471
7 542


In [None]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
import h5py

plt.rcParams["mathtext.fontset"] = "cm"
plt.rcParams["figure.autolayout"] = False

from data.dataclasses import MultiModeState
from data.particle_clouds.utils import (
    extract_jetclass_features,
    extract_aoj_features,
    sample_noise,
    sample_masks,
    map_basis_to_tokens,
    map_tokens_to_basis,
)


class ParticleClouds(MultiModeState):
    def __post_init__(
        self,
        dataset="JetClass",
        path=None,
        num_jets=100_000,
        min_num_particles=0,
        max_num_particles=128,
        multiplicity_dist=None,
    ):
        self.max_num_particles = max_num_particles

        if isinstance(dataset, torch.Tensor):
            self.continuous, self.discrete, self.mask = (
                dataset[..., :3],
                dataset[..., 3:-1].long(),
                dataset[..., -1].unsqueeze(-1).long(),
            )
            if not self.discrete.nelement():
                del self.discrete

        elif isinstance(dataset, MultiModeState):
            if "continuous" in dataset.available_modes():
                self.continuous = dataset.continuous
            if "discrete" in dataset.available_modes():
                self.discrete = dataset.discrete
            self.mask = dataset.mask

        elif "JetClass" in dataset:
            assert path is not None, "Specify the path to the JetClass dataset"
            self.continuous, self.discrete, self.mask = extract_jetclass_features(
                path,
                num_jets,
                min_num_particles,
                max_num_particles,
            )

        elif "AspenOpenJets" in dataset:
            assert path is not None, "Specify the path to the AOJ dataset"
            self.continuous, self.discrete, self.mask = extract_aoj_features(
                path,
                num_jets,
                min_num_particles,
                max_num_particles,
            )

        elif "Noise" in dataset:
            self.continuous, self.discrete = sample_noise(num_jets, max_num_particles)
            self.mask = sample_masks(
                multiplicity_dist,
                num_jets,
                min_num_particles,
                max_num_particles,
            )
            self.continuous *= self.mask
            self.discrete *= self.mask

        # ... useful attributes:

        self.pt = self.continuous[..., 0]
        self.eta_rel = self.continuous[..., 1]
        self.phi_rel = self.continuous[..., 2]
        self.multiplicity = torch.sum(self.mask, dim=1)
        self.mask_bool = self.mask.squeeze(-1) > 0


    def compute_4mom(self):
        self.px = self.pt * torch.cos(self.phi_rel)
        self.py = self.pt * torch.sin(self.phi_rel)
        self.pz = self.pt * torch.sinh(self.eta_rel)
        self.e = self.pt * torch.cosh(self.eta_rel)

    def get_data_stats(self):
        hist, _ = np.histogram(
            self.multiplicity,
            bins=np.arange(0, self.max_num_particles + 2, 1),
            density=True,
        )
        return {
            "mean": self.continuous[self.mask_bool].mean(0).tolist(),
            "std": self.continuous[self.mask_bool].std(0).tolist(),
            "min": self.continuous[self.mask_bool].min(0).values.tolist(),
            "max": self.continuous[self.mask_bool].max(0).values.tolist(),
            "multinomial_num_particles": hist.tolist(),
            "num_particles_mean": torch.mean(
                self.multiplicity.squeeze(-1).float()
            ).item(),
            "num_particles_std": torch.std(
                self.multiplicity.squeeze(-1).float()
            ).item(),
        }

    # ...data visualization methods

    def histplot(
        self,
        feature="pt",
        idx=None,
        xlim=None,
        ylim=None,
        xlabel=None,
        ylabel=None,
        figsize=(3, 3),
        fontsize=12,
        ax=None,
        **kwargs,
    ):
        if ax is None:
            _, ax = plt.subplots(figsize=figsize)
        x = (
            getattr(self, feature)[self.mask_bool]
            if idx is None
            else getattr(self, feature)[:, idx]
        )
        sns.histplot(x=x, element="step", ax=ax, **kwargs)
        ax.set_xlabel(feature if xlabel is None else xlabel, fontsize=fontsize)
        ax.set_ylabel(ylabel, fontsize=fontsize)
        ax.set_xlim(xlim)
        ax.set_ylim(ylim)
