Skip to content

Commit

Permalink
update tx
Browse files Browse the repository at this point in the history
  • Loading branch information
ncullen93 committed May 7, 2024
1 parent ed5c8f3 commit bb90d7b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
9 changes: 4 additions & 5 deletions nitrain/transforms/labels.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from .base import BaseTransform

__all__ = [
'ExpandLabels'
'LabelsToChannels'
]

class ExpandLabels(BaseTransform):
class LabelsToChannels(BaseTransform):
"""
Create a channel dimension for each separate value in a
segmentation image.
Expand All @@ -17,10 +17,9 @@ class ExpandLabels(BaseTransform):
It is also possible to keep the original values in the channels
instead of making all values equal to 1.
"""
def __init__(self, keep_values=False, as_channels=True):
def __init__(self, keep_values=False):
self.keep_values = keep_values
self.as_channels = as_channels

def __call__(self, *images):
images = [image.expand_labels(self.keep_values, self.as_channels) for image in images]
images = [image.labels_to_channels(self.keep_values) for image in images]
return images if len(images) > 1 else images[0]
3 changes: 2 additions & 1 deletion tests/test_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def test_multiclass_segmentation(self):
outputs=ImageReader('*/img3d_multiseg.nii.gz'),
transforms={
('inputs','outputs'): tx.Resample((40,40,40)),
'outputs': tx.ExpandLabels()
'outputs': tx.LabelsToChannels()
},
base_dir=base_dir)

Expand All @@ -81,6 +81,7 @@ def test_multiclass_segmentation(self):
data_train, data_test = dataset.split(0.8)

loader = nt.Loader(data_train,
sampler=SliceSampler(batch_size=20, axis=-1),
images_per_batch=4)

xb, yb = next(iter(loader))
Expand Down

0 comments on commit bb90d7b

Please sign in to comment.