Skip to content

Commit

Permalink
ENH: support multi-inputs in loaders
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Apr 2, 2024
1 parent a47899b commit 8823faf
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 2 deletions.
11 changes: 10 additions & 1 deletion nitrain/datasets/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __getitem__(self, idx):
class ImageConfig:
def __init__(self, images):
self.values = images
self.ids = None

def __getitem__(self, idx):
return self.values[idx]
Expand Down Expand Up @@ -149,7 +150,15 @@ def _infer_config(x, base_dir=None):
elif isinstance(x[0], dict):
configs = [_infer_config(config, base_dir=base_dir) for config in x]
return ComposeConfig(configs)
# list that is meant to be an array
# list that is meant to be an array or multiple-images
elif isinstance(x[0], list):
if isinstance(x[0][0], ants.ANTsImage):
configs = []
for i in range(len(x[0])):
configs.append(ImageConfig([xx[i] for xx in x]))
return ComposeConfig(configs)
else:
return ArrayConfig(np.array(x))
else:
return ArrayConfig(np.array(x))

Expand Down
9 changes: 8 additions & 1 deletion nitrain/loaders/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,14 @@ def __iter__(self):
for x_batch, y_batch in sampled_batch:

if self.expand_dims is not None:
x_batch = np.array([np.expand_dims(xx.numpy(), self.expand_dims) for xx in x_batch])
if isinstance(x_batch[0], list):
x_batch_return = []
for i in range(len(x_batch[0])):
tmp_x_batch = np.array([np.expand_dims(xx[i].numpy(), self.expand_dims) for xx in x_batch])
x_batch_return.append(tmp_x_batch)
x_batch = x_batch_return
else:
x_batch = np.array([np.expand_dims(xx.numpy(), self.expand_dims) for xx in x_batch])
if 'ANTsImage' in str(type(y[0])):
y_batch = np.array([np.expand_dims(yy.numpy(), self.expand_dims) for yy in y_batch])
else:
Expand Down
13 changes: 13 additions & 0 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,19 @@ def test_image_to_image(self):
x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (4, 256, 256, 1))
self.assertTrue(y_batch.shape == (4, 256, 256, 1))

def test_multi_image_to_image(self):
img = ants.image_read(ants.get_data('r16'))
dataset = datasets.MemoryDataset([[img, img] for _ in range(10)],
[img for _ in range(10)])
loader = loaders.DatasetLoader(dataset,
batch_size=4)

x_batch, y_batch = next(iter(loader))
self.assertTrue(len(x_batch) == 2)
self.assertTrue(x_batch[0].shape == (4, 256, 256, 1))
self.assertTrue(x_batch[1].shape == (4, 256, 256, 1))
self.assertTrue(y_batch.shape == (4, 256, 256, 1))

def test_image_to_image_with_slice_sampler(self):
img = ants.image_read(ants.get_data('mni'))
Expand Down

0 comments on commit 8823faf

Please sign in to comment.