In [30]:
using EDM4hep
using EDM4hep.RootIO
using JetReconstruction
using StaticArrays
using LorentzVectorHEP
using JSON
using ONNXRunTime
using LinearAlgebra
using StructArrays
using UnROOT
using Plots

# Import our custom modules
include("JetFlavourHelper.jl")  # Our main module
include("JetConstituentUtils.jl")  # For constituent feature extraction
include("ReconstructedParticle.jl")  # For handling reconstructed particles



Main.ReconstructedParticle

In [3]:
# Paths to model files
model_dir = "wc_pt_7classes_12_04_2023"
onnx_path = joinpath(model_dir, "fccee_flavtagging_edm4hep_wc_v1.onnx")
json_path = joinpath(model_dir, "fccee_flavtagging_edm4hep_wc_v1.json")

# Load the configuration and model
config = JSON.parsefile(json_path)
model = ONNXRunTime.load_inference(onnx_path)

# Display the output classes we'll predict
println("The model predicts these flavor classes:")
for class_name in config["output_names"]
    println(" - ", class_name)
end

The model predicts these flavor classes:
 - recojet_isG
 - recojet_isQ
 - recojet_isS
 - recojet_isC
 - recojet_isB


In [5]:
# Path to ROOT file with EDM4hep data
edm4hep_path = "events_080263084.root"
reader = RootIO.Reader(edm4hep_path)

# Get event information
events = RootIO.get(reader, "events")
println("Loaded $(length(events)) events")

# Choose a specific event to analyze (event #13)
event_id = 13
evt = events[event_id]
println("Processing event #$event_id")

# Get reconstructed particles and tracks
recps = RootIO.get(reader, evt, "ReconstructedParticles")
tracks = RootIO.get(reader, evt, "EFlowTrack_1")

# Set the magnetic field strength (in Tesla)
bz = 2.0  # This should be obtained from your data if possible

# Get other needed collections for feature extraction
trackdata = RootIO.get(reader, evt, "EFlowTrack")
trackerhits = RootIO.get(reader, evt, "TrackerHits")
gammadata = RootIO.get(reader, evt, "EFlowPhoton")
nhdata = RootIO.get(reader, evt, "EFlowNeutralHadron")
calohits = RootIO.get(reader, evt, "CalorimeterHits")
dNdx = RootIO.get(reader, evt, "EFlowTrack_2")

# Load track length information from the ROOT file
# Note: This is a placeholder. You need to replace this with the loading from EDM4hep.jl once I fix the bug in EDM4hep.
f = ROOTFile(edm4hep_path)
mytree = LazyTree(f, "events", ["EFlowTrack_L"])
track_L = collect(mytree[event_id].EFlowTrack_L)

println("Loaded $(length(recps)) reconstructed particles")
println("Loaded $(length(tracks)) tracks")

Loaded 100000 events
Processing event #13
Loaded 107 reconstructed particles
Loaded 44 tracks


In [15]:
# Find muons in the event
muons = RootIO.get(reader, evt, "Muon#0")
muons_all = ReconstructedParticle.get(muons, recps)
muons_selected = ReconstructedParticle.sel_p(20.0)(muons_all)
# TODO: There is a bug in the sel_p function. 
# When there's only one muon, it should returns a single value instead of a collection but it's buggy.

# Remove high-momentum muons from the input particles
recps_no_muons = ReconstructedParticle.remove(recps, muons_selected)
vrecps = collect(recps_no_muons)

# Cluster jets using the EEkt algorithm with R=2.0 and p=1.0
cs = jet_reconstruct(vrecps; p = 1.0, R = 2.0, algorithm = JetAlgorithm.EEKt)

# Get 2 exclusive jets
jets = exclusive_jets(cs; njets=2, T=EEJet)

# For each jet, get its constituent particles
constituent_indices = [constituent_indexes(jet, cs) for jet in jets]
jet_constituents = JetConstituentUtils.build_constituents_cluster(recps_no_muons, constituent_indices)


2-element Vector{StructVector{EDM4hep.ReconstructedParticle}}:
 [EDM4hep.ReconstructedParticle(#76, 22, 0.1520507f0, (-0.089623965, -0.118329994, 0.032938913), (0.0, 0.0, 0.0), 0.0f0, 3.1475472f-5, 0.0f0, Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Cluster#[32], Track#[], EDM4hep.ReconstructedParticle#[], ParticleID#[], #0, #0), EDM4hep.ReconstructedParticle(#75, 22, 0.38633013f0, (-0.3512648, 0.14457734, 0.07043707), (0.0, 0.0, 0.0), 0.0f0, 9.40061f-5, 0.0f0, Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Cluster#[31], Track#[], EDM4hep.ReconstructedParticle#[], ParticleID#[], #0, #0), EDM4hep.ReconstructedParticle(#78, 22, 0.049976002f0, (-0.041344654, -0.018857136, 0.020799734), (0.0, 0.0, 0.0), 0.0f0, -9.3014905f-6, 0.0f0, Float32[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], Cluster#[34], Track#[], EDM4hep.ReconstructedParticle#[], ParticleID#[], #0, #0), EDM4hep.ReconstructedParticle(#45, 22, 0.13387372f0, (-0.039546303, 0.03148594, -0.1239633), (0

In [35]:
println("Extracting features for flavor tagging...")
feature_data = JetFlavourHelper.extract_features(
    jets, 
    jet_constituents, 
    tracks, 
    bz, 
    track_L, 
    trackdata, 
    trackerhits, 
    gammadata, 
    nhdata, 
    calohits, 
    dNdx
)

model, config = JetFlavourHelper.setup_weaver(
    onnx_path,
    json_path
)

input_tensors = JetFlavourHelper.prepare_input_tensor(
    jet_constituents,
    jets,
    config,
    feature_data
)

println("Running flavor tagging inference...")
weights = JetFlavourHelper.get_weights(
    0,  # Thread slot
    feature_data,
    jets,
    jet_constituents,
    config,
    model
)

jet_scores = Dict{String, Vector{Float32}}()
for (i, score_name) in enumerate(config["output_names"])
    jet_scores[score_name] = JetFlavourHelper.get_weight(weights, i-1)
end

println("Jet scores:")
for (name, scores) in jet_scores
    println(" - $name: $(scores[1])")
end

Extracting features for flavor tagging...
Running flavor tagging inference...
Jet scores:
 - recojet_isG: 0.63997483
 - recojet_isB: 0.13061157
 - recojet_isQ: 0.05580642
 - recojet_isC: 0.030684045
 - recojet_isS: 0.14292312
