In [None]:
import h5py
import torch
from pathlib import Path
from warpconvnet.geometry.types.voxels import Voxels
import numpy as np

FRAME_NAME = "frame_rebinned_reco"
VIEW_RANGES = {
    "U": (0, 800),
    "V": (800, 1600),
    "W": (1600, 2650),
}

In [None]:
inpath = "../apa-test-data/gzip2/out_monte-carlo-011984-000100_279569_164_1_20260116T164131Z/" \
        "monte-carlo-011984-000100_279569_164_1_20260116T164131Z_g4-tru-anode0.h5"
outpath = "../apa-test-data/pickle/out_monte-carlo-011984-000100_279569_164_1_20260116T164131Z/" \
        "monte-carlo-011984-000100_279569_164_1_20260116T164131Z_g4-tru-anode0_sparse.pt"

In [None]:
def debug_print_voxels(vox, prefix="", full=False):
    print(f"{prefix}Voxels summary")
    print("Type:", type(vox))
    print("Batch size:", vox.batch_size)
    print("Num spatial dims:", vox.num_spatial_dims)
    print("Num channels:", vox.num_channels)
    print("Num active voxels:", vox.coordinate_tensor.shape[0])

    # --- core tensors ---
    print("Core tensors:")
    print("  coordinate_tensor:",
          vox.coordinate_tensor.shape,
          vox.coordinate_tensor.dtype,
          vox.coordinate_tensor.device)

    print("  feature_tensor:",
          vox.feature_tensor.shape,
          vox.feature_tensor.dtype,
          vox.feature_tensor.device)

    print("  offsets:",
          vox.offsets.shape,
          vox.offsets.dtype,
          vox.offsets.device)

    if full:
      # --- batch indexed coords ---
      bic = vox.batch_indexed_coordinates
      print("Batch-indexed coordinates:")
      print("  batch_indexed_coordinates:",
            bic.shape,
            bic.dtype)

      # --- extras / metadata ---
      print("\nExtras:")
      extras = vox.extra_attributes
      if not extras:
            print("  (none)")
      else:
            for k, v in extras.items():
                  if torch.is_tensor(v):
                        print(f"  {k}: tensor shape={v.shape}, dtype={v.dtype}")
                  else:
                        print(f"  {k}: {v}")

      # --- derived properties ---
      print("\nDerived properties:")
      print("  tensor_stride:", vox.tensor_stride)
      print("  voxel_size:", vox.voxel_size)
      print("  ordering:", vox.ordering)


In [None]:
ch_start, ch_end = VIEW_RANGES["W"]
h5_path = Path(inpath)

sparse_groups = {}

with h5py.File(h5_path, "r") as f:
    for group in f.keys():
        if FRAME_NAME not in f[group]:
            continue

        print(f"group: {group}")

        frame = f[group][FRAME_NAME][()]  # (channels, ticks)

        x = torch.from_numpy(frame).float()
        x = x.unsqueeze(0).unsqueeze(0)  # (B=1, C=1, H, W)

        vox = Voxels.from_dense(x)

        if group == "1":
            debug_print_voxels(vox, prefix=f"[{group}] ")

        sparse_groups[group] = {
            "coords": vox.coordinate_tensor.cpu(),
            "features": vox.feature_tensor.cpu(),
            "offsets": vox.offsets.cpu()
        }

# create output directory if it doesn't exist
out_dir = Path(outpath).parent
out_dir.mkdir(parents=True, exist_ok=True)
torch.save(sparse_groups, outpath, pickle_protocol=5)

print(f"Saved sparse file: {outpath}")

In [None]:
from warpconvnet.geometry.types.voxels import Voxels
from warpconvnet.geometry.coords.integer import IntCoords
from warpconvnet.geometry.features.cat import CatFeatures

from warpconvnet.geometry.types.voxels import Voxels
from warpconvnet.geometry.coords.integer import IntCoords
from warpconvnet.geometry.features.cat import CatFeatures

sparse_groups = torch.load(outpath, map_location="cpu",weights_only=False)
voxels_per_group = {}

for group, data in sparse_groups.items():
    coords   = data["coords"]
    feats    = data["features"]
    offsets  = data["offsets"]

    coords_b = IntCoords(coords, offsets=offsets)
    feats_b  = CatFeatures(feats, offsets=offsets)

    vox = Voxels(
        batched_coordinates=coords_b,
        batched_features=feats_b,
        offsets=offsets
    )

    voxels_per_group[group] = vox

    if group == "1":
        debug_print_voxels(vox, prefix=f"[reconstructed {group}] ")

In [None]:
def split_view_sparse(vox: Voxels, view: str) -> Voxels:
    lo, hi = VIEW_RANGES[view]
    mask = (vox.coordinate_tensor[:,0] >= lo) & (vox.coordinate_tensor[:,0] < hi)
    coords = vox.coordinate_tensor[mask].clone()
    feats  = vox.feature_tensor[mask]
    coords[:,0] -= lo  # rebase channel coordinates to 0 for this view
    offsets = torch.tensor([0, coords.shape[0]], dtype=torch.int64)
    return Voxels(IntCoords(coords, offsets), CatFeatures(feats, offsets), offsets)

In [None]:
import matplotlib.pyplot as plt

# --- open HDF5 ---
h5_path = Path(inpath)
with h5py.File(h5_path, "r") as f:

    group = "1"
    print(f"Group: {group}")

    # original dense image
    frame = f[group][FRAME_NAME][()]  # shape (channels, ticks)

    # reconstructed sparse
    data = sparse_groups[group]
    coords   = data["coords"]
    feats    = data["features"]
    offsets  = data["offsets"]

    vox = Voxels(
        batched_coordinates=IntCoords(coords, offsets=offsets),
        batched_features=CatFeatures(feats, offsets=offsets),
        offsets=offsets
    )

    # --- plot all views ---
    fig, axes = plt.subplots(1, 3, figsize=(10, 5), sharey=True)

    ### DIRECT DENSE PLOTTING
    for ax, view in zip(axes, ["U", "V", "W"]):
        lo, hi = VIEW_RANGES[view]
        img_dense = frame[lo:hi, :]
        ax.imshow(img_dense.T, cmap="twilight", origin="lower")
        ax.set_title(f"Dense {view}")
        ax.set_xlabel("Wire")

    axes[0].set_ylabel("Time tick")
    plt.tight_layout()
    plt.show()

    ### SPARSE PLOTTING ALL VIEWS TOGETHER (TENSORS AS SAVED)
    fig, ax = plt.subplots(figsize=(15, 5))
        
    coords_np = vox.coordinate_tensor.cpu().numpy()
    vals_np   = vox.feature_tensor.cpu().numpy()[:, 0]
    y = coords_np[:, 1]
    x = coords_np[:, 0]

    ax.scatter(x, y, c=vals_np, s=4, cmap="viridis", linewidths=0)
    ax.axvline(x=800, color='red', linestyle='--')
    ax.axvline(x=1600, color='red', linestyle='--')
    ax.set_xlim(0, 2650)
    ax.set_ylim(0, 1500)
    #ax.invert_yaxis()
    ax.set_title(f"Sparse [ALL VIEWS]")
    ax.set_xlabel("Wire")
    ax.set_ylabel("Time tick")
    plt.tight_layout()
    plt.show()

    ### SPARSE PLOTTING SPLIT VIEWS
    fig, axes = plt.subplots(1, 3, figsize=(10, 5), sharey=True)
    for ax, view in zip(axes, ["U", "V", "W"]):
        vox_view = split_view_sparse(vox, view)

        # scatter plot from sparse coords
        coords_np = vox_view.coordinate_tensor.cpu().numpy()
        vals_np   = vox_view.feature_tensor.cpu().numpy()[:, 0]

        y = coords_np[:, 1]
        x = coords_np[:, 0]

        ax.scatter(x, y, c=vals_np, s=4, cmap="viridis", linewidths=0)
        #ax.invert_yaxis()
        ax.set_title(f"Sparse {view}")
        ax.set_xlabel("Wire")
        ax.set_xlim(0, 800)
        if view == "W":
            ax.set_xlim(0, 1050)
        ax.set_ylim(0, 1500)

