# DataLoader Benchmark

Measure wall-clock time for one full epoch of data loading (no model forward pass).
Uses the pre-decimated h5 files for fast I/O.

In [None]:
import time
import numpy as np
from dataset_ecei_tcn import ECEiTCNDataset, create_loaders

ROOT           = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt'
DECIMATED_ROOT = '/global/cfs/cdirs/m5187/proj-share/ECEi_excerpt/dsrpt_decimated'

In [None]:
ds = ECEiTCNDataset(
    root            = ROOT,
    decimated_root  = DECIMATED_ROOT,
    Twarn           = 300_000,
    baseline_length = 40_000,
    data_step       = 10,
    nsub            = 781_250,    # ~781 ms (matches disruptcnn)
    stride          = 481_260,    # overlap by receptive field
    normalize       = True,
)
ds.summary()

In [None]:
BATCH_SIZE  = 8
NUM_WORKERS = 4

loaders = create_loaders(ds, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS)

for name, loader in loaders.items():
    print(f'{name:>5s}: {len(loader.dataset):5d} subseqs, '
          f'{len(loader):4d} batches (bs={BATCH_SIZE})')

In [None]:
def benchmark_epoch(loader, name=''):
    """Iterate one full epoch, return timing stats."""
    n_batches = 0
    n_samples = 0
    total_pos = 0
    total_neg = 0

    t0 = time.perf_counter()
    for X, target, weight in loader:
        n_batches += 1
        n_samples += X.shape[0]
        total_pos += (target == 1).sum().item()
        total_neg += (target == 0).sum().item()
    elapsed = time.perf_counter() - t0

    print(f'[{name}]  {n_batches} batches, {n_samples} samples')
    print(f'  Elapsed : {elapsed:.2f} s')
    print(f'  Per batch : {elapsed/n_batches*1e3:.1f} ms')
    print(f'  Per sample: {elapsed/n_samples*1e3:.1f} ms')
    print(f'  Throughput: {n_samples/elapsed:.1f} samples/s')
    frac = total_pos / (total_pos + total_neg) if (total_pos + total_neg) > 0 else 0
    print(f'  Label 1 fraction: {frac:.3f}')
    return elapsed

In [None]:
results = {}
for name, loader in loaders.items():
    results[name] = benchmark_epoch(loader, name=name)
    print()

In [None]:
import matplotlib.pyplot as plt

names = list(results.keys())
times = [results[n] for n in names]

fig, ax = plt.subplots(figsize=(6, 3))
bars = ax.barh(names, times, color='steelblue')
for bar, t in zip(bars, times):
    ax.text(bar.get_width() + 0.1, bar.get_y() + bar.get_height() / 2,
            f'{t:.2f} s', va='center', fontweight='bold')
ax.set_xlabel('Epoch time (seconds)')
ax.set_title(f'Data-loading epoch time (bs={BATCH_SIZE}, workers={NUM_WORKERS})')
plt.tight_layout()
plt.show()

In [None]:
# Sanity: print one batch's shapes and dtypes
split_name = 'train' if 'train' in loaders else list(loaders.keys())[0]
X, target, weight = next(iter(loaders[split_name]))
print(f'X      : {X.shape}  {X.dtype}')
print(f'target : {target.shape}  {target.dtype}  unique={target.unique().tolist()}')
print(f'weight : {weight.shape}  {weight.dtype}  unique={weight.unique().tolist()}')