In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def Bins(xmin, xmax, bins=100):
    dx = (xmax - xmin) / bins
    return np.arange(xmin, xmax, dx)

In [None]:

from src.datasets import JetNetDataset
from torch.utils.data import DataLoader
import torch
class DataConfig:

    sets : dict = {
                   'fm_midpoint' : ('fm_tops150_mp200nfe.h5', 'etaphipt'),
                   'fm_euler'    : ('fm_tops150_eu200nfe.h5', 'etaphipt'),
                   'diff_ddim'   : ('ddim_200.h5', 'etaphipt_frac'),
                   'jetnet'      : ('t150.hdf5', 'particle_features')
                   }

    labels : dict = {
                     'fm_midpoint' : 0,
                     'fm_euler'    : 1,
                     'diff_ddim'   : 2,
                     'jetnet'      : 3
                     }


config = DataConfig 
data = JetNetDataset(dir_path='data/', 
                        data_files=config.sets,
                        data_class_labels=config.labels,
                        particle_features=['eta_rel', 'phi_rel', 'pt_rel', 'R', 'e_rel'],
                        )


In [None]:

data_loader = DataLoader(dataset=data, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)

for batch, labels in data_loader: 
    print(batch[0],labels[1])
    break

In [None]:
class JetNetPreprocess:
    
    def __init__(self, 
                 jet: torch.Tensor=None, 
                 method: dict=None, 
                 info: tuple=None):
        
        self.jet = jet
        self.dim_features = self.jet.shape[-1] - 1
        self.method = method
        self.mean, self.std, self.max, self.min = info
        self.mask = self.jet[:, -1, None]
        self.jet_unmask = self.jet[:, :self.dim_features]


    def standardize(self,  sigma: float=1.0):
        self.jet_unmask = (self.jet_unmask * self.mean) * (1e-8 + sigma / self.std )
        self.jet_unmask = self.jet_unmask * self.mask
        self.jet = torch.cat((self.jet_unmask, self.mask), dim=-1)

    def normalize(self):
        self.jet_unmask = (self.jet_unmask - self.min) / ( self.max - self.min )
        self.jet_unmask = self.jet_unmask * self.mask
        self.jet = torch.cat((self.jet_unmask, self.mask), dim=-1)
    
    def logit_tramsform(self, alpha=1e-6):
        self.jet_unmask = self.logit(self.jet_unmask, alpha=alpha)
        self.jet_unmask = self.jet_unmask * self.mask
        self.jet = torch.cat((self.jet_unmask, self.mask), dim=-1)

    def logit(t, alpha=1e-6):
        x = alpha + (1 - 2 * alpha) * t
        return torch.log(x/(1-x))

    # def normalize(self, inverse: bool=False):
    #     mask = self.particles[:, -1].bool()
    #     D = self.dim
    #     if not inverse:
    #         self.max, _ = torch.max(self.particles[..., :D][mask], dim=0, keepdim=True)
    #         self.min, _ = torch.min(self.particles[..., :D][mask], dim=0, keepdim=True)
    #         self.particles[..., :D][mask] = (self.particles[..., :D][mask] - self.min) / (self.max - self.min) 
    #         print('\t- normalizing data')
    #     else:
    #         self.particles[..., :D][mask] = self.particles[..., :D][mask] * (self.max - self.min) + self.min * mask 
    #         print('\t-: un-normalizing data')

    # def preprocess(self, methods: dict={}, name: str='data'):
    #     print('INFO: preprocessing {}'.format(name))
    #     method_items = list(methods.items())  
    #     for method_name, method_kwargs in method_items:
    #         method_kwargs['inverse'] = False
    #         method = getattr(self, method_name, None)
    #         if method and callable(method):
    #             method(**method_kwargs)
    #         else:
    #             print(f"Method {method_name} not found")



jet = data[0][0]
info = (torch.Tensor([ 1.8317e-05, -6.5716e-05,  1.5960e-02]), 
        torch.Tensor([0.1364, 0.1361, 0.0300]), 
        torch.Tensor([-3.1485e+00, -6.8031e-01,  2.0489e-08]), 
        torch.Tensor([1.6063, 0.8665, 0.9089]))

print(jet)
jet_prepr = JetPreprocess(jet, info=info)
jet_prepr.standardize(inverse=True, sigma=5.0)
# print(jet_prepr.jet)
jet_prepr.jet

In [None]:
data[0][0]

# Jet images

In [None]:
bins = (Bins(-1,1,200), Bins(-1,1,200))

jetnet.image(bins=bins) 
flowmatch.image(bins=bins)
diffusion.image(bins=bins)

# Check distribution of particle features

In [None]:

data = JetNetDataLoader(dir_path='data/', 
                        data_files=DataConfig.sets,
                        preprocess=None,
                        num_jets=175000,
                        num_constituents=30, 
                        clip_neg_pt=True,
                        particle_features=['eta_rel', 'phi_rel', 'pt_rel', 'R', 'e_rel']
                        )

jetnet = JetNetFeatures(data[2])
flowmatch_mp = JetNetFeatures(data[0])
flowmatch_eu = JetNetFeatures(data[1])
diffusion = JetNetFeatures(data[3])

bins = Bins(0,1)
fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.particle_plot(feature='pt_rel', bins=bins, ax=ax) 
flowmatch_eu.particle_plot(feature='pt_rel', bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.particle_plot(feature='pt_rel', bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.particle_plot(feature='pt_rel', bins=bins, fill=False, color='b', ax=ax) 

bins = Bins(0,3)
fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.particle_plot(feature='R', bins=bins, ax=ax) 
flowmatch_eu.particle_plot(feature='R', bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.particle_plot(feature='R', bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.particle_plot(feature='R', bins=bins, fill=False, color='b', ax=ax) 

bins = Bins(0,0.8)
fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.particle_plot(feature='pt_rel', nth_particle=1, bins=bins, ax=ax) 
flowmatch_eu.particle_plot(feature='pt_rel', nth_particle=1, bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.particle_plot(feature='pt_rel', nth_particle=1,bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.particle_plot(feature='pt_rel', nth_particle=1, bins=bins, fill=False, color='b', ax=ax) 
plt.title(r'hardest particles in jet')

bins = Bins(0,0.35)
fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.jet_plot(feature='m_rel', bins=bins, ax=ax) 
flowmatch_eu.jet_plot(feature='m_rel', bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.jet_plot(feature='m_rel', bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.jet_plot(feature='m_rel', bins=bins, fill=False, color='b', ax=ax) 

bins = range(30)
fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.jet_plot(feature='multiplicity', bins=bins, ax=ax) 
flowmatch_eu.jet_plot(feature='multiplicity', bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.jet_plot(feature='multiplicity', bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.jet_plot(feature='multiplicity', bins=bins, fill=False, color='b', ax=ax) 

# Preprocess data

In [None]:
data = JetNetDataLoader(dir_path='data/',  data_files=DataConfig.sets, preprocess=DataConfig.preprocess, num_jets=175000, num_constituents=30,  clip_neg_pt=True, particle_features=['eta_rel', 'phi_rel', 'pt_rel', 'R', 'e_rel'])

jetnet = JetNetFeatures(data[2])
flowmatch_mp = JetNetFeatures(data[0])
flowmatch_eu = JetNetFeatures(data[1])
diffusion = JetNetFeatures(data[3])

bins = Bins(-10,50)
fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.particle_plot(feature='pt_rel', bins=bins, ax=ax) 
flowmatch_eu.particle_plot(feature='pt_rel', bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.particle_plot(feature='pt_rel', bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.particle_plot(feature='pt_rel', bins=bins, fill=False, color='b', ax=ax) 

fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.particle_plot(feature='R', bins=bins, ax=ax) 
flowmatch_eu.particle_plot(feature='R', bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.particle_plot(feature='R', bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.particle_plot(feature='R', bins=bins, fill=False, color='b', ax=ax) 

fig, ax = plt.subplots(1, figsize=(5,5))
jetnet.particle_plot(feature='pt_rel', nth_particle=1, bins=bins, ax=ax) 
flowmatch_eu.particle_plot(feature='pt_rel', nth_particle=1, bins=bins,  fill=False, color='r', ax=ax) 
flowmatch_mp.particle_plot(feature='pt_rel', nth_particle=1,bins=bins,  fill=False, color='purple', ax=ax) 
diffusion.particle_plot(feature='pt_rel', nth_particle=1, bins=bins, fill=False, color='b', ax=ax) 
plt.title(r'hardest particles in jet')
