Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Enable heatmaps when tiling on the fly #491

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
df92a48
transform works locally, coordinatess need to be propagated
vale-salvatelli Jun 1, 2022
a7754d3
switch branch
vale-salvatelli Jun 7, 2022
6f29abe
fix merge
vale-salvatelli Jun 7, 2022
ce04dab
fix merge
vale-salvatelli Jun 7, 2022
be25cf6
adding coordinates to the batch
vale-salvatelli Jun 7, 2022
8183fb6
sfixing collate
vale-salvatelli Jun 9, 2022
1a93d3a
updating transform parameters
vale-salvatelli Jun 9, 2022
bbaa5fa
method that update slide results updated, tiles mthod to be refactore…
vale-salvatelli Jun 13, 2022
5c10885
shape and type results aligned to Tiles
vale-salvatelli Jun 14, 2022
585c893
updatingg env to pin MONAI dev commit
vale-salvatelli Jun 14, 2022
7970037
heatmaps produced locally when tiling on the fly
vale-salvatelli Jun 15, 2022
792385e
problematic slides now skipped
vale-salvatelli Jun 20, 2022
85bd4ec
bug fix runs locally
vale-salvatelli Jun 21, 2022
5281940
reduce logging
vale-salvatelli Jun 21, 2022
2ea8c94
works locally skipping slide with problematic patches (1 in val on th…
vale-salvatelli Jun 27, 2022
a720a28
issue with coordinnates being equal fixed
vale-salvatelli Jul 5, 2022
a996323
cleaning up checks no longer needed
vale-salvatelli Jul 5, 2022
547889f
fixing merge conflicts
vale-salvatelli Jul 5, 2022
7550cae
leftover rfrom merge
vale-salvatelli Jul 5, 2022
e989b4f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 5, 2022
73f0798
flake8 fixes
vale-salvatelli Jul 5, 2022
3418f47
flake8 fixes
vale-salvatelli Jul 5, 2022
73b044f
addressing some PR feedback, thanks @Kenza
vale-salvatelli Jul 5, 2022
58c74ec
more changes
vale-salvatelli Jul 6, 2022
fec8f8e
fixing merge with latest main
vale-salvatelli Jul 6, 2022
87e693f
more feedback implemented
vale-salvatelli Jul 6, 2022
bf16cb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2022
b656599
fixing current test plus update env
vale-salvatelli Jul 6, 2022
6408a2e
Merge branch 'vsalva/monai_transform_update' of https://github.com/mi…
vale-salvatelli Jul 6, 2022
aa3e6f1
fix flake8
vale-salvatelli Jul 6, 2022
4226ba3
minor fixes
vale-salvatelli Jul 6, 2022
4a97089
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2022
50a58c0
fixing issue with unexpcted batch size due to changes in MONAI dev
vale-salvatelli Jul 6, 2022
6dfa5d6
Merge branch 'vsalva/monai_transform_update' of https://github.com/mi…
vale-salvatelli Jul 6, 2022
600dd06
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 6, 2022
33c52b1
merging main with cpath renaming
vale-salvatelli Aug 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions hi-ml-histopathology/.vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,41 @@
"name": "Python: Run SlidesPandaImageNetMIL locally",
"type": "python",
"request": "launch",
"justMyCode": false,
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
"args": [
"--model=histopathology.SlidesPandaImageNetMIL",
"--pl_fast_dev_run=10",
"--crossval_count=0",
"--batch_size=2",
"--max_bag_size=4",
"--max_bag_size_inf=4",
"--max_bag_size=5",
"--max_bag_size_inf=5",
"--num_top_slides=2",
"--num_top_tiles=2"
"--num_top_tiles=2",
"--max_num_gpus=1",
],
"console": "integratedTerminal",
},
{
"name": "Python: Long Run SlidesPandaImageNetMIL locally",
"type": "python",
"request": "launch",
"justMyCode": false,
"program": "${workspaceFolder}/../hi-ml/src/health_ml/runner.py",
"args": [
"--model=histopathology.SlidesPandaImageNetMIL",
"--crossval_count=0",
"--batch_size=50",
"--max_bag_size=10",
"--max_bag_size_inf=10",
"--num_top_slides=2",
"--num_top_tiles=2",
"--max_num_gpus=1",
"--max_epochs 2",
"--pl_limit_train_batches 20",
"--pl_limit_test_batches 20",
"--pl_limit_val_batches 20"

],
"console": "integratedTerminal",
},
Expand Down
5 changes: 4 additions & 1 deletion hi-ml-histopathology/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,10 @@ dependencies:
- jupyter-client==7.3.4
- lightning-bolts==0.4.0
- mlflow==1.17.0
- monai==0.8.0
# commit of dev branch containing transform with coordinates
# - git+https://github.com/Project-MONAI/MONAI.git@df4a7d72e1d231b898f88d92cf981721c49ceaeb
# commit of dev branch including latest fixed to GridPatch 22/06
- git+https://github.com/Project-MONAI/MONAI.git@669bddf581201f994d1bcc0cb780854901605d9b
- more-itertools==8.10.0
- mypy==0.961
- mypy-extensions==0.4.3
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def __init__(self, **kwargs: Any) -> None:
# declared in DatasetParams:
local_datasets=[Path("/tmp/datasets/PANDA")],
azure_datasets=["PANDA"],
save_output_slides=False,)
save_output_slides=True,)
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
default_kwargs.update(kwargs)
super().__init__(**default_kwargs)

Expand Down
70 changes: 35 additions & 35 deletions hi-ml-histopathology/src/histopathology/datamodules/base_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator

from histopathology.utils.wsi_utils import image_collate
from histopathology.utils.wsi_utils import array_collate
from histopathology.models.transforms import LoadTilesBatchd
from histopathology.datasets.base_dataset import SlidesDataset, TilesDataset
from histopathology.utils.naming import ModelKey

from monai.transforms.compose import Compose
from monai.transforms.io.dictionary import LoadImaged
from monai.apps.pathology.transforms import TileOnGridd
from monai.transforms import RandGridPatchd
from monai.data.image_reader import WSIReader

_SlidesOrTilesDataset = TypeVar('_SlidesOrTilesDataset', SlidesDataset, TilesDataset)
Expand Down Expand Up @@ -250,67 +250,67 @@ def __init__(
pad_full: Optional[bool] = False,
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
background_val: Optional[int] = 255,
filter_mode: Optional[str] = "min",
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
overlap: Optional[float] = 0,
intensity_threshold: Optional[float] = 0,
**kwargs: Any,
) -> None:
"""
:param level: the whole slide image level at which the image is extracted, defaults to 1
this param is passed to the LoadImaged monai transform that loads a WSI with cucim backend
:param tile_size: size of the square tile, defaults to 224
this param is passed to TileOnGridd monai transform for tiling on the fly.
:param step: step size to create overlapping tiles, defaults to None (same as tile_size)
Use a step < tile_size to create overlapping tiles, analogousely a step > tile_size will skip some chunks in
the wsi. This param is passed to TileOnGridd monai transform for tiling on the fly.
:param random_offset: randomize position of the grid, instead of starting from the top-left corner,
defaults to True. This param is passed to TileOnGridd monai transform for tiling on the fly.
:param pad_full: pad image to the size evenly divisible by tile_size, defaults to False
This param is passed to TileOnGridd monai transform for tiling on the fly.
:param background_val: the background constant to ignore background tiles (e.g. 255 for white background),
defaults to 255. This param is passed to TileOnGridd monai transform for tiling on the fly.
:param filter_mode: mode must be in ["min", "max", "random"]. If total number of tiles is greater than
tile_count, then sort by intensity sum, and take the smallest (for min), largest (for max) or random (for
random) subset, defaults to "min" (which assumes background is high value). This param is passed to TileOnGridd
random) subset, defaults to "min" (which assumes background is high value). This param is passed to
monai transform for tiling on the fly.
: param overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
: param intensity_threshold: a value to keep only the patches whose sum of intensities are less than the
threshold. Defaults to no filtering.
"""
super().__init__(**kwargs)
self.level = level
self.tile_size = tile_size
self.step = step
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
self.random_offset = random_offset
self.pad_full = pad_full
self.background_val = background_val
self.filter_mode = filter_mode
# TileOnGridd transform expects None to select all foreground tile so we hardcode max_bag_size and
# Tiling transform expects None to select all foreground tile so we hardcode max_bag_size and
# max_bag_size_inf to None if set to 0
self.max_bag_size = None if self.max_bag_size == 0 else self.max_bag_size # type: ignore
self.max_bag_size_inf = None if self.max_bag_size_inf == 0 else self.max_bag_size_inf # type: ignore
self.overlap = overlap
self.intensity_threshold = intensity_threshold

def _load_dataset(self, slides_dataset: SlidesDataset, stage: ModelKey) -> Dataset:
base_transform = Compose(
[
LoadImaged(
keys=slides_dataset.IMAGE_COLUMN,
reader=WSIReader,
backend="cuCIM",
dtype=np.uint8,
level=self.level,
image_only=True,
),
TileOnGridd(
keys=slides_dataset.IMAGE_COLUMN,
tile_count=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf,
tile_size=self.tile_size,
step=self.step,
random_offset=self.random_offset if stage == ModelKey.TRAIN else False,
pad_full=self.pad_full,
background_val=self.background_val,
filter_mode=self.filter_mode,
return_list_of_dicts=True,
),
]
load_image_transform = LoadImaged(
keys=slides_dataset.IMAGE_COLUMN,
reader=WSIReader, # type: ignore
backend="cuCIM",
dtype=np.uint8,
level=self.level,
image_only=True,
)
if self.transforms_dict and self.transforms_dict[stage]:
max_offset = None if (self.random_offset and stage == ModelKey.TRAIN) else 0
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
random_grid_transform = RandGridPatchd(
keys=[slides_dataset.IMAGE_COLUMN],
patch_size=[self.tile_size, self.tile_size], # type: ignore
num_patches=self.max_bag_size if stage == ModelKey.TRAIN else self.max_bag_size_inf,
sort_fn=self.filter_mode,
pad_mode="constant",
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
constant_values=self.background_val,
overlap=self.overlap, # type: ignore
threshold=self.intensity_threshold,
max_offset=max_offset,
)
base_transform = Compose([load_image_transform, random_grid_transform])
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved

transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten()
if self.transforms_dict and self.transforms_dict[stage]:
transforms = Compose([base_transform, self.transforms_dict[stage]]).flatten() # type: ignore
else:
transforms = base_transform
# The tiling transform is randomized. Make them deterministic. This call needs to be
Expand All @@ -325,7 +325,7 @@ def _get_dataloader(self, dataset: SlidesDataset, stage: ModelKey, shuffle: bool
return DataLoader(
transformed_slides_dataset,
batch_size=self.batch_size,
collate_fn=image_collate,
collate_fn=array_collate,
shuffle=shuffle,
generator=generator,
**dataloader_kwargs,
Expand Down
105 changes: 90 additions & 15 deletions hi-ml-histopathology/src/histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,20 +352,95 @@ def get_bag_label(labels: Tensor) -> Tensor:
# SlidesDataModule attributes a single label to a bag of tiles already no need to do majority voting
return labels

@staticmethod
def get_empty_lists(shape: int, n: int) -> List:
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
ll = []
for _ in range(n):
ll.append([None] * shape)
return ll
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def get_patch_coordinate(slide_offset: List, patch_location: List, patch_size: List) -> Tuple[int, int, int, int]:
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would you think of using our Box class?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 on that. Tuples with 4 elements of the same type are just too easy to mix up

""" computing absolute patch coordinate """
# PATCH_LOCATION is expected to have shape [y, x]
top = slide_offset[0] + patch_location[0]
bottom = slide_offset[0] + patch_location[0] + patch_size[0]
left = slide_offset[1] + patch_location[1]
right = slide_offset[1] + patch_location[1] + patch_size[1]
return top, bottom, left, right

@staticmethod
def expand_slide_constant_metadata(id: str, path: str, n_patches: int) -> Tuple[List, List, List]:
"""Duplicate metadata that is patch invariant to match the shape of other arrays"""
slide_id = [id] * n_patches
image_paths = [path] * n_patches
tile_id = [f"{id}_{tile_id}" for tile_id in range(n_patches)]
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
return slide_id, image_paths, tile_id

@staticmethod
def check_patch_location_format(batch):
"""Workaround for bug in MONAI that returns not consistent location"""
faulty_slides_idx = []
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
for i, locations in enumerate(batch[SlideKey.PATCH_LOCATION]):
for location in locations:
if len(location) != 2:
print(f'Slide {batch[SlideKey.SLIDE_ID][i]} '
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
f'will be skipped as its patches contained unexpected values in patch_location {location}')
faulty_slides_idx.append(batch[SlideKey.SLIDE_ID][i])
break
n = len(faulty_slides_idx)
if n > 0:
print(f'{n} slides will be skipped because something was wrong in the patch location')
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
return faulty_slides_idx

def get_slide_patch_coordinates(
self, slide_offset: List, patches_location: List, patch_size: List
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add doc string please?

) -> Tuple[List, List, List, List]:
""" computing absolute coordinates for all patches in a slide"""
top, bottom, left, right = self.get_empty_lists(len(patches_location), 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might this be cleaner using numpy arrays, perhaps?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dccastro , I like your suggestion above to use the Box class instead. Much cleaner than using a 4-tuple of ints, that's just calling for errors to happen.

for i, location in enumerate(patches_location):
top[i], bottom[i], left[i], right[i] = self.get_patch_coordinate(slide_offset, location, patch_size)
return top, bottom, left, right

def compute_slide_metadata(self, batch: Dict, index: int, metadata_dict: Dict) -> Dict:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this method need to mutate the input metadata_dict in-place, instead of returning a new dictionary?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1. If needs to mutate in place, then the signature and name of the function should reflect that mutation. (update_metadata, -> None)

"""compute patch-dependent and patch-invariante metadata for a single slide """
offset = batch[SlideKey.OFFSET.value][index]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Minor] The .value shouldn't be necessary

patches_location = batch[SlideKey.PATCH_LOCATION.value][index]
patch_size = batch[SlideKey.PATCH_SIZE.value][index]
n_patches = len(patches_location)
id = batch[SlideKey.SLIDE_ID][index]
path = batch[SlideKey.IMAGE_PATH][index]

top, bottom, left, right = self.get_slide_patch_coordinates(offset, patches_location, patch_size)
slide_id, image_paths, tile_id = self.expand_slide_constant_metadata(id, path, n_patches)

metadata_dict[ResultsKey.TILE_TOP] = top
metadata_dict[ResultsKey.TILE_BOTTOM] = bottom
metadata_dict[ResultsKey.TILE_LEFT] = left
metadata_dict[ResultsKey.TILE_RIGHT] = right
metadata_dict[ResultsKey.SLIDE_ID] = slide_id
metadata_dict[ResultsKey.TILE_ID] = tile_id
metadata_dict[ResultsKey.IMAGE_PATH] = image_paths
return metadata_dict

def update_results_with_data_specific_info(self, batch: Dict, results: Dict) -> None:
# WARNING: This is a dummy input until we figure out tiles coordinates retrieval in the next iteration.
bag_sizes = [tiles.shape[0] for tiles in batch[SlideKey.IMAGE]]
results.update(
{
ResultsKey.SLIDE_ID: [
[slide_id] * bag_sizes[i] for i, slide_id in enumerate(batch[SlideKey.SLIDE_ID])
],
ResultsKey.TILE_ID: [
[f"{slide_id}_{tile_id}" for tile_id in range(bag_sizes[i])]
for i, slide_id in enumerate(batch[SlideKey.SLIDE_ID])
],
ResultsKey.IMAGE_PATH: [
[img_path] * bag_sizes[i] for i, img_path in enumerate(batch[SlideKey.IMAGE_PATH])
],
if all(key.value in batch.keys() for key in [SlideKey.OFFSET, SlideKey.PATCH_LOCATION, SlideKey.PATCH_SIZE]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a set operation would be easier to understand. set(SlideKey.OFFSET, SlideKey....) <= set(batch.keys()

n_slides = len(batch[SlideKey.SLIDE_ID])
metadata_dict = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couldn't this be replaced by just adding if key not in results: results[key] = [] in the loop below?

ResultsKey.TILE_TOP: [],
ResultsKey.TILE_BOTTOM: [],
ResultsKey.TILE_LEFT: [],
ResultsKey.TILE_RIGHT: [],
ResultsKey.SLIDE_ID: [],
ResultsKey.TILE_ID: [],
ResultsKey.IMAGE_PATH: [],
}
)
results.update(metadata_dict)
# each slide can have a different number of patches
for i in range(n_slides):
updated_metadata_dict = self.compute_slide_metadata(batch, i, metadata_dict)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if all this metadata processing should be wrapped into a transform. Data is prefetched anyways it should be more efficient to apply it as a transform that is muti processed by the dataloader. Also from a design perspective, the model shouldn't be handling any metadata processing...

for key in metadata_dict.keys():
results[key].append(updated_metadata_dict[key])
else:
rank_zero_warn(message="Offset, patch location or patch size are not found in the batch"
"make sure to use RandGridPatch.")
5 changes: 5 additions & 0 deletions hi-ml-histopathology/src/histopathology/utils/naming.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# ------------------------------------------------------------------------------------------

from enum import Enum
from monai.utils import WSIPatchKeys


class SlideKey(str, Enum):
Expand All @@ -19,6 +20,10 @@ class SlideKey(str, Enum):
FOREGROUND_THRESHOLD = 'foreground_threshold'
METADATA = 'metadata'
LOCATION = 'location'
PATCH_SIZE = WSIPatchKeys.SIZE.value # 'patch_size'
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
PATCH_LOCATION = WSIPatchKeys.LOCATION.value # 'patch_location'
OFFSET = 'offset'
SHAPE = 'original_spatial_shape'


class TileKey(str, Enum):
Expand Down
10 changes: 9 additions & 1 deletion hi-ml-histopathology/src/histopathology/utils/output_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def normalize_dict_for_df(dict_old: Dict[ResultsKey, Any]) -> Dict[str, Any]:
value = value.squeeze(0).cpu().numpy()
if value.ndim == 0:
value = np.full(bag_size, fill_value=value)
if isinstance(value, List) and isinstance(value[0], torch.Tensor):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why do we need to do that here? is that for the coordinates? Maybe turn it into numpy arrays in the first place?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function could do with a lot more documentation. A docstring would be great. And each of the if branches should have a clear description which case we are handling here, and where those cases arise.

value = [value[i].item() for i in range(len(value))]
dict_new[key] = value
elif key == ResultsKey.CLASS_PROBS:
if isinstance(value, torch.Tensor):
Expand Down Expand Up @@ -134,11 +136,17 @@ def save_outputs_csv(results: ResultsType, outputs_dir: Path) -> None:

# Collect the list of dictionaries in a list of pandas dataframe and save
df_list = []
skipped_slides = 0
for slide_dict in list_slide_dicts:
slide_dict = normalize_dict_for_df(slide_dict) # type: ignore
df_list.append(pd.DataFrame.from_dict(slide_dict))
try:
df_list.append(pd.DataFrame.from_dict(slide_dict))
except ValueError:
skipped_slides += 1
logging.warning(f"something wrong in the dimension of slide {slide_dict[ResultsKey.SLIDE_ID][0]}")
df = pd.concat(df_list, ignore_index=True)
df.to_csv(csv_filename, mode='w+', header=True)
logging.warning(f"{skipped_slides} slides have not been included in the ouputs because of issues with the outputs")
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved


def save_features(results: ResultsType, outputs_dir: Path) -> None:
Expand Down
Loading