Skip to content

Commit

Permalink
ENH: support image-to-image loading
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed Apr 2, 2024
1 parent 37c7113 commit 84e00ea
Show file tree
Hide file tree
Showing 8 changed files with 28 additions and 6 deletions.
4 changes: 4 additions & 0 deletions nitrain/loaders/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,12 @@ def __iter__(self):

if self.expand_dims is not None:
x_batch = np.array([np.expand_dims(xx.numpy(), self.expand_dims) for xx in x_batch])
if 'ANTsImage' in str(type(y[0])):
y_batch = np.array([np.expand_dims(yy.numpy(), self.expand_dims) for yy in y_batch])
else:
x_batch = np.array([xx.numpy() for xx in x_batch])
if 'ANTsImage' in str(type(y[0])):
y_batch = np.array([yy.numpy() for yy in y_batch])

yield x_batch, y_batch

Expand Down
1 change: 0 additions & 1 deletion nitrain/models/fetch_architecture.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@

from inspect import getmembers, isfunction


def fetch_architecture(name, dim=None):
"""
Fetch an architecture function based on its name and input image
Expand Down
7 changes: 4 additions & 3 deletions nitrain/models/fetch_pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

def fetch_pretrained(name, cache_dir=None):
"""
Fetch a pretrained model from ANTsPyNet. Pretrained
models can be used to make predictions (inference) on
your data or as a starting point for fine-tuning to help
Fetch a pretrained model.
Pretrained models can be used to make predictions (inference)
on your data or as a starting point for fine-tuning to help
improve model fitting on your data.
Returns
Expand Down
2 changes: 2 additions & 0 deletions nitrain/models/load.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@


def load(path):
import tensorflow as tf
model = tf.keras.models.load_model(path)
Expand Down
5 changes: 4 additions & 1 deletion nitrain/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def __next__(self):
if self.shuffle:
indices = random.sample(range(len(y)), len(y))
x = [x[i] for i in indices]
y = y[indices]
if 'ANTsImage' in str(type(y[0])):
y = [y[i] for i in indices]
else:
y = y[indices]

return x, y
else:
Expand Down
1 change: 1 addition & 0 deletions nitrain/samplers/slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __repr__(self):


def create_slices(images, values, axis):
# TODO: support image-to-image
slices = []
new_values = []
for image, value in zip(images, values):
Expand Down
2 changes: 1 addition & 1 deletion nitrain/trainers/cloud_trainer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@



class CloudTrainer:
class GoogleCloudTrainer:
"""
Launch a nitrain training job in the cloud using your own
Google Cloud or AWS account. It is recommended to use this
Expand Down
12 changes: 12 additions & 0 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,18 @@ def test_3d(self):

x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (4, 182, 218, 182, 1))

def test_image_to_image(self):
img = ants.image_read(ants.get_data('r16'))
x = [img for _ in range(10)]
dataset = datasets.MemoryDataset(x, x)
loader = loaders.DatasetLoader(dataset,
batch_size=4)

x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (4, 256, 256, 1))
self.assertTrue(y_batch.shape == (4, 256, 256, 1))


if __name__ == '__main__':
run_tests()

0 comments on commit 84e00ea

Please sign in to comment.