In [None]:
import numpy as np
import uproot
from tqdm import tqdm

In [None]:
rng = np.random.default_rng()

In [None]:
file_name = "df_pi_plus.root"

In [None]:
basename = file_name.split(".")[0]

In [None]:
train_file_name = f"{basename}_train.root"
test_file_name = f"{basename}_test.root"
treename = "df"

In [None]:
tree = uproot.open(f"{file_name}:{treename}")

In [None]:
train_frac = 0.8

In [None]:
train_file = uproot.recreate(train_file_name)
test_file = uproot.recreate(test_file_name)

In [None]:
target_dims = (3279, 116)
mufilter_dims = (3279, 68)

In [None]:
train_file.mktree(
    treename,
    {
        "X": (">f4", target_dims),
        "X_mufilter": (">f4", mufilter_dims),
        "start_x": ">f8",
        "start_y": ">f8",
        "start_z": ">f8",
        "nu_energy": ">f8",
        "hadron_energy": ">f8",
        "lepton_energy": ">f8",
        "energy_dep_target": ">f8",
        "energy_dep_mufilter": ">f8",
        "nu_flavour": ">i8",
        "is_cc": "bool",
    },
    title="Dataframe for CNN studies",
)

In [None]:
test_file.mktree(
    treename,
    {
        "X": (">f4", target_dims),
        "X_mufilter": (">f4", mufilter_dims),
        "start_x": ">f8",
        "start_y": ">f8",
        "start_z": ">f8",
        "nu_energy": ">f8",
        "hadron_energy": ">f8",
        "lepton_energy": ">f8",
        "energy_dep_target": ">f8",
        "energy_dep_mufilter": ">f8",
        "nu_flavour": ">i8",
        "is_cc": "bool",
    },
    title="Dataframe for CNN studies",
)

In [None]:
t = tqdm(total=tree.num_entries)
partition = int(tree.num_entries * train_frac)
for batch, report in tree.iterate(step_size="1MB", library="np", report=True):
    batch_size = report.stop - report.start
    if report.stop <= partition:
        train_file[treename].extend({key: batch[key] for key in batch.keys()})
    elif report.start < partition:
        batch_partition = partition - report.start
        train_file[treename].extend(
            {key: batch[key][:batch_partition] for key in batch.keys()}
        )
        test_file[treename].extend(
            {key: batch[key][batch_partition:] for key in batch.keys()}
        )
    else:
        test_file[treename].extend({key: batch[key] for key in batch.keys()})
    t.update(batch_size)

In [None]:
print(train_file[treename].num_entries)
print(test_file[treename].num_entries)

In [None]:
train_file.close()
test_file.close()