From 9bb235a5f23faa859843d66bde433f11af91f617 Mon Sep 17 00:00:00 2001 From: ncullen93 Date: Mon, 8 Apr 2024 10:32:30 +0200 Subject: [PATCH] add shuffle --- nitrain/loaders/dataset_loader.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nitrain/loaders/dataset_loader.py b/nitrain/loaders/dataset_loader.py index c0dbc22..7fda509 100644 --- a/nitrain/loaders/dataset_loader.py +++ b/nitrain/loaders/dataset_loader.py @@ -31,6 +31,7 @@ def __init__(self, self.x_transforms = x_transforms self.y_transforms = y_transforms self.co_transforms = co_transforms + self.shuffle = shuffle if sampler is None: sampler = samplers.BaseSampler(sub_batch_size=batch_size) @@ -101,6 +102,7 @@ def __iter__(self): if self.expand_dims is not None: if isinstance(x_batch[0], list): + print('got multiple inputs') x_batch_return = [] for i in range(len(x_batch[0])): tmp_x_batch = np.array([np.expand_dims(xx[i].numpy(), self.expand_dims) for xx in x_batch])