Skip to content

Commit

Permalink
update loader tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 4, 2024
1 parent 4ed11a6 commit 8b1849e
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 24 deletions.
8 changes: 8 additions & 0 deletions nitrain/transforms/shape.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
from .base import BaseTransform

__all__ = [
'ExpandDims',
'Reorient',
'Rollaxis',
'Repeat',
'Swapaxes'
]

class ExpandDims(BaseTransform):
def __init__(self, axis=-1):
self.axis = axis

def __call__(self, *images):
images = [image.expand_dims(self.axis) for image in images]
return images if len(images) > 1 else images[0]

class Reorient(BaseTransform):
def __init__(self, orientation):
Expand Down
27 changes: 27 additions & 0 deletions tests/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,33 @@ def test_transforms(self):
self.assertEqual(xb[0].shape, (20,48,48,1))
self.assertEqual(xb[1].shape, (20,48,48,1))
self.assertEqual(yb.shape, (20,30,40,1))

def test_multiclass_segmentation_no_expand_dims(self):
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=PatternReader('*/img3d.nii.gz'),
outputs=PatternReader('*/img3d_multiseg.nii.gz'),
transforms={
('inputs','outputs'): tx.Resample((40,40,40)),
'inputs': tx.ExpandDims(),
'outputs': tx.LabelsToChannels()
},
base_dir=base_dir)

x,y = dataset[0]

data_train, data_test = dataset.split(0.8)

loader = nt.Loader(data_train,
images_per_batch=4,
shuffle=True,
expand_dims=False,
sampler=SliceSampler(batch_size=20, axis=2))

xb, yb = next(iter(loader))

self.assertEqual(xb.shape, (20,40,40,1))
self.assertEqual(yb.shape, (20,40,40,2))


if __name__ == '__main__':
Expand Down
50 changes: 26 additions & 24 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_multiclass_segmentation(self):
outputs=PatternReader('*/img3d_multiseg.nii.gz'),
transforms={
('inputs','outputs'): tx.Resample((40,40,40)),
'inputs': tx.ExpandDims(),
'outputs': tx.LabelsToChannels()
},
base_dir=base_dir)
Expand All @@ -76,30 +77,31 @@ def test_multiclass_segmentation(self):
data_train, data_test = dataset.split(0.8)

loader = nt.Loader(data_train,
images_per_batch=4,
shuffle=True,
sampler=SliceSampler(batch_size=20, axis=-1))

arch_fn = nt.fetch_architecture('unet', dim=2)
model = arch_fn(x.shape[:-1]+(1,),
number_of_outputs=2,
number_of_layers=4,
number_of_filters_at_base_layer=16,
mode='classification')

# train
trainer = nt.Trainer(model, task='segmentation')
trainer.fit(loader, epochs=2)

# evaluate on test data
test_loader = loader.copy(data_test)
trainer.evaluate(test_loader)

# inference on test data
predictor = nt.Predictor(model,
task='segmentation',
sampler=SliceSampler(axis=-1))
y_pred = predictor.predict(data_test)
images_per_batch=4,
shuffle=True,
expand_dims=False,
sampler=SliceSampler(batch_size=20, axis=2))

#arch_fn = nt.fetch_architecture('unet', dim=2)
#model = arch_fn(x.shape[:-1]+(1,),
# number_of_outputs=2,
# number_of_layers=4,
# number_of_filters_at_base_layer=16,
# mode='classification')
#
## train
#trainer = nt.Trainer(model, task='segmentation')
#trainer.fit(loader, epochs=2)
#
## evaluate on test data
#test_loader = loader.copy(data_test)
#trainer.evaluate(test_loader)
#
## inference on test data
#predictor = nt.Predictor(model,
# task='segmentation',
# sampler=SliceSampler(axis=-1))
#y_pred = predictor.predict(data_test)

def test_image_regression(self):
pass
Expand Down

0 comments on commit 8b1849e

Please sign in to comment.