In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from loader.dataset import DUNEImageDataset
from loader.apa_dataset import APAImageDataset
from loader.splits import train_val_split

## Test DUNECVN dataset class

In [None]:
rootdir = "/nfs/data/1/rrazakami/work/data_cvn/data/dune/2023_trainings/latest/dunevd"

label_tokens = [ "numu", "nue", "nutau", "NC" ]  # adapt if different

ds = DUNEImageDataset(
    rootdir=rootdir,
    class_names=label_tokens,
    view_index=2,        # e.g. Z plane
    use_cache=True
)

print("Dataset size:", len(ds))

In [None]:
idx = 2000000  # try a few different indices later
x, y = ds[idx]

print("Index:", idx)
print("Label (int):", y)
print("Label (name):", ds.idx_to_class[y] if hasattr(ds, "idx_to_class") else y)
print("Image shape:", x.shape)
print("Image dtype:", x.dtype)
print("Min / Max:", x.min().item(), x.max().item())


In [None]:
# Remove channel dim for plotting
img = x.squeeze(0).cpu().numpy()

plt.figure(figsize=(6, 6))
plt.imshow(img.T, interpolation="none", cmap="twilight")
plt.colorbar(label="Normalized ADC")
plt.title(f"Sample {idx} | class = {ds.idx_to_class[y]}")
plt.xlabel("Wire")
plt.ylabel("Time tick")
plt.tight_layout()
plt.show()


In [None]:
from warpconvnet.geometry.types.voxels import Voxels

xbatch = x.unsqueeze(0)        # shape -> (1,1,960,1500) ADD BATCH DIM
vox = Voxels.from_dense(xbatch)

# --- extract voxel data ---
coords = vox.batch_indexed_coordinates.cpu().numpy()        # (N, 3)
feats  = vox.batched_features.batched_tensor.cpu().numpy()  # (N, 1)

print("coords shape:", coords.shape)    # expect (N, 3)
print("feats shape:", feats.shape)      # expect (N, 1)

# drop batch index -> 2D image coords
y = coords[:, 1]
x = coords[:, 2]

# scatter plot
plt.figure(figsize=(6, 6))
plt.scatter(
    x, y,
    c=feats[:, 0],
    s=4,
    cmap="viridis",
    linewidths=0,
)
plt.gca().invert_yaxis()   # matches image convention
plt.xlabel("x")
plt.ylabel("y")
plt.title("Replotted image from Voxels (2D)")
plt.colorbar(label="value")
plt.tight_layout()
plt.show()

### Test training/validation split

In [None]:
train_ds, val_ds, train_idx, val_idx = train_val_split(ds, val_fraction=0.2, seed=42, use_cache=True)

In [None]:
print(len(ds), len(train_ds) + len(val_ds))
print(len(train_ds), len(train_ds) / len(ds))
print(len(val_ds), len(val_ds) / len(ds))

In [None]:
num_classes = len(ds.class_names)

full_counts  = torch.bincount(ds.labels, minlength=num_classes)
train_counts = torch.bincount(ds.labels[train_idx], minlength=num_classes)
val_counts   = torch.bincount(ds.labels[val_idx], minlength=num_classes)

In [None]:
print("Full:", full_counts, full_counts.float() / full_counts.sum())
print("Train:", train_counts, train_counts.float() / train_counts.sum())
print("Val:", val_counts, val_counts.float() / val_counts.sum())

## Test APA custom dataset class

In [None]:
rootdir = "/nfs/data/1/mvicenzi/apa-test-data/gzip2"

ds = APAImageDataset(
    rootdir=rootdir,
    apa=0,
    view="W",        # "U", "V", or "W"
    use_cache=True,
    cache_dir="./data",
)

print("Dataset size:", len(ds))

In [None]:
idx = 100  # try a few different indices
x = ds[idx]

print("Index:", idx)
print("Image shape:", x.shape)
print("Image dtype:", x.dtype)
print("Min / Max:", x.min().item(), x.max().item())


In [None]:
# Remove channel dim for plotting
img = x.squeeze(0).cpu().numpy()  # (ticks, wires)

plt.figure(figsize=(8, 5))
plt.imshow(
    img.T, # makes it (wires, ticks)
    interpolation="none",
    #aspect="auto",
    cmap="twilight",
)
plt.colorbar(label="ADC")
plt.title(f"APA {ds.apa} | View {ds.view} | Sample {idx}")
plt.xlabel("Wire")
plt.ylabel("Time tick")
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(10, 5), sharey=True)

for ax, view in zip(axes, ["U", "V", "W"]):
    ds_view = APAImageDataset(rootdir, apa=0, view=view, use_cache=True)
    img = ds_view[idx].squeeze(0).cpu().numpy()

    im = ax.imshow(img.T, cmap="twilight")
    ax.set_title(f"View {view}")
    ax.set_xlabel("Wire")

axes[0].set_ylabel("Time tick")
#fig.colorbar(im, ax=axes, label="ADC")
plt.tight_layout()
plt.show()


In [None]:
from warpconvnet.geometry.types.voxels import Voxels

idx = 100  # try a few different indices
x = ds[idx]

print("Index:", idx)
print("Image shape:", x.shape)
print("Image dtype:", x.dtype)
print("Min / Max:", x.min().item(), x.max().item())

dense = x.unsqueeze(0)        # shape -> (1,1,960,1500)
vox = Voxels.from_dense(dense)

print("\nVoxels object")
print("Type:", type(vox))

# quick introspection: try to print some attributes if available
for attr in ("batched_coordinates", "batched_features", "offsets"):
     if hasattr(vox, attr):
        v = getattr(vox, attr)
        try:
            print(f"{attr}: type={type(v)}, shape={v.shape}, dtype={v.dtype}")
        except Exception:
            print(f"{attr}: {v}")


            

In [None]:
# extract coords + values
coords = vox.batch_indexed_coordinates.cpu().numpy()   # (N, 3)
vals   = vox.batched_features.batched_tensor.cpu().numpy()[:, 0]  # (N,)

# drop batch index -> 2D image coords
y = coords[:, 1]
x = coords[:, 2]

# scatter plot
plt.figure(figsize=(6, 6))
plt.scatter(
    x, y,
    c=vals,
    s=4,
    cmap="viridis",
    linewidths=0,
)
plt.gca().invert_yaxis()   # matches image convention
plt.xlabel("x")
plt.ylabel("y")
plt.title("Replotted image from Voxels (2D)")
plt.colorbar(label="value")
plt.tight_layout()
plt.show()

