Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor transforms #213

Merged
merged 235 commits into from May 21, 2020
Merged

Refactor transforms #213

merged 235 commits into from May 21, 2020

Conversation

charleygros
Copy link
Member

@charleygros charleygros commented May 1, 2020

Following #209, some refactoring in the transforms:

  • im and seg transforms run separetely, see comment here.
  • Only use numpy operations (i.e. no PIL)
  • Be robust to both "sample" and "list of sample"
  • Split file between different small files, using Folder imports #212
  • Implement test functions

…-medical-imaging into cg/im-seg-roi_transforms

# Please enter a commit message to explain why this merge is necessary,
# especially if it merges an updated upstream into a topic branch.
#
# Lines starting with '#' will be ignored, and an empty message aborts
# the commit.
@charleygros charleygros self-assigned this May 1, 2020
@charleygros
Copy link
Member Author

charleygros commented May 1, 2020

In order to be compatible for both "sample" and "list of sample" I implemented the following decorator, @list_capable:

def list_capable(wrapped):
    @functools.wraps(wrapped)
    def wrapper(self, sample, metadata):
        if isinstance(sample, list):
            list_data, list_metadata = [], []
            for s_cur, m_cur in zip(sample, metadata):
                # Run function for each sample of the list
                data_cur, metadata_cur = wrapped(self, s_cur, m_cur)
                list_data.append(data_cur)
                list_metadata.append(metadata_cur)
            return list_data, list_metadata
        return wrapped(self, sample, metadata)
    return wrapper

The idea is:

  • the transformation function (and undo) can only receive a sample
  • if the input is a list, then the decorator will iterate through the sample and call the tranformation for each of them, and return a list of "transformed sample"
  • same idea for the metadata

@charleygros
Copy link
Member Author

For instance, the class HistogramClipping:

Before:

class HistogramClipping(IMEDTransform):

    def __init__(self, min_percentile=5.0, max_percentile=95.0):
        self.min_percentile = min_percentile
        self.max_percentile = max_percentile

    def do_clipping(self, data):
        data = np.copy(data)
        # Ensure that data is a numpy array
        data = np.array(data)
        # Run clipping
        percentile1 = np.percentile(data, self.min_percentile)
        percentile2 = np.percentile(data, self.max_percentile)
        data[data <= percentile1] = percentile1
        data[data >= percentile2] = percentile2
        return data

    def __call__(self, sample):
        input_data = sample['input']

        # TODO: Decorator?
        if isinstance(input_data, list):
            output_data = [self.do_clipping(data) for data in input_data]
        else:
            output_data = self.do_clipping(input_data)

        # Update
        rdict = {'input': output_data}
        sample.update(rdict)
        return sample

After:

class HistogramClipping(IMEDTransform):

    def __init__(self, min_percentile=5.0, max_percentile=95.0):
        self.min_percentile = min_percentile
        self.max_percentile = max_percentile

    @list_capable
    def __call__(self, sample, metadata={}):
        data = np.copy(sample)
        # Run clipping
        percentile1 = np.percentile(sample, self.min_percentile)
        percentile2 = np.percentile(sample, self.max_percentile)
        data[sample <= percentile1] = percentile1
        data[sample >= percentile2] = percentile2
        return data, metadata

Note: it also simplifies the code regarding the labeled data.

@charleygros
Copy link
Member Author

Another example with RandomShiftIntensity:

Before:

class RandomTensorChannelShift(IMEDTransform):

    def __init__(self, shift_range):
        self.shift_range = shift_range

    @staticmethod
    def get_params(shift_range):
        sampled_value = np.random.uniform(shift_range[0],
                                          shift_range[1])
        return sampled_value

    @staticmethod
    def sample_augment(input_data, params):
        np_input_data = np.array(input_data)
        np_input_data += params
        input_data = Image.fromarray(np_input_data, mode='F')
        return input_data

    def __call__(self, sample):
        input_data = sample['input']
        params = self.get_params(self.shift_range)

        if isinstance(input_data, list):
            ret_input = [self.sample_augment(item, params) for item in input_data]
        else:
            ret_input = self.sample_augment(input_data, params)

        rdict = {'input': ret_input}

        sample.update(rdict)
        return sample

After

class RandomShiftIntensity(IMEDTransform):

    def __init__(self, shift_range):
        self.shift_range = shift_range

    @list_capable
    def __call__(self, sample, metadata={}):
        # Get random offset
        offset = np.random.uniform(self.shift_range[0], self.shift_range[1])
        # Update metadata
        metadata['offset'] = offset
        # Shift intensity
        data = sample + offset
        return data, metadata

    @list_capable
    def undo_transform(self, sample, metadata={}):
        assert 'offset' in metadata
        # Get offset
        offset = metadata['offset']
        # Substract offset
        data = sample - offset
        return data, metadata

@charleygros
Copy link
Member Author

For testing purposes, I implemented a function to create dummy data with labels:

def create_test_image_2d(width, height, num_modalities, noise_max=10.0, num_objs=1, rad_max=30, num_seg_classes=1):
    """Create test image.

    Create test image and its segmentation with a given number of objects, classes, and maximum radius.

    Args:
        width (int): width image
        height (int): height image
        num_modalities (int): number of modalities
        noise_max (float): noise from the uniform distribution [0,noise_max)
        num_objs (int): number of objects
        rad_max (int): maximum radius of objects
        num_seg_classes (int): number of classes
    Return:
        list, list: image and segmentation, list of num_modalities elements of shape (width, height).

    Adapted from: https://github.com/Project-MONAI/MONAI/blob/master/monai/data/synthetic.py#L17
    """

@charleygros
Copy link
Member Author

charleygros commented May 1, 2020

We can now run tests such as:

@pytest.mark.parametrize('im_seg', (create_test_image_2d(100, 100, 1),
                                    create_test_image_2d(100, 100, 3)))
def test_RandomShiftIntensity(im_seg):
    im, _ = im_seg
    # Transform
    transform = RandomShiftIntensity(shift_range=[0., 10.])

    # Apply Do Transform
    metadata_in = [{} for _ in im] if isinstance(im, list) else {}
    do_im, do_metadata = transform(sample=im, metadata=metadata_in)
    # Check result has the same number of modalities
    assert len(do_im) == len(im)
    # Check metadata update
    assert all('offset' in m for m in do_metadata)
    # Check shifting
    for idx, i in enumerate(im):
        assert isclose(np.max(do_im[idx]-i), do_metadata[idx]['offset'], rel_tol=1e-02)

    # Apply Undo Transform
    undo_im, undo_metadata = transform.undo_transform(sample=do_im, metadata=do_metadata)
    # Check result has the same number of modalities
    assert len(undo_im) == len(im)
    # Check undo
    for idx, i in enumerate(im):
        assert np.allclose(undo_im[idx], i, rtol=1e-02)

@@ -302,7 +304,6 @@ def __getitem__(self, item):
class Crop(ImedTransform):
def __init__(self, size):
self.size = size if len(size) == 3 else size + [0]
self.is_2D = True if len(size) == 2 else False
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason why I check the 2D versus 3D here instead of via sample --> was to be robut to 3D data with shape[2] == 1 (yes I know.. not common). While by checking with the crop_size --> we are certain that the user wants a 2D transform.
What do you think @andreanne-lemay ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I wanted to discuss that as well! The way it was before worked fine, why I changed it is so the transform automatically adapts to 2D or 3D output without having to change the tranforms. For instance, in the test_orientation.py I can use the same transforms for 2D and 3D. Also, I found it useful to not have to change multiple parameters when I want to train in 2D or 3D (now I only have to change the unet_3D bool).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is convenient indeed. Ok, we can go with that.

Comment on lines 187 to 190
params_do = metadata["resample"]
params_undo = [1. / x for x in params_do]
original_shape = metadata["data_shape"]
current_shape = sample.shape
params_undo = [x / y for x, y in zip(original_shape, current_shape)]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works too. Did you have issues with the other version? curious

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wanted to comment on this! I changed it because using the inverse was not enough precise so I didn't have the same input and output shape after doing and undoing transforms. (e.i 51 instead of 52). When using the dimensions, I get a zoom of example 2.01343 instead of 2, which always give me the right output size.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent :)


# Save params
metadata['resample'] = params_resample
params_resample = (hfactor, wfactor, dfactor) if not is_2d else (hfactor, wfactor, 1.0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool

@@ -20,45 +20,45 @@

def multichannel_capable(wrapped):
@functools.wraps(wrapped)
def wrapper(self, sample, metadata):
def wrapper(self, sample, metadata, data_type):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh.. what I meant in Slack was to add data_type in the metadata --> Is there advantages in doing the way you did? as it is only used for the resampling.
By adding it in the metadata, we would only need to modify 1 line in the loader.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed it :)!

@@ -23,9 +23,9 @@ def multichannel_capable(wrapped):
def wrapper(self, sample, metadata, data_type):
if isinstance(sample, list):
list_data, list_metadata = [], []
for s_cur, m_cur, d_cur in zip(sample, metadata, data_type):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment

@@ -21,6 +18,8 @@
'uint8': torch.ByteTensor,
}

TRANSFORM_PARAMS = ['resample', 'elastic', 'rotation', 'offset', 'crop_params', 'reverse', 'affine', 'gaussian_noise']
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO charley: resample can be removed.

@andreanne-lemay
Copy link
Member

On my end, everything runs well, I'm running my usual pipeline to check the performance -> looks good for now. I didn't test or modify the files in dev (we probably should check if there are all still functional?).

I didn't change 'gt' wasn't sure if we wanted label or seg?

As requested here I removed the explicits calls on all tests.

@@ -155,7 +155,9 @@ def test_NumpyToTensor(im_seg):
def _test_Resample(im_seg, resample_transform, native_resolution, is_2D=False):
im, seg = im_seg
metadata_ = {'zooms': native_resolution,
'data_shape': im[0].shape if len(im[0].shape) == 3 else list(im[0].shape) + [1]}
'data_shape': im[0].shape if len(im[0].shape) == 3 else list(im[0].shape) + [1],
'data_type': 'im'
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: charley or.. : make another metadata_ for seg with data_type = 'gt'

@andreanne-lemay
Copy link
Member

@andreanne-lemay & @olix86 : Just so that you are aware: test_RandomAffine and test_RandomRotation may "fail" from time to time because of the problem illustrated here below, even if the do and undo transforms are technically correct.
For now, we could just (i) reduce the translation / rotation params (e.g. test only rotation of 5°) (ii) add a specific warning (eg "the test may have fail because the segmentation lost 80% of its coverage after the transform") and (iii) rerun the test. If it happens too often and becomes annoying, we could think of an alternative. Does it sound okay?

IMG_20200521_114729

In the long run, we might want to change. If we test a small shift or rotation, will this problem disappear? Or for these we might want to always use the same image (as opposed to randomly generated)

Copy link
Member

@andreanne-lemay andreanne-lemay left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🥇 🥇 🥇

@charleygros charleygros merged commit 5cd654c into master May 21, 2020
@charleygros charleygros deleted the cg/im-seg-roi_transforms branch May 21, 2020 06:46
@olix86 olix86 mentioned this pull request May 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants