Skip to content

Commit

Permalink
add shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Apr 8, 2024
1 parent cad8c93 commit 9bb235a
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nitrain/loaders/dataset_loader.py
Expand Up @@ -31,6 +31,7 @@ def __init__(self,
self.x_transforms = x_transforms
self.y_transforms = y_transforms
self.co_transforms = co_transforms
self.shuffle = shuffle

if sampler is None:
sampler = samplers.BaseSampler(sub_batch_size=batch_size)
Expand Down Expand Up @@ -101,6 +102,7 @@ def __iter__(self):

if self.expand_dims is not None:
if isinstance(x_batch[0], list):
print('got multiple inputs')
x_batch_return = []
for i in range(len(x_batch[0])):
tmp_x_batch = np.array([np.expand_dims(xx[i].numpy(), self.expand_dims) for xx in x_batch])
Expand Down

0 comments on commit 9bb235a

Please sign in to comment.