In [None]:
import ROOT

In [None]:
from tqdm import tqdm

In [None]:
import hist
import matplotlib.pyplot as plt
import scienceplots

plt.style.use(["science", "notebook"])
plt.rcParams["font.size"] = 14
plt.rcParams["axes.formatter.limits"] = -5, 4
plt.rcParams["figure.figsize"] = 6, 4
colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

In [None]:
h_xy = (
    hist.Hist.new.Regular(50, -22, 26, name=r"x")
    .Regular(50, -21, +28, name=r"y")
    .Double()
)

In [None]:
h_zx = (
    hist.Hist.new.Regular(50, -200, -80, name=r"z")
    .Regular(50, -22, +26, name=r"x")
    .Double()
)

In [None]:
import numpy as np

In [None]:
import pandas as pd

In [None]:
n_files = 1000

In [None]:
tree = ROOT.TChain("cbmsim")
for infile in [
    f"root://eospublic.cern.ch//eos/experiment/sndlhc/users/olantwin/advsnd/2024/07/nu12/CCDIS/{i + 1}/sndLHC.Genie-TGeant4_dig.root"
    for i in range(n_files)
]:
    tree.AddFile(infile)

In [None]:
ntuple = []

In [None]:
for event in tqdm(tree, "Event loop: ", total=n_files * 100):
    start_z = None
    nu_energy = None
    target_stations = {}
    mufilter_stations = {}
    for hit in event.Digi_AdvTargetHits:
        detID = hit.GetDetectorID()
        station = hit.GetStation()
        if station not in target_stations:
            target_stations[station] = [
                detID,
            ]
        else:
            target_stations[station].append(detID)

    for hit in event.Digi_AdvMuFilterHits:
        detID = hit.GetDetectorID()
        station = hit.GetStation()
        if station not in mufilter_stations:
            mufilter_stations[station] = [
                detID,
            ]
        else:
            mufilter_stations[station].append(detID)
    for track_id, track in enumerate(event.MCTrack):
        match track.GetMotherId():
            case -1:  # neutrino can't be part of shower
                nu_energy = track.GetEnergy()
            case 0:  # shower initiating electron
                start_z = track.GetStartZ()
                h_xy.fill(track.GetStartX(), track.GetStartY())
                h_zx.fill(track.GetStartZ(), track.GetStartX())
    energy_dep_mufilter = 0
    for point in event.AdvMuFilterPoint:
        energy_dep_mufilter += point.GetEnergyLoss()
        # print(point.GetEnergyLoss())
    energy_dep_target = 0
    for point in event.AdvTargetPoint:
        energy_dep_target += point.GetEnergyLoss()
    target_strip_dict = {
        f"target_n_hits_station_{i}": len(target_stations[i])
        if i in target_stations
        else 0
        for i in range(100)
    }
    mufilter_strip_dict = {
        f"mufilter_n_hits_station_{i}": len(mufilter_stations[i])
        if i in mufilter_stations
        else 0
        for i in range(20)
    }
    ntuple.append(
        dict(
            {
                "start_z": start_z,
                "nu_energy": nu_energy,
                "energy_dep_target": energy_dep_target,
                "energy_dep_mufilter": energy_dep_mufilter,
                "target_n_hits": len(event.Digi_AdvTargetHits),
                "target_n_stations": len(target_stations),
                "mufilter_n_hits": len(event.Digi_AdvMuFilterHits),
                "mufilter_n_stations": len(mufilter_stations),
            },
            **target_strip_dict,
            **mufilter_strip_dict,
        )
    )

In [None]:
df = pd.DataFrame(ntuple)

In [None]:
df

In [None]:
df.to_csv("features.csv")

In [None]:
h_xy.plot()
plt.xlabel(r"$x\;[\mathrm{cm}]$")
plt.ylabel(r"$y\;[\mathrm{cm}]$")
ax = plt.gca()
plt.text(
    0.8,
    1.02,
    "AdvSND",
    fontweight="bold",
    fontfamily="sans-serif",
    fontsize=16,
    transform=ax.transAxes,
    usetex=False,
)
plt.text(
    0.0,
    1.02,
    "preliminary",
    fontfamily="sans-serif",
    fontsize=16,
    transform=ax.transAxes,
    usetex=False,
)
plt.savefig("plots/h_xy.pdf")
plt.savefig("plots/h_xy.png")

In [None]:
h_zx.plot()
plt.xlabel(r"$z\;[\mathrm{cm}]$")
plt.ylabel(r"$x\;[\mathrm{cm}]$")
ax = plt.gca()
plt.text(
    0.8,
    1.02,
    "AdvSND",
    fontweight="bold",
    fontfamily="sans-serif",
    fontsize=16,
    transform=ax.transAxes,
    usetex=False,
)
plt.text(
    0.0,
    1.02,
    "preliminary",
    fontfamily="sans-serif",
    fontsize=16,
    transform=ax.transAxes,
    usetex=False,
)
plt.savefig("plots/h_zx.pdf")
plt.savefig("plots/h_zx.png")