In [1]:
import uproot
from coffea.nanoevents import NanoEventsFactory
import numpy as np
import matplotlib.pyplot as plt
import awkward as ak
from torch.utils.data import DataLoader
import torch

from saja.modules import SaJa
from saja import JetPartonAssignmentDataset
from saja.training import trainNet
from saja.losses import object_wise_cross_entropy

In [2]:
def filterEvents(jets, electrons, muons, genpart):

    
    selected_electrons = electrons[electrons.pt > 25]
    selected_muons = muons[muons.pt > 25]
    jet_filter = (jets.pt > 25)
    selected_jets = jets[jet_filter]

    # single lepton requirement
    event_filters = ((ak.count(selected_electrons.pt, axis=1) + ak.count(selected_muons.pt, axis=1)) == 1)
    # at least four jets
    event_filters = event_filters & (ak.count(selected_jets.pt, axis=1) >= 4)
    # at least one b-tagged jet ("tag" means score above threshold)
    B_TAG_THRESHOLD = 0.5
    event_filters = event_filters & (ak.sum(selected_jets.btagCSVV2 >= B_TAG_THRESHOLD, axis=1) >= 1)

    # selected_events = events[event_filters]
    selected_electrons = selected_electrons[event_filters]
    selected_muons = selected_muons[event_filters]
    selected_jets = selected_jets[event_filters]
    selected_genpart = genpart[event_filters]

    ### only consider 4j2b region
    region_filter = ak.sum(selected_jets.btagCSVV2 > B_TAG_THRESHOLD, axis=1) >= 2

    selected_jets_region = selected_jets[region_filter]
    selected_electrons_region = selected_electrons[region_filter]
    selected_muons_region = selected_muons[region_filter]
    selected_genpart_region = selected_genpart[region_filter]
    
    
    #### filter genPart to valid matching candidates ####

    # get rid of particles without parents
    genpart_parent = selected_genpart_region.distinctParent
    genpart_filter = np.invert(ak.is_none(genpart_parent, axis=1))
    selected_genpart_region_reduced = selected_genpart_region[genpart_filter]
    genpart_parent_reduced = selected_genpart_region_reduced.distinctParent

    # ensure that parents are top quark or W
    genpart_filter2 = ((np.abs(genpart_parent_reduced.pdgId)==6) | 
                       (np.abs(genpart_parent_reduced.pdgId)==24))
    selected_genpart_region_reduced = selected_genpart_region_reduced[genpart_filter2]

    # ensure particle itself is a quark
    genpart_filter3 = ((np.abs(selected_genpart_region_reduced.pdgId)<7) & 
                       (np.abs(selected_genpart_region_reduced.pdgId)>0))
    selected_genpart_region_reduced = selected_genpart_region_reduced[genpart_filter3]

    # get rid of duplicates
    genpart_filter4 = selected_genpart_region_reduced.hasFlags("isLastCopy")
    selected_genpart_region_reduced = selected_genpart_region_reduced[genpart_filter4]
    
    # match jets to nearest valid genPart candidate
    nearest_genpart = selected_jets_region.nearest(selected_genpart_region_reduced, 
                                                   threshold=1.0)
    nearest_parent = nearest_genpart.distinctParent # parent of matched particle

    parent_pdgid = nearest_parent.pdgId # pdgId of parent particle
    grandchild_pdgid = nearest_parent.distinctChildren.distinctChildren.pdgId # pdgId of particle's parent's grandchildren

    jet_counts = ak.num(selected_jets_region)
    grandchildren_flat = np.abs(ak.flatten(grandchild_pdgid,axis=-1)) # flatten innermost axis for convenience
    
    # if particle has a cousin that is a lepton
    has_lepton_cousin = (ak.sum(((grandchildren_flat%2==0) & (grandchildren_flat>10) & (grandchildren_flat<19)),
                                axis=-1)>0)
    # if particle has a cousin that is a neutrino
    has_neutrino_cousin = (ak.sum(((grandchildren_flat%2==1) & (grandchildren_flat>10) & (grandchildren_flat<19)),
                                  axis=-1)>0)
    # if a particle has a lepton cousin and a neutrino cousin
    has_both_cousins = ak.fill_none((has_lepton_cousin & has_neutrino_cousin), False)
    
    has_both_cousins_flat = ak.flatten(has_both_cousins)
    # get labels from parent pdgId (fill none with 100 to filter them)
    labels_flat = np.abs(ak.fill_none(ak.flatten(parent_pdgid),100).to_numpy())
    labels_flat[has_both_cousins_flat] = -6 # assign jets with both cousins as top1
    
    # W jet labels
    labels_W_flat = np.copy(labels_flat)
    labels_W_flat[labels_W_flat!=24]=0
    labels_W_flat[labels_W_flat==24]=1
    labels_W = ak.unflatten(labels_W_flat, jet_counts)
    labels_all = ak.zeros_like(labels_W)
    labels_all = ak.where(labels_W==1, ak.ones_like(labels_W), labels_all)
    
    # top1 jet labels
    labels_top1_flat = np.copy(labels_flat)
    labels_top1_flat[labels_top1_flat!=-6]=0
    labels_top1_flat[labels_top1_flat==-6]=1
    labels_top1 = ak.unflatten(labels_top1_flat, jet_counts)
    labels_all = ak.where(labels_top1==1, 2*ak.ones_like(labels_top1), labels_all)
    
    # top2 jet labels
    labels_top2_flat = np.copy(labels_flat)
    labels_top2_flat[labels_top2_flat!=6]=0
    labels_top2_flat[labels_top2_flat==6]=1
    labels_top2 = ak.unflatten(labels_top2_flat, jet_counts)
    labels_all = ak.where(labels_top2==1, 3*ak.ones_like(labels_top2), labels_all)
    
    # top2 jet labels
    labels_other_flat = np.zeros(labels_flat.shape)
    labels_other_flat[(labels_flat!=6) & (labels_flat!=-6) & (labels_flat!=24)]=1
    labels_other = ak.unflatten(labels_other_flat, jet_counts)
    
    # labels = ak.concatenate([x[..., np.newaxis] for x in ak.unzip(labels)], axis=1)
    labels = ak.concatenate([labels_W[..., np.newaxis],
                             labels_top1[..., np.newaxis],
                             labels_top2[..., np.newaxis],
                             labels_other[..., np.newaxis]],axis=2)
    
    labels_id = ak.unflatten(labels_flat, jet_counts)

    has_W = ak.sum(labels_id==24,axis=-1) == 2
    has_top2 = ak.sum(labels_id==6,axis=-1) == 1
    has_top1 = ak.sum(labels_id==-6,axis=-1) == 1
    training_event_filter = has_W & has_top2 & has_top1

    selected_jets_region = selected_jets_region[training_event_filter]
    selected_electrons_region = selected_electrons_region[training_event_filter]
    selected_muons_region = selected_muons_region[training_event_filter]
    labels_all = labels_all[training_event_filter]
    
    return selected_jets_region, selected_electrons_region, selected_muons_region, labels_all

In [3]:
events_train = NanoEventsFactory.from_root(
    "https://xrootd-local.unl.edu:1094//store/user/AGC/nanoAOD/TT_TuneCUETP8M1_13TeV-powheg-pythia8/cmsopendata2015_ttbar_19980_PU25nsData2015v1_76X_mcRun2_asymptotic_v12_ext3-v1_00000_0004.root", 
    treepath="Events", 
    entry_stop=20000
).events()

events_val = NanoEventsFactory.from_root(
    "https://xrootd-local.unl.edu:1094//store/user/AGC/nanoAOD/TT_TuneCUETP8M1_13TeV-powheg-pythia8/cmsopendata2015_ttbar_19980_PU25nsData2015v1_76X_mcRun2_asymptotic_v12_ext3-v1_00000_0004.root", 
    treepath="Events",
    entry_start=20000,
    entry_stop=25000
).events()



In [4]:
jets_train, electrons_train, muons_train, labels_train = filterEvents(
    events_train.Jet, 
    events_train.Electron, 
    events_train.Muon, 
    events_train.GenPart)
jets_val, electrons_val, muons_val, labels_val = filterEvents(
    events_val.Jet, 
    events_val.Electron, 
    events_val.Muon, 
    events_val.GenPart)

In [5]:
dataset_train = JetPartonAssignmentDataset(
    jets=jets_train, 
    target=labels_train, 
    electrons=electrons_train, 
    muons=muons_train
)

dataset_val = JetPartonAssignmentDataset(
    jets=jets_val, 
    target=labels_val, 
    electrons=electrons_val, 
    muons=muons_val
)

In [6]:
train_loader = DataLoader(dataset_train, 
                          batch_size=32, 
                          collate_fn=dataset_train.collate)
val_loader = DataLoader(dataset_val, 
                        batch_size=32, 
                        collate_fn=dataset_val.collate)

In [7]:
model = SaJa(4,4)
loss = object_wise_cross_entropy
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999))

In [8]:
train_iter = enumerate(trainNet(model, 
                                optimizer, 
                                loss, 
                                train_loader, 
                                val_loader,
                                20, 
                                notebook=True))

In [9]:
for i, result in train_iter:
    print(result.cost)
    torch.save(model, 'run_stats.pyt')

Number of batches: train = 23, val = 6


Epochs:   0%|          | 0/20 [00:00<?, ?it/s, train=start, val=start]

Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 0: train=0.740303, val=0.621655, took 5.2346 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 1: train=0.740303, val=0.621655, took 4.9032 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 2: train=0.740303, val=0.621655, took 5.0663 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 3: train=0.740303, val=0.621655, took 4.923 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 4: train=0.740303, val=0.621655, took 5.2207 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 5: train=0.740303, val=0.621655, took 4.9381 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 6: train=0.740303, val=0.621655, took 4.7013 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 7: train=0.740303, val=0.621655, took 4.7092 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 8: train=0.740303, val=0.621655, took 4.8479 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 9: train=0.740303, val=0.621655, took 4.6425 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 10: train=0.740303, val=0.621655, took 5.0587 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 11: train=0.740303, val=0.621655, took 4.7366 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 12: train=0.740303, val=0.621655, took 4.548 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 13: train=0.740303, val=0.621655, took 4.6814 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 14: train=0.740303, val=0.621655, took 5.0083 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 15: train=0.740303, val=0.621655, took 4.7287 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 16: train=0.740303, val=0.621655, took 4.5402 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 17: train=0.740303, val=0.621655, took 18.48 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 18: train=0.740303, val=0.621655, took 4.9095 s
0.7403028322302777


Training:   0%|          | 0/23 [00:00<?, ?it/s, train=start]

Epoch 19: train=0.740303, val=0.621655, took 4.6746 s
0.7403028322302777
