In [1]:
import os
import ast
import glob
import functools
import math
import torch

from torch.utils.data import DataLoader
from src.logger.logger import _logger, _configLogger
from src.dataset.dataset import EventDatasetCollection, EventDataset
from src.utils.import_tools import import_module
from src.dataset.functions_graph import graph_batch_func
from src.dataset.functions_data import concat_events
from src.utils.paths import get_path


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_data = EventDataset.from_directory("/work/gkrzmanc/jetclustering/preprocessed_data/scouting_PFNano_signals1/SVJ_hadronic_std3/s-channel_mMed-1100_mDark-20_rinv-0.3_alpha-peak_13TeV-pythia8_n-2000", mmap=True)
print("N events:", len(train_data))
train_loader = DataLoader(
        train_data,
        batch_size=8,
        drop_last=True,
        pin_memory=True,
        num_workers=1,
        collate_fn=concat_events,
        persistent_workers=1
    )


N events: 91316


In [3]:
b = next(iter(train_loader))

In [4]:
b.n_events

8

In [5]:
dir(b)

['MET',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__len__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 'evt_collections',
 'fatjets',
 'genjets',
 'init_attrs',
 'jets',
 'matrix_element_gen_particles',
 'n_events',
 'offline_pfcands',
 'pfcands',
 'serialize',
 'special_pfcands']

In [6]:
len(b.pfcands), b.pfcands.batch_number

(2023, tensor([   0,  331,  636,  821, 1044, 1404, 1619, 1787, 2023]))

In [22]:
def get_idx_for_event(obj, i):
    return obj.batch_number[i], obj.batch_number[i+1]

def get_labels(pfcands):
    labels = torch.zeros(len(pfcands)).long()
    R = 0.8
    for i in range(len(b)):
        s, e = get_idx_for_event(b.matrix_element_gen_particles, i)
        dq_eta = b.matrix_element_gen_particles.eta[s:e]
        dq_phi = b.matrix_element_gen_particles.phi[s:e]
        # dq_pt = b.matrix_element_gen_particles.pt[s:e] # Maybe we can somehow weigh the loss by pt?
        s, e = get_idx_for_event(pfcands, i)
        pfcands_eta = pfcands.eta[s:e]
        pfcands_phi = pfcands.phi[s:e]
        # calculate the distance matrix between each dark quark and pfcands
        dist_matrix = torch.cdist(
            torch.stack([dq_eta, dq_phi], dim=1),
            torch.stack([pfcands_eta, pfcands_phi], dim=1),
            p=2
        )
        dist_matrix = dist_matrix.T
        closest_quark_dist, closest_quark_idx = dist_matrix.min(dim=1)
        closest_quark_idx[closest_quark_dist > R] = -1
        labels[s:e] = closest_quark_idx
    return labels