In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from loader.dataset import DUNEImageDataset
from loader.splits import train_val_split

## Test custom dataset class

In [None]:
rootdir = "/nfs/data/1/rrazakami/work/data_cvn/data/dune/2023_trainings/latest/dunevd"

label_tokens = [ "numu", "nue", "nutau", "NC" ]  # adapt if different

ds = DUNEImageDataset(
    rootdir=rootdir,
    class_names=label_tokens,
    view_index=2,        # e.g. Z plane
    use_cache=True
)

print("Dataset size:", len(ds))

In [None]:
idx = 2000000  # try a few different indices later
x, y = ds[idx]

print("Index:", idx)
print("Label (int):", y)
print("Label (name):", ds.idx_to_class[y] if hasattr(ds, "idx_to_class") else y)
print("Image shape:", x.shape)
print("Image dtype:", x.dtype)
print("Min / Max:", x.min().item(), x.max().item())


In [None]:
# Remove channel dim for plotting
img = x.squeeze(0).cpu().numpy()

plt.figure(figsize=(6, 6))
plt.imshow(img.T, interpolation="none", cmap="twilight")
plt.colorbar(label="Normalized ADC")
plt.title(f"Sample {idx} | class = {ds.idx_to_class[y]}")
plt.xlabel("Wire")
plt.ylabel("Time tick")
plt.tight_layout()
plt.show()


## Test training/validation split

In [None]:
train_ds, val_ds, train_idx, val_idx = train_val_split(ds, val_fraction=0.2, seed=42, use_cache=True)

In [None]:
print(len(ds), len(train_ds) + len(val_ds))
print(len(train_ds), len(train_ds) / len(ds))
print(len(val_ds), len(val_ds) / len(ds))

In [None]:
num_classes = len(ds.class_names)

full_counts  = torch.bincount(ds.labels, minlength=num_classes)
train_counts = torch.bincount(ds.labels[train_idx], minlength=num_classes)
val_counts   = torch.bincount(ds.labels[val_idx], minlength=num_classes)

In [None]:
print("Full:", full_counts, full_counts.float() / full_counts.sum())
print("Train:", train_counts, train_counts.float() / train_counts.sum())
print("Val:", val_counts, val_counts.float() / val_counts.sum())