Skip to content

Commit

Permalink
test channels first
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 24, 2024
1 parent 4442ca8 commit b4ba534
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 16 deletions.
18 changes: 9 additions & 9 deletions nitrain/loaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self,
dataset,
images_per_batch,
transforms=None,
channel_axis=-1,
channels_first=False,
shuffle=False,
sampler=None):
"""
Expand All @@ -33,7 +33,7 @@ def __init__(self,

self.dataset = dataset
self.images_per_batch = images_per_batch
self.channel_axis = channel_axis
self.channels_first = channels_first
self.transforms = transforms
self.shuffle = shuffle

Expand All @@ -45,7 +45,7 @@ def copy(self, dataset=None, drop_transforms=False):
new_loader = Loader(
dataset = copy(self.dataset) if dataset is None else dataset,
images_per_batch = self.images_per_batch,
channel_axis = self.channel_axis,
channels_first = self.channels_first,
transforms = self.transforms if not drop_transforms else None,
shuffle = self.shuffle,
sampler = self.sampler
Expand Down Expand Up @@ -103,9 +103,9 @@ def __iter__(self):

for x_batch, y_batch in sampled_batch:

if self.channel_axis is not None:
x_batch = expand_image_dims(x_batch, self.channel_axis)
y_batch = expand_image_dims(y_batch, self.channel_axis)
if self.channels_first is not None:
x_batch = expand_image_dims(x_batch, self.channels_first)
y_batch = expand_image_dims(y_batch, self.channels_first)

x_batch = convert_to_numpy(x_batch)
y_batch = convert_to_numpy(y_batch)
Expand Down Expand Up @@ -155,10 +155,10 @@ def convert_to_numpy(x):
else:
return np.array(x)

def expand_image_dims(x, axis):
mytx = tx.AddChannel(axis)
def expand_image_dims(x, channels_first):
mytx = tx.AddChannel(channels_first)
if isinstance(x, list):
return [expand_image_dims(xx, axis) for xx in x]
return [expand_image_dims(xx, channels_first) for xx in x]
else:
if ants.is_image(x):
return mytx(x) if not x.has_components else x
Expand Down
6 changes: 3 additions & 3 deletions nitrain/transforms/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
]

class AddChannel(BaseTransform):
def __init__(self, axis=-1):
def __init__(self, channels_first=False):
"""
import ants
from nitrain import transforms as tx
img = ants.image_read(ants.get_data('r16'))
mytx = tx.AddChannel()
img2 = mytx(img)
"""
self.axis = axis
self.channels_first = channels_first

def __call__(self, *images):
images = [ants.merge_channels([image], axis=self.axis) for image in images]
images = [ants.merge_channels([image], channels_first=self.channels_first) for image in images]
return images if len(images) > 1 else images[0]

class Reorient(BaseTransform):
Expand Down
36 changes: 32 additions & 4 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,28 @@ def test_2d(self):
self.assertEqual(xb.shape, (4, 256, 256, 1))
self.assertEqual(yb.shape, (4,))

loader = nt.Loader(dataset_2d, images_per_batch=4, channel_axis=None)
loader = nt.Loader(dataset_2d, images_per_batch=4, channels_first=None)
xb, yb = next(iter(loader))
self.assertEqual(xb.shape, (4, 256, 256))
self.assertEqual(yb.shape, (4,))

# test repr
rep = loader.__repr__()


def test_2d_channels_first(self):
loader = nt.Loader(self.dataset_2d, images_per_batch=4,
channels_first=True)

x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (4, 1, 256, 256))

loader2 = loader.to_keras()
x_batch, y_batch = next(iter(loader2))
self.assertTrue(x_batch.shape == (4, 1, 256, 256))

gen = record_generator(loader)
xb,yb = next(iter(gen))

def test_copy(self):
import ants
import nitrain as nt
Expand Down Expand Up @@ -112,10 +126,24 @@ def test_3d(self):

x_batch, y_batch = next(iter(loader))
self.assertEqual(x_batch.shape, (4, 182, 218, 182, 1))

def test_3d_channels_first(self):
loader = nt.Loader(self.dataset_3d, images_per_batch=4,
channels_first=True)

x_batch, y_batch = next(iter(loader))
self.assertTrue(x_batch.shape == (4, 1, 182, 218, 182))

loader2 = loader.to_keras()
x_batch, y_batch = next(iter(loader2))
self.assertTrue(x_batch.shape == (4, 1, 182, 218, 182))

gen = record_generator(loader)
xb,yb = next(iter(gen))

def test_3d_no_expand(self):
loader = nt.Loader(self.dataset_3d, images_per_batch=4,
channel_axis=None)
channels_first=None)

x_batch, y_batch = next(iter(loader))

Expand Down Expand Up @@ -279,7 +307,7 @@ def test_multiclass_segmentation_no_expand_dims(self):
loader = nt.Loader(data_train,
images_per_batch=4,
shuffle=True,
channel_axis=None,
channels_first=None,
sampler=SliceSampler(batch_size=20, axis=2))

xb, yb = next(iter(loader))
Expand Down

0 comments on commit b4ba534

Please sign in to comment.