New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Refactor transforms #213
Refactor transforms #213
Conversation
…-medical-imaging into cg/im-seg-roi_transforms # Please enter a commit message to explain why this merge is necessary, # especially if it merges an updated upstream into a topic branch. # # Lines starting with '#' will be ignored, and an empty message aborts # the commit.
In order to be compatible for both "sample" and "list of sample" I implemented the following decorator, def list_capable(wrapped):
@functools.wraps(wrapped)
def wrapper(self, sample, metadata):
if isinstance(sample, list):
list_data, list_metadata = [], []
for s_cur, m_cur in zip(sample, metadata):
# Run function for each sample of the list
data_cur, metadata_cur = wrapped(self, s_cur, m_cur)
list_data.append(data_cur)
list_metadata.append(metadata_cur)
return list_data, list_metadata
return wrapped(self, sample, metadata)
return wrapper The idea is:
|
For instance, the class Before:class HistogramClipping(IMEDTransform):
def __init__(self, min_percentile=5.0, max_percentile=95.0):
self.min_percentile = min_percentile
self.max_percentile = max_percentile
def do_clipping(self, data):
data = np.copy(data)
# Ensure that data is a numpy array
data = np.array(data)
# Run clipping
percentile1 = np.percentile(data, self.min_percentile)
percentile2 = np.percentile(data, self.max_percentile)
data[data <= percentile1] = percentile1
data[data >= percentile2] = percentile2
return data
def __call__(self, sample):
input_data = sample['input']
# TODO: Decorator?
if isinstance(input_data, list):
output_data = [self.do_clipping(data) for data in input_data]
else:
output_data = self.do_clipping(input_data)
# Update
rdict = {'input': output_data}
sample.update(rdict)
return sample After:class HistogramClipping(IMEDTransform):
def __init__(self, min_percentile=5.0, max_percentile=95.0):
self.min_percentile = min_percentile
self.max_percentile = max_percentile
@list_capable
def __call__(self, sample, metadata={}):
data = np.copy(sample)
# Run clipping
percentile1 = np.percentile(sample, self.min_percentile)
percentile2 = np.percentile(sample, self.max_percentile)
data[sample <= percentile1] = percentile1
data[sample >= percentile2] = percentile2
return data, metadata Note: it also simplifies the code regarding the labeled data. |
Another example with Before:class RandomTensorChannelShift(IMEDTransform):
def __init__(self, shift_range):
self.shift_range = shift_range
@staticmethod
def get_params(shift_range):
sampled_value = np.random.uniform(shift_range[0],
shift_range[1])
return sampled_value
@staticmethod
def sample_augment(input_data, params):
np_input_data = np.array(input_data)
np_input_data += params
input_data = Image.fromarray(np_input_data, mode='F')
return input_data
def __call__(self, sample):
input_data = sample['input']
params = self.get_params(self.shift_range)
if isinstance(input_data, list):
ret_input = [self.sample_augment(item, params) for item in input_data]
else:
ret_input = self.sample_augment(input_data, params)
rdict = {'input': ret_input}
sample.update(rdict)
return sample Afterclass RandomShiftIntensity(IMEDTransform):
def __init__(self, shift_range):
self.shift_range = shift_range
@list_capable
def __call__(self, sample, metadata={}):
# Get random offset
offset = np.random.uniform(self.shift_range[0], self.shift_range[1])
# Update metadata
metadata['offset'] = offset
# Shift intensity
data = sample + offset
return data, metadata
@list_capable
def undo_transform(self, sample, metadata={}):
assert 'offset' in metadata
# Get offset
offset = metadata['offset']
# Substract offset
data = sample - offset
return data, metadata |
For testing purposes, I implemented a function to create dummy data with labels: def create_test_image_2d(width, height, num_modalities, noise_max=10.0, num_objs=1, rad_max=30, num_seg_classes=1):
"""Create test image.
Create test image and its segmentation with a given number of objects, classes, and maximum radius.
Args:
width (int): width image
height (int): height image
num_modalities (int): number of modalities
noise_max (float): noise from the uniform distribution [0,noise_max)
num_objs (int): number of objects
rad_max (int): maximum radius of objects
num_seg_classes (int): number of classes
Return:
list, list: image and segmentation, list of num_modalities elements of shape (width, height).
Adapted from: https://github.com/Project-MONAI/MONAI/blob/master/monai/data/synthetic.py#L17
""" |
We can now run tests such as: @pytest.mark.parametrize('im_seg', (create_test_image_2d(100, 100, 1),
create_test_image_2d(100, 100, 3)))
def test_RandomShiftIntensity(im_seg):
im, _ = im_seg
# Transform
transform = RandomShiftIntensity(shift_range=[0., 10.])
# Apply Do Transform
metadata_in = [{} for _ in im] if isinstance(im, list) else {}
do_im, do_metadata = transform(sample=im, metadata=metadata_in)
# Check result has the same number of modalities
assert len(do_im) == len(im)
# Check metadata update
assert all('offset' in m for m in do_metadata)
# Check shifting
for idx, i in enumerate(im):
assert isclose(np.max(do_im[idx]-i), do_metadata[idx]['offset'], rel_tol=1e-02)
# Apply Undo Transform
undo_im, undo_metadata = transform.undo_transform(sample=do_im, metadata=do_metadata)
# Check result has the same number of modalities
assert len(undo_im) == len(im)
# Check undo
for idx, i in enumerate(im):
assert np.allclose(undo_im[idx], i, rtol=1e-02) |
ivadomed/transforms.py
Outdated
@@ -302,7 +304,6 @@ def __getitem__(self, item): | |||
class Crop(ImedTransform): | |||
def __init__(self, size): | |||
self.size = size if len(size) == 3 else size + [0] | |||
self.is_2D = True if len(size) == 2 else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason why I check the 2D versus 3D here instead of via sample --> was to be robut to 3D data with shape[2] == 1 (yes I know.. not common). While by checking with the crop_size --> we are certain that the user wants a 2D transform.
What do you think @andreanne-lemay ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I wanted to discuss that as well! The way it was before worked fine, why I changed it is so the transform automatically adapts to 2D or 3D output without having to change the tranforms. For instance, in the test_orientation.py I can use the same transforms for 2D and 3D. Also, I found it useful to not have to change multiple parameters when I want to train in 2D or 3D (now I only have to change the unet_3D bool).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is convenient indeed. Ok, we can go with that.
ivadomed/transforms.py
Outdated
params_do = metadata["resample"] | ||
params_undo = [1. / x for x in params_do] | ||
original_shape = metadata["data_shape"] | ||
current_shape = sample.shape | ||
params_undo = [x / y for x, y in zip(original_shape, current_shape)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works too. Did you have issues with the other version? curious
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I also wanted to comment on this! I changed it because using the inverse was not enough precise so I didn't have the same input and output shape after doing and undoing transforms. (e.i 51 instead of 52). When using the dimensions, I get a zoom of example 2.01343 instead of 2, which always give me the right output size.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent :)
|
||
# Save params | ||
metadata['resample'] = params_resample | ||
params_resample = (hfactor, wfactor, dfactor) if not is_2d else (hfactor, wfactor, 1.0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool
ivadomed/transforms.py
Outdated
@@ -20,45 +20,45 @@ | |||
|
|||
def multichannel_capable(wrapped): | |||
@functools.wraps(wrapped) | |||
def wrapper(self, sample, metadata): | |||
def wrapper(self, sample, metadata, data_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh.. what I meant in Slack was to add data_type in the metadata --> Is there advantages in doing the way you did? as it is only used for the resampling.
By adding it in the metadata, we would only need to modify 1 line in the loader.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed it :)!
ivadomed/transforms.py
Outdated
@@ -23,9 +23,9 @@ def multichannel_capable(wrapped): | |||
def wrapper(self, sample, metadata, data_type): | |||
if isinstance(sample, list): | |||
list_data, list_metadata = [], [] | |||
for s_cur, m_cur, d_cur in zip(sample, metadata, data_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment
ivadomed/loader/utils.py
Outdated
@@ -21,6 +18,8 @@ | |||
'uint8': torch.ByteTensor, | |||
} | |||
|
|||
TRANSFORM_PARAMS = ['resample', 'elastic', 'rotation', 'offset', 'crop_params', 'reverse', 'affine', 'gaussian_noise'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO charley: resample
can be removed.
On my end, everything runs well, I'm running my usual pipeline to check the performance -> looks good for now. I didn't test or modify the files in dev (we probably should check if there are all still functional?). I didn't change 'gt' wasn't sure if we wanted As requested here I removed the explicits calls on all tests. |
@@ -155,7 +155,9 @@ def test_NumpyToTensor(im_seg): | |||
def _test_Resample(im_seg, resample_transform, native_resolution, is_2D=False): | |||
im, seg = im_seg | |||
metadata_ = {'zooms': native_resolution, | |||
'data_shape': im[0].shape if len(im[0].shape) == 3 else list(im[0].shape) + [1]} | |||
'data_shape': im[0].shape if len(im[0].shape) == 3 else list(im[0].shape) + [1], | |||
'data_type': 'im' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: charley or.. : make another metadata_ for seg with data_type = 'gt'
In the long run, we might want to change. If we test a small shift or rotation, will this problem disappear? Or for these we might want to always use the same image (as opposed to randomly generated) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🥇 🥇 🥇
Following #209, some refactoring in the transforms:
numpy
operations (i.e. noPIL
)