Skip to content

Commit

Permalink
add LabelsToChannels transform
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 4, 2024
1 parent e09880c commit 4ed11a6
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 6 deletions.
7 changes: 7 additions & 0 deletions nitrain/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,12 @@ def fetch_data(name, path=None, overwrite=False):
img3d = nti.ones((30,40,50))
img3d_seg = nti.zeros_like(img3d).astype('uint8')
img3d_seg[10:20,10:30,10:40] = 1

img3d_multiseg = nti.zeros_like(img3d).astype('uint8')
img3d_multiseg[:20,:20,:20] = 1
img3d_multiseg[20:30,20:30,20:30] = 2
img3d_multiseg[30:,30:,30:]=0

img3d_large = nti.ones((60,80,100))
for i in range(10):
sub_dir = os.path.join(save_dir, f'sub_{i}')
Expand All @@ -73,6 +79,7 @@ def fetch_data(name, path=None, overwrite=False):
nti.save(img3d_large + i, os.path.join(sub_dir, 'img3d_large.nii.gz'))
nti.save(img3d + i + 100, os.path.join(sub_dir, 'img3d_100.nii.gz'))
nti.save(img3d_seg, os.path.join(sub_dir, 'img3d_seg.nii.gz'))
nti.save(img3d_multiseg, os.path.join(sub_dir, 'img3d_multiseg.nii.gz'))
nti.save(img3d + i + 1000, os.path.join(sub_dir, 'img3d_1000.nii.gz'))

# write csv file
Expand Down
5 changes: 1 addition & 4 deletions nitrain/samplers/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,7 @@ def __iter__(self):
if self.shuffle:
indices = random.sample(range(len(self.y)), len(self.y))
self.x = [self.x[i] for i in indices]
if nti.is_image(self.y[0]):
self.y = [self.y[i] for i in indices]
else:
self.y = self.y[indices]
self.y = [self.y[i] for i in indices]

return self

Expand Down
1 change: 1 addition & 0 deletions nitrain/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from .image import *
from .intensity import *
from .labels import *
from .math import *
from .shape import *
from .spatial import *
Expand Down
25 changes: 25 additions & 0 deletions nitrain/transforms/labels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from .base import BaseTransform

__all__ = [
'LabelsToChannels'
]

class LabelsToChannels(BaseTransform):
"""
Create a channel dimension for each separate value in a
segmentation image.
If an image has shape (100,100) and has three unique values (0,1,2),
then this transform will return an image with shape (100,100,3) where
(100,100,0) = 1 if the original value is 0, (100,100,1) = 1 if the original
value is 1, and (100,100,2) = 2 if the original value is 2.
It is also possible to keep the original values in the channels
instead of making all values equal to 1.
"""
def __init__(self, keep_values=False):
self.keep_values = keep_values

def __call__(self, *images):
images = [image.labels_to_channels(self.keep_values) for image in images]
return images if len(images) > 1 else images[0]
24 changes: 24 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,30 @@ def test_Translate(self):
img2d_tx = my_tx(img2d)

img3d_tx = my_tx(img3d)

class TestLabels(unittest.TestCase):

def setUp(self):
pass

def tearDown(self):
pass

def test_LabelsToChannels(self):
img2d = nti.zeros((100,100))
img2d[:20,:] = 1
img2d[20:40,:] = 2
img2d[40:60,:] = 3

img3d = nti.zeros((100,100,100))
img3d[:20,:,:] = 1
img3d[20:40,:,:] = 2
img3d[40:60,:,:] = 3

my_tx = tx.LabelsToChannels()

img2d_tx = my_tx(img2d)
img3d_tx = my_tx(img3d)

class TestErrors(unittest.TestCase):

Expand Down
56 changes: 54 additions & 2 deletions tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
from nitrain.readers import PatternReader
from nitrain.samplers import SliceSampler

class TestClass_Keras(unittest.TestCase):
class TestClass_OneInput_OneOutput(unittest.TestCase):
def setUp(self):
pass

def tearDown(self):
pass

def test_segmentation(self):
def test_binary_segmentation(self):
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=PatternReader('*/img3d.nii.gz'),
Expand Down Expand Up @@ -59,7 +59,59 @@ def test_segmentation(self):
task='segmentation',
sampler=SliceSampler(axis=-1))
y_pred = predictor.predict(data_test)

def test_multiclass_segmentation(self):
base_dir = nt.fetch_data('example-01')

dataset = nt.Dataset(inputs=PatternReader('*/img3d.nii.gz'),
outputs=PatternReader('*/img3d_multiseg.nii.gz'),
transforms={
('inputs','outputs'): tx.Resample((40,40,40)),
'outputs': tx.LabelsToChannels()
},
base_dir=base_dir)

x,y = dataset[0]

data_train, data_test = dataset.split(0.8)

loader = nt.Loader(data_train,
images_per_batch=4,
shuffle=True,
sampler=SliceSampler(batch_size=20, axis=-1))

arch_fn = nt.fetch_architecture('unet', dim=2)
model = arch_fn(x.shape[:-1]+(1,),
number_of_outputs=2,
number_of_layers=4,
number_of_filters_at_base_layer=16,
mode='classification')

# train
trainer = nt.Trainer(model, task='segmentation')
trainer.fit(loader, epochs=2)

# evaluate on test data
test_loader = loader.copy(data_test)
trainer.evaluate(test_loader)

# inference on test data
predictor = nt.Predictor(model,
task='segmentation',
sampler=SliceSampler(axis=-1))
y_pred = predictor.predict(data_test)

def test_image_regression(self):
pass

def test_scalar_regression(self):
pass

def test_binary_scalar_classification(self):
pass

def test_multiclass_scalar_classification(self):
pass

if __name__ == '__main__':
run_tests()

0 comments on commit 4ed11a6

Please sign in to comment.