Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Enable heatmaps when tiling on the fly #491

Open
wants to merge 36 commits into
base: main
Choose a base branch
from

Conversation

vale-salvatelli
Copy link
Contributor

@vale-salvatelli vale-salvatelli commented Jul 5, 2022

In this PR we enable heatmap outputs when tiling on the fly. Specifically:

  • we changed the MONAI transform used for tiling to be GridPatch. The main advantage is that this transform returns coordinates for each of the patch
  • we update the collate function to handle generic arrays
  • the environment is updated to MONAI dev because the transform and the patches we added to it in MONAI are not released yet

@vale-salvatelli vale-salvatelli changed the title Vsalva/monai transform update ENH: Enable heatmaps when tiling on the fly Jul 5, 2022
@codecov
Copy link

codecov bot commented Jul 5, 2022

Codecov Report

Merging #491 (33c52b1) into main (c73103a) will decrease coverage by 22.82%.
The diff coverage is 5.66%.

Impacted file tree graph

Flag Coverage Δ
hi-ml-cpath 25.69% <5.66%> (-51.71%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
hi-ml-cpath/src/health_cpath/utils/output_utils.py 6.62% <0.00%> (-65.05%) ⬇️
hi-ml-cpath/src/health_cpath/utils/wsi_utils.py 15.62% <3.84%> (-84.38%) ⬇️
...-cpath/src/health_cpath/datamodules/base_module.py 7.69% <7.69%> (-85.74%) ⬇️
hi-ml-cpath/src/health_cpath/utils/naming.py 2.53% <20.00%> (-96.12%) ⬇️
...-ml-cpath/src/health_cpath/preprocessing/tiling.py 10.52% <0.00%> (-85.97%) ⬇️
...ml-cpath/src/health_cpath/datasets/base_dataset.py 6.55% <0.00%> (-75.41%) ⬇️
hi-ml-cpath/src/health_cpath/utils/layer_utils.py 19.44% <0.00%> (-75.00%) ⬇️
hi-ml-cpath/src/health_cpath/utils/viz_utils.py 12.28% <0.00%> (-72.81%) ⬇️
... and 63 more

Copy link
Contributor

@kenza-bouzid kenza-bouzid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good, I left some suggestions.

hi-ml-histopathology/src/histopathology/models/deepmil.py Outdated Show resolved Hide resolved
return faulty_slides_idx

def get_slide_patch_coordinates(
self, slide_offset: List, patches_location: List, patch_size: List
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add doc string please?

results.update(metadata_dict)
# each slide can have a different number of patches
for i in range(n_slides):
updated_metadata_dict = self.compute_slide_metadata(batch, i, metadata_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if all this metadata processing should be wrapped into a transform. Data is prefetched anyways it should be more efficient to apply it as a transform that is muti processed by the dataloader. Also from a design perspective, the model shouldn't be handling any metadata processing...

hi-ml-histopathology/src/histopathology/utils/naming.py Outdated Show resolved Hide resolved
@@ -72,6 +72,8 @@ def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
value = value.squeeze(0).cpu().numpy()
if value.ndim == 0:
value = np.full(bag_size, fill_value=value)
if isinstance(value, List) and isinstance(value[0], torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why do we need to do that here? is that for the coordinates? Maybe turn it into numpy arrays in the first place?

hi-ml-histopathology/src/histopathology/models/deepmil.py Outdated Show resolved Hide resolved
self, slide_offset: List, patches_location: List, patch_size: List
) -> Tuple[List, List, List, List]:
""" computing absolute coordinates for all patches in a slide"""
top, bottom, left, right = self.get_empty_lists(len(patches_location), 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this be cleaner using numpy arrays, perhaps?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dccastro , I like your suggestion above to use the Box class instead. Much cleaner than using a 4-tuple of ints, that's just calling for errors to happen.

return ll

@staticmethod
def get_patch_coordinate(slide_offset: List, patch_location: List, patch_size: List) -> Tuple[int, int, int, int]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you think of using our Box class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on that. Tuples with 4 elements of the same type are just too easy to mix up


def compute_slide_metadata(self, batch: Dict, index: int, metadata_dict: Dict) -> Dict:
"""compute patch-dependent and patch-invariante metadata for a single slide """
offset = batch[SlideKey.OFFSET.value][index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] The .value shouldn't be necessary

top[i], bottom[i], left[i], right[i] = self.get_patch_coordinate(slide_offset, location, patch_size)
return top, bottom, left, right

def compute_slide_metadata(self, batch: Dict, index: int, metadata_dict: Dict) -> Dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this method need to mutate the input metadata_dict in-place, instead of returning a new dictionary?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If needs to mutate in place, then the signature and name of the function should reflect that mutation. (update_metadata, -> None)

],
if all(key.value in batch.keys() for key in [SlideKey.OFFSET, SlideKey.PATCH_LOCATION, SlideKey.PATCH_SIZE]):
n_slides = len(batch[SlideKey.SLIDE_ID])
metadata_dict = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be replaced by just adding if key not in results: results[key] = [] in the loop below?

hi-ml-histopathology/src/histopathology/utils/wsi_utils.py Outdated Show resolved Hide resolved
hi-ml-histopathology/src/histopathology/utils/wsi_utils.py Outdated Show resolved Hide resolved
top[i], bottom[i], left[i], right[i] = self.get_patch_coordinate(slide_offset, location, patch_size)
return top, bottom, left, right

def compute_slide_metadata(self, batch: Dict, index: int, metadata_dict: Dict) -> Dict:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If needs to mutate in place, then the signature and name of the function should reflect that mutation. (update_metadata, -> None)

ResultsKey.IMAGE_PATH: [
[img_path] * bag_sizes[i] for i, img_path in enumerate(batch[SlideKey.IMAGE_PATH])
],
if all(key.value in batch.keys() for key in [SlideKey.OFFSET, SlideKey.PATCH_LOCATION, SlideKey.PATCH_SIZE]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a set operation would be easier to understand. set(SlideKey.OFFSET, SlideKey....) <= set(batch.keys()

hi-ml-histopathology/src/histopathology/utils/naming.py Outdated Show resolved Hide resolved
@@ -72,6 +72,8 @@ def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
value = value.squeeze(0).cpu().numpy()
if value.ndim == 0:
value = np.full(bag_size, fill_value=value)
if isinstance(value, List) and isinstance(value[0], torch.Tensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function could do with a lot more documentation. A docstring would be great. And each of the if branches should have a clear description which case we are handling here, and where those cases arise.

@@ -33,6 +33,47 @@ dependencies:
- ruamel.yaml==0.16.12
- tensorboard==2.6.0
# Histopathology requirements
- coloredlogs==15.0.1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes to primary_deps.yml should be made via requirements_run.txt

If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
:param intensity_threshold: a value to keep only the patches whose sum of intensities are less than the
threshold. Defaults to no filtering.
:pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:pad_mode: refer to NumpyPadMode and PytorchPadMode. If None, no padding will be applied.
:param pad_mode: refer to NumpyPadMode and PytorchPadMode. If `None`, no padding will be applied.

monai transform for tiling on the fly.
:param filter_mode: when `num_patches` is provided, it determines if keep patches with highest values
(`"max"`), lowest values (`"min"`), or in their default order (`None`). Default to None.
:param overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the order of values? (width, height) or the other way around?

max_offset = None if (self.random_offset and stage == ModelKey.TRAIN) else 0

if stage != ModelKey.TRAIN:
grid_transform = RandGridPatchd(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something I don't get here: We are using random tiles when we are NOT training?

bottom = slide_offset[0] + patch_location[0] + patch_size[0]
left = slide_offset[1] + patch_location[1]
right = slide_offset[1] + patch_location[1] + patch_size[1]
return top, bottom, left, right
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuples of 4 integers are really error prone. Can we use the Box class instead?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(maybe changing it to use top, bottom, left, right in the same washup?)

@staticmethod
def expand_slide_constant_metadata(id: str, path: str, n_patches: int, top: List[int],
bottom: List[int], left: List[int], right: List[int]) -> Tuple[List, List, List]:
"""Duplicate metadata that is patch invariant to match the shape of other arrays"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you expand the documentation a bit here? also "match the shape of other arrays" is not completely correct, it is matching the number given in n_patches (and assumes that many arrays have matching lengths).

@@ -72,6 +72,8 @@ def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
value = value.squeeze(0).cpu().numpy()
if value.ndim == 0:
value = np.full(bag_size, fill_value=value)
if isinstance(value, List) and isinstance(value[0], torch.Tensor):
value = [value[i].item() for i in range(len(value))]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
value = [value[i].item() for i in range(len(value))]
value = [v.item() for v in value]

@@ -39,7 +39,7 @@ def __getitem__(self, index: int) -> List[Dict[SlideKey, Any]]:


@pytest.mark.parametrize("random_n_tiles", [False, True])
def test_image_collate(random_n_tiles: bool) -> None:
def test_array_collate(random_n_tiles: bool) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no test coverage for the new functionality?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants