In [None]:
import numpy as np
import matplotlib.pyplot as plt
import h5py
from pathlib import Path

file = "datasets/neutrino/bnb_WithWire_00.h5"

with h5py.File(file, "r") as f:
    def print_structure(name, obj):
        if isinstance(obj, h5py.Dataset):
            print(f"[dataset] {name}, shape={obj.shape}, dtype={obj.dtype}")
        elif isinstance(obj, h5py.Group):
            print(f"[group]   {name}")
    f.visititems(print_structure)

In [None]:
import h5py
import numpy as np
from pathlib import Path
from glob import glob

BASE_DIR = "datasets/neutrino"
PAT_INCL = f"{BASE_DIR}/bnb_WithWire_*.h5"
PAT_NUE = f"{BASE_DIR}/nue_WithWire_*.h5"

TIME_DS = 6
NRMS = 2.0
SIG_HIT_THRESHOLD = 50  # only when store_all=False

W_PER_PLANE = {0:2400, 1:2400, 2:3456}
T_RAW = 6400
T_DS = T_RAW // TIME_DS

OUTS_INCL = {p: f"{BASE_DIR}/data_inclusive_plane{p}.h5" for p in (0,1,2)}
OUTS_NUE = {p: f"{BASE_DIR}/data_nue_plane{p}.h5" for p in (0,1,2)}
OVERWRITE_OUTPUTS = True

def changepoint_index(eid_array):
    change = np.any(eid_array[1:] != eid_array[:-1], axis=1)
    starts = np.concatenate(([0], np.where(change)[0] + 1)).astype(np.int64)
    lengths = np.diff(np.concatenate((starts, [len(eid_array)]))).astype(np.int64)
    eids = [tuple(eid_array[s]) for s in starts]
    return starts, lengths, eids

def build_table_index_by_changes(f, table):
    eid_all = f[f"{table}/event_id"][:]
    s, c, e = changepoint_index(eid_all)
    return s, c, e, {eid:(int(ss), int(cc)) for eid, ss, cc in zip(e, s, c)}

def downsample_time_sum(img_WT, factor=TIME_DS, apply_nb_clip=True):
    trim = img_WT.shape[1] % factor
    if trim: img_WT = img_WT[:, :-trim]
    out = img_WT.reshape(img_WT.shape[0], -1, factor).sum(axis=2)
    if apply_nb_clip:
        adccutoff = 10 * factor / 6
        adcsaturation = 100 * factor / 6
        out = np.where(out < adccutoff, 0.0, out)
        out = np.minimum(out, adcsaturation)
    return out

def create_plane_out(path, plane, t_bins=T_DS, w_bins=None, compression="gzip", overwrite=False):
    if w_bins is None: w_bins = W_PER_PLANE[plane]
    p = Path(path)
    if p.exists():
        if overwrite: p.unlink()
        else:
            with h5py.File(path, "a") as g:
                g.attrs.setdefault("plane", int(plane))
                g.attrs.setdefault("time_downsample", int(TIME_DS))
                g.attrs.setdefault("wire_downsample", 1)
                g.attrs.setdefault("nb_clip", True)
            return
        
    with h5py.File(path, "w") as g:
        g.attrs["plane"] = int(plane)
        g.attrs["time_downsample"] = int(TIME_DS)
        g.attrs["wire_downsample"] = 1
        g.attrs["nb_clip"] = True
        maxshape = (None, t_bins, w_bins)
        chunks = (1, t_bins, w_bins)
        g.create_dataset("image", shape=(0, t_bins, w_bins), maxshape=maxshape, chunks=chunks, dtype="float32", compression=compression)
        g.create_dataset("sigmask", shape=(0, t_bins, w_bins), maxshape=maxshape, chunks=chunks, dtype="uint8", compression=compression)
        g.create_dataset("bkgmask", shape=(0, t_bins, w_bins), maxshape=maxshape, chunks=chunks, dtype="uint8", compression=compression)
        g.create_dataset("event_id", shape=(0, 3), maxshape=(None, 3), chunks=(1024,3), dtype="int32", compression=compression)
        g.create_dataset("event_idx_in_file", shape=(0,), maxshape=(None,), chunks=(1024,), dtype="int64", compression=compression)
        g.create_dataset("source_file", shape=(0,), maxshape=(None,), chunks=(1024,), dtype=h5py.string_dtype(encoding="utf-8"), compression=compression)

def append_plane_record(path, img_TW, sigmask_TW, bkgmask_TW, eid, evt_idx, src_file):
    with h5py.File(path, "a") as g:
        n = g["image"].shape[0]
        for k, arr in (("image", img_TW.astype(np.float32)),
                       ("sigmask", sigmask_TW.astype(np.uint8)),
                       ("bkgmask", bkgmask_TW.astype(np.uint8))):
            g[k].resize((n+1,) + g[k].shape[1:])
            g[k][n] = arr
        for name, val in (("event_id", np.asarray(eid, dtype=np.int32)),
                          ("event_idx_in_file", np.int64(evt_idx)),
                          ("source_file", str(src_file))):
            g[name].resize((n+1,) + g[name].shape[1:])
            g[name][n] = val

def dedupe_edep_max_energyfraction(hit_id_e, ef_e, g4_e):
    idx = np.argsort(-ef_e, kind="mergesort")
    hi = hit_id_e[idx]
    g4 = g4_e[idx]
    _, first = np.unique(hi, return_index=True)
    return {int(hi[i]): int(g4[i]) for i in first}

def build_label_notebook_exact(plane_wire, h_wire, h_time, h_rms, h_g4, T_raw=T_RAW, time_ds=TIME_DS, nrms=NRMS):
    W = len(plane_wire)
    lab_full = np.zeros((W, T_raw), dtype=np.int8)
    wire_to_row = {int(w): i for i, w in enumerate(plane_wire)}
    x = np.fromiter((wire_to_row[int(w)] for w in h_wire), dtype=np.int64, count=h_wire.size)
    t0 = np.floor(h_time - nrms*h_rms).astype(np.int64)
    t1 = np.ceil (h_time + nrms*h_rms).astype(np.int64)
    t0 = np.clip(t0, 0, T_raw-1); t1 = np.clip(t1, 0, T_raw-1)
    # neutrino first (+1)
    for xi, lo, hi, is_nu in zip(x, t0, t1, (h_g4 >= 0)):
        if is_nu: lab_full[int(xi), int(lo):int(hi)+1] = 1
    # cosmic (−1) overrides
    for xi, lo, hi, is_cos in zip(x, t0, t1, (h_g4 < 0)):
        if is_cos: lab_full[int(xi), int(lo):int(hi)+1] = -1
    trim = T_raw % time_ds
    if trim: lab_full = lab_full[:, :-trim]
    lab_ds_WT = lab_full.reshape(W, -1, time_ds).sum(axis=2)
    lab_ds_WT = np.sign(lab_ds_WT).astype(np.int8)
    return lab_ds_WT.T  # shape (T_ds, W)

def process_raw_file_wirewise(raw_path, out_paths, sig_hit_threshold=SIG_HIT_THRESHOLD, nrms=NRMS, store_all=False):
    print(f"\n>> Processing: {raw_path}")

    with h5py.File(raw_path, "r") as f:
        w_starts, w_counts, w_eids, _ = build_table_index_by_changes(f, "wire_table")
        _, _, _, h_map = build_table_index_by_changes(f, "hit_table")
        if "edep_table/event_id" in f:
            _, _, _, e_map = build_table_index_by_changes(f, "edep_table")
        else:
            e_map = {}

        kept = {0:0, 1:0, 2:0}
        for evt_idx, eid in enumerate(w_eids):
            ws, wc = int(w_starts[evt_idx]), int(w_counts[evt_idx])
            lp = f["wire_table/local_plane"][ws:ws+wc, 0]
            lw = f["wire_table/local_wire"][ws:ws+wc, 0].astype(np.int64)
            adc = f["wire_table/adc"][ws:ws+wc, :]

            if eid in h_map:
                hs, hc = h_map[eid]
                hit_id = f["hit_table/hit_id"][hs:hs+hc, 0]
                hit_lp = f["hit_table/local_plane"][hs:hs+hc, 0]
                hit_lw = f["hit_table/local_wire"][hs:hs+hc, 0].astype(np.int64)
                hit_time = f["hit_table/local_time"][hs:hs+hc, 0]
                hit_rms = f["hit_table/rms"][hs:hs+hc, 0]
            else:
                hit_id = hit_lp = hit_lw = hit_time = hit_rms = np.zeros((0,), dtype=np.float32)

            g4_by_hit = {}
            if eid in e_map:
                es, ec = e_map[eid]
                if ec > 0:
                    g4_by_hit = dedupe_edep_max_energyfraction(
                        f["edep_table/hit_id"][es:es+ec, 0],
                        f["edep_table/energy_fraction"][es:es+ec, 0],
                        f["edep_table/g4_id"][es:es+ec, 0],)
            if hit_id.size:
                g4 = np.array([g4_by_hit.get(int(h), -1) for h in hit_id], dtype=np.int32)
            else:
                g4 = np.zeros((0,), dtype=np.int32)

            for p in (0,1,2):
                m_rows = (lp == p)
                if not np.any(m_rows):
                    if store_all:
                        Wp = W_PER_PLANE[p]
                        img_TW = np.zeros((T_DS, Wp), dtype=np.float32)
                        sigmask_TW = np.zeros((T_DS, Wp), dtype=np.uint8)
                        bkgmask_TW = np.zeros((T_DS, Wp), dtype=np.uint8)
                        append_plane_record(out_paths[p], img_TW, sigmask_TW, bkgmask_TW, eid=eid, evt_idx=evt_idx, src_file=raw_path)
                        kept[p] += 1
                    continue

                plane_adc = adc[m_rows]
                plane_wire = lw[m_rows]
                order = np.argsort(plane_wire)
                plane_wire = plane_wire[order]
                plane_adc = plane_adc[order, :]

                img_W_Tds = downsample_time_sum(plane_adc, factor=TIME_DS, apply_nb_clip=True)
                img_TW = img_W_Tds.T

                if hit_id.size:
                    m_hits = (hit_lp == p) & np.isin(hit_lw, plane_wire)
                else:
                    m_hits = np.zeros((0,), dtype=bool)

                if np.any(m_hits):
                    h_wire = hit_lw[m_hits]
                    h_time = hit_time[m_hits]
                    h_rms_ = hit_rms[m_hits]
                    h_g4 = g4[m_hits]
                    nu_count = int(np.count_nonzero(h_g4 >= 0))
                else:
                    h_wire = np.zeros((0,), dtype=np.int64)
                    h_time = np.zeros((0,), dtype=np.float32)
                    h_rms_ = np.zeros((0,), dtype=np.float32)
                    h_g4 = np.zeros((0,), dtype=np.int32)
                    nu_count = 0

                keep = True if store_all else (nu_count >= int(sig_hit_threshold))
                if not keep:
                    continue

                label_TW  = build_label_notebook_exact(plane_wire, h_wire, h_time, h_rms_, h_g4, T_raw=T_RAW, time_ds=TIME_DS, nrms=nrms)
                sigmask_TW = (label_TW > 0).astype(np.uint8)
                bkgmask_TW = (label_TW < 0).astype(np.uint8)

                append_plane_record(out_paths[p], img_TW, sigmask_TW, bkgmask_TW, eid=eid, evt_idx=evt_idx, src_file=raw_path)
                kept[p] += 1

        print(f"Kept per plane: {kept}")
    return True

def run_batch(which="both", overwrite_outputs=OVERWRITE_OUTPUTS, store_all=False, sig_hit_threshold=SIG_HIT_THRESHOLD):
    if which in ("inclusive", "both"):
        for p in (0,1,2):
            create_plane_out(OUTS_INCL[p], plane=p, overwrite=overwrite_outputs)
    if which in ("nue", "both"):
        for p in (0,1,2):
            create_plane_out(OUTS_NUE[p], plane=p, overwrite=overwrite_outputs)

    incl_files = sorted(glob(PAT_INCL)) if which in ("inclusive", "both") else []
    nue_files = sorted(glob(PAT_NUE)) if which in ("nue", "both") else []
    print(f"there are {len(incl_files)} inclusive files and {len(nue_files)} nue files (which='{which}').")

    for raw_path in incl_files:
        process_raw_file_wirewise(raw_path, OUTS_INCL, sig_hit_threshold=sig_hit_threshold, nrms=NRMS, store_all=store_all)

    for raw_path in nue_files:
        process_raw_file_wirewise(raw_path, OUTS_NUE, sig_hit_threshold=sig_hit_threshold, nrms=NRMS, store_all=store_all)

    print("\nall done")
    if which in ("inclusive", "both"):
        print("inclusive outputs:", OUTS_INCL)
    if which in ("nue", "both"):
        print("nue outputs      :", OUTS_NUE)

In [None]:
#run_batch(which="inclusive", overwrite_outputs=True, store_all=False, sig_hit_threshold=50)
run_batch(which="inclusive", overwrite_outputs=True, store_all=True)

In [None]:
run_batch(which="nue", overwrite_outputs=True, store_all=True)

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt

def plot_plane_sample(plane_file, idx=0, p='raw', cmap='viridis',
                      jet_vmin=0.0, jet_vmax=100.0,
                      point_size=6, alpha=0.6,
                      figsize=(15,6), show_colorbar=True):
    with h5py.File(plane_file, "r") as g:
        N = g["image"].shape[0]
        if not (0 <= idx < N):
            raise IndexError(f"idx out of [0,{N-1}]")
        
        img = g["image"][idx]
        sigm = g["sigmask"][idx]
        bkgm = g["bkgmask"][idx]
        eid = tuple(g["event_id"][idx])
        src = g["source_file"][idx]

        if isinstance(src, (bytes, bytearray)): src = src.decode("utf-8", errors="ignore")

        evt_ord = int(g["event_idx_in_file"][idx])
        plane = int(g.attrs["plane"])
        time_ds = int(g.attrs["time_downsample"])

    T, W = img.shape
    print(f"Loaded plane={plane}, idx={idx} | eid={eid} | src={src} | event_idx_in_file={evt_ord}")
    print(f"  labeled pixels: sig={int(sigm.sum())}, bkg={int(bkgm.sum())}")

    fig, ax = plt.subplots(figsize=figsize)

    if p == 'raw':
        im = ax.imshow(img, origin='lower', aspect='auto', cmap=cmap, vmin=jet_vmin, vmax=jet_vmax)
        #if show_colorbar:
            #plt.colorbar(im, ax=ax).set_label("ADC sum (downsampled)")
    else:
        # keep axes identical to raw
        ax.set_xlim(0, W-1)
        ax.set_ylim(0, T-1)
        if p in ('sig', 'sigbkg'):
            y_sig, x_sig = np.where(sigm > 0)
            if x_sig.size:
                ax.scatter(x_sig, y_sig, s=point_size, marker='.', linewidths=0, alpha=alpha, color='tab:red', label='signal')
        if p in ('bkg', 'sigbkg'):
            y_bkg, x_bkg = np.where(bkgm > 0)
            if x_bkg.size:
                ax.scatter(x_bkg, y_bkg, s=point_size, marker='.', linewidths=0, alpha=alpha, color='tab:blue', label='background')
        if ax.get_legend_handles_labels()[0]:
            ax.legend(loc='upper right')

    ax.set_title(f"plane={plane} eid={eid} (time downsample x{time_ds})")
    ax.set_xlabel("wire")
    ax.set_ylabel("time (downsampled ticks)")
    plt.tight_layout(); plt.show()

In [None]:
#plot_plane_sample("datasets/neutrino/data_inclusive_plane2.h5", idx=0, p='raw',   jet_vmin=0, jet_vmax=100, cmap='gist_ncar')
#plot_plane_sample("datasets/neutrino/data_inclusive_plane2.h5", idx=0, p='sig', point_size=3)
#plot_plane_sample("datasets/neutrino/data_inclusive_plane2.h5", idx=0, p='bkg', point_size=3)
#plot_plane_sample("datasets/neutrino/data_inclusive_plane2.h5", idx=0, p='sigbkg', point_size=3)

for i in range(10):
    #plot_plane_sample("datasets/neutrino/data_inclusive_plane2.h5", idx=i, p='raw',   jet_vmin=0, jet_vmax=100, cmap='gist_ncar')
    plot_plane_sample("datasets/neutrino/data_inclusive_plane2.h5", idx=i+20, p='sigbkg', point_size=3)