Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Add histopathology module and add hi-ml as submodule #603

Merged
merged 24 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .amlignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ pull_request_template.md
SECURITY.md
__pycache__
azure-pipelines
datasets
/datasets
docs
sphinx-docs
modelweights
Expand All @@ -35,4 +35,5 @@ tensorboard_runs
InnerEyeTestVariables.txt
InnerEyePrivateSettings.yml
cifar-10-batches-py
cifar-100-python
cifar-100-python
!**/InnerEye/ML/Histopathology/datasets
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ packages-microsoft-prod.deb

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
# before PyInstaller builds the exe, so as to inject date/other infos into it
*.manifest
*.spec

Expand Down Expand Up @@ -166,3 +166,5 @@ InnerEye-DataQuality/name_stats_scoring.png
InnerEye-DataQuality/cifar-10-batches-py
InnerEye-DataQuality/logs
InnerEye-DataQuality/data

!**/InnerEye/ML/Histopathology/datasets
mebristo marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "fastMRI"]
path = fastMRI
url = https://github.com/facebookresearch/fastMRI
[submodule "hi-ml"]
path = hi-ml
url = https://github.com/microsoft/hi-ml
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ created.

### Added
- ([#594](https://github.com/microsoft/InnerEye-DeepLearning/pull/594)) When supplying a "--tag" argument, the AzureML jobs use that value as the display name, to more easily distinguish run.
- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor
- ([#577](https://github.com/microsoft/InnerEye-DeepLearning/pull/577)) Commandline switch `monitor_gpu` to monitor
GPU utilization via Lightning's `GpuStatsMonitor`, switch `monitor_loading` to check batch loading times via
`BatchTimeCallback`, and `pl_profiler` to turn on the Lightning profiler (`simple`, `advanced`, or `pytorch`)
- ([#544](https://github.com/microsoft/InnerEye-DeepLearning/pull/544)) Add documentation for segmentation model evaluation.
Expand All @@ -31,6 +31,8 @@ jobs that run in AzureML.
- ([#559](https://github.com/microsoft/InnerEye-DeepLearning/pull/559)) Adding the accompanying code for the ["Active label cleaning: Improving dataset quality under resource constraints"](https://arxiv.org/abs/2109.00574) paper. The code can be found in the [InnerEye-DataQuality](InnerEye-DataQuality/README.md) subfolder. It provides tools for training noise robust models, running label cleaning simulation and loading our label cleaning benchmark datasets.
- ([#589](https://github.com/microsoft/InnerEye-DeepLearning/pull/589)) Add `LightningContainer.update_azure_config()`
hook to enable overriding `AzureConfig` parameters from a container (e.g. `experiment_name`, `cluster`, `num_nodes`).
-([#603](https://github.com/microsoft/InnerEye-DeepLearning/pull/603)) Add histopathology module


### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
149 changes: 149 additions & 0 deletions InnerEye/ML/Histopathology/datamodules/base_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
import pickle
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union

from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from health_ml.utils.bag_utils import BagDataset, multibag_collate
from health_ml.utils.common_utils import _create_generator
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from InnerEye.ML.Histopathology.models.transforms import LoadTilesBatchd


class CacheMode(Enum):
NONE = 'none'
MEMORY = 'memory'
DISK = 'disk'


class TilesDataModule(LightningDataModule):
"""Base class to load the tiles of a dataset as train, val, test sets"""

def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
seed: Optional[int] = None, transform: Optional[Callable] = None,
cache_mode: CacheMode = CacheMode.NONE, save_precache: bool = False,
cache_dir: Optional[Path] = None,
number_of_cross_validation_splits: int = 0,
cross_validation_split_index: int = 0) -> None:
"""
:param root_path: Root directory of the source dataset.
:param max_bag_size: Upper bound on number of tiles in each loaded bag. If 0 (default),
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
random subsets of instances.
:param batch_size: Number of slides to load per batch.
:param seed: pseudorandom number generator seed to use for shuffling instances and bags. Note that randomness in
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
:param transform: A transform to apply to the source tiles dataset, or a composition of
transforms using `monai.transforms.Compose`. By default (`None`), applies `LoadTilesBatchd`.
:param cache_mode: The type of caching to perform, i.e. whether the results of all
transforms up to the first randomised one should be computed only once and reused in
subsequent iterations:
- `MEMORY`: the entire transformed dataset is kept in memory for fastest access;
- `DISK`: each transformed sample is saved to disk and loaded on-demand;
- `NONE` (default): no caching is performed.
:param save_precache: Whether to pre-cache the entire transformed dataset upfront and save
it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so
multiple processes can afterwards access the same cache without contention in DDP settings.
:param cache_dir: The directory onto which to cache data if caching is enabled.
:param number_of_cross_validation_splits: Number of folds to perform.
:param cross_validation_split_index: Index of the cross validation split to be performed.
"""
if save_precache and cache_mode is CacheMode.NONE:
raise ValueError("Can only pre-cache if caching is enabled")
if save_precache and cache_dir is None:
raise ValueError("A cache directory is required for pre-caching")
if cache_mode is CacheMode.DISK and cache_dir is None:
raise ValueError("A cache directory is required for on-disk caching")
super().__init__()

self.root_path = root_path
self.max_bag_size = max_bag_size
self.transform = transform
self.cache_mode = cache_mode
self.save_precache = save_precache
self.cache_dir = cache_dir
self.batch_size = batch_size
self.number_of_cross_validation_splits = number_of_cross_validation_splits
self.cross_validation_split_index = cross_validation_split_index
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
self.class_weights = self.train_dataset.get_class_weights()
self.seed = seed

def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]:
"""Create the training, validation, and test datasets"""
raise NotImplementedError

def prepare_data(self) -> None:
if self.save_precache:
self._load_dataset(self.train_dataset, stage='train', shuffle=True)
self._load_dataset(self.val_dataset, stage='val', shuffle=True)
self._load_dataset(self.test_dataset, stage='test', shuffle=True)

def _dataset_pickle_path(self, stage: str) -> Optional[Path]:
if self.cache_dir is None:
return None
return self.cache_dir / f"{stage}_dataset.pkl"

def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset:
dataset_pickle_path = self._dataset_pickle_path(stage)

if dataset_pickle_path and dataset_pickle_path.exists():
with dataset_pickle_path.open('rb') as f:
return pickle.load(f)

generator = _create_generator(self.seed)
bag_dataset = BagDataset(tiles_dataset, # type: ignore
bag_ids=tiles_dataset.slide_ids,
max_bag_size=self.max_bag_size,
shuffle_samples=shuffle,
generator=generator)
transform = self.transform or LoadTilesBatchd(tiles_dataset.IMAGE_COLUMN)

# Save and restore PRNG state for consistency across (pre-)caching options
generator_state = generator.get_state()
transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore
generator.set_state(generator_state)

if dataset_pickle_path:
dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True)
with dataset_pickle_path.open('wb') as f:
pickle.dump(transformed_bag_dataset, f)

return transformed_bag_dataset

def _get_transformed_dataset(self, base_dataset: BagDataset,
transform: Union[Sequence[Callable], Callable]) -> Dataset:
if self.cache_mode is CacheMode.MEMORY:
dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore
elif self.cache_mode is CacheMode.DISK:
dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore
if self.save_precache:
import tqdm # TODO: Make optional

for i in tqdm.trange(len(dataset), desc="Loading dataset"):
dataset[i] # empty loop to pre-compute all transformed samples
else:
dataset = Dataset(base_dataset, transform) # type: ignore
return dataset

def _get_dataloader(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool,
**dataloader_kwargs: Any) -> DataLoader:
transformed_bag_dataset = self._load_dataset(tiles_dataset, stage=stage, shuffle=shuffle)
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
generator = bag_dataset.bag_sampler.generator
return DataLoader(transformed_bag_dataset, batch_size=self.batch_size,
collate_fn=multibag_collate, shuffle=shuffle, generator=generator,
pin_memory=False, # disable pinning as loaded data may already be on GPU
**dataloader_kwargs)

def train_dataloader(self) -> DataLoader:
return self._get_dataloader(self.train_dataset, 'train', shuffle=True)

def val_dataloader(self) -> DataLoader:
return self._get_dataloader(self.val_dataset, 'val', shuffle=True)

def test_dataloader(self) -> DataLoader:
return self._get_dataloader(self.test_dataset, 'test', shuffle=True)
23 changes: 23 additions & 0 deletions InnerEye/ML/Histopathology/datamodules/panda_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Tuple

from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.utils.split_dataset import DatasetSplits


class PandaTilesDataModule(TilesDataModule):
""" PandaTilesDataModule is the child class of TilesDataModule specific to PANDA dataset
Method get_splits() returns the train, val, test splits from the PANDA dataset
"""

def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]:
dataset = PandaTilesDataset(self.root_path)
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(),
proportion_train=.8,
proportion_test=.1,
proportion_val=.1,
subject_column=dataset.TILE_ID_COLUMN,
group_column=dataset.SLIDE_ID_COLUMN)
return (PandaTilesDataset(self.root_path, dataset_df=splits.train),
PandaTilesDataset(self.root_path, dataset_df=splits.val),
PandaTilesDataset(self.root_path, dataset_df=splits.test))
33 changes: 33 additions & 0 deletions InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Tuple, Any

from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.utils.split_dataset import DatasetSplits
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved


class TcgaCrckTilesDataModule(TilesDataModule):
""" TcgaCrckTilesDataModule is the child class of TilesDataModule specific to TCGA-Crck dataset
Method get_splits() returns the train, val, test splits from the TCGA-Crck dataset
Methods train_dataloader(), val_dataloader() and test_dataloader() override the base class methods for bag loading
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def get_splits(self) -> Tuple[TcgaCrck_TilesDataset, TcgaCrck_TilesDataset, TcgaCrck_TilesDataset]:
trainval_dataset = TcgaCrck_TilesDataset(self.root_path, train=True)
splits = DatasetSplits.from_proportions(trainval_dataset.dataset_df.reset_index(),
proportion_train=0.8,
proportion_test=0.0,
proportion_val=0.2,
subject_column=trainval_dataset.TILE_ID_COLUMN,
group_column=trainval_dataset.SLIDE_ID_COLUMN,
random_seed=5)

if self.number_of_cross_validation_splits > 1:
# Function get_k_fold_cross_validation_splits() will concatenate train and val splits
splits = splits.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)[self.cross_validation_split_index]

return (TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.train),
TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.val),
TcgaCrck_TilesDataset(self.root_path, train=False))
107 changes: 107 additions & 0 deletions InnerEye/ML/Histopathology/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset


class TilesDataset(Dataset):
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.

:param TILE_ID_COLUMN: CSV column name for tile ID.
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
:param IMAGE_COLUMN: CSV column name for relative path to image file.
:param PATH_COLUMN: CSV column name for relative path to image file. Replicated to propagate the path to the batch.
:param LABEL_COLUMN: CSV column name for tile label.
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
:param TILE_X_COLUMN: CSV column name for horizontal tile coordinate (optional).
:param TILE_Y_COLUMN: CSV column name for vertical tile coordinate (optional).
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
"""
TILE_ID_COLUMN: str = 'tile_id'
SLIDE_ID_COLUMN: str = 'slide_id'
IMAGE_COLUMN: str = 'image'
PATH_COLUMN: str = 'image_path'
LABEL_COLUMN: str = 'label'
SPLIT_COLUMN: Optional[str] = 'split'
TILE_X_COLUMN: Optional[str] = 'tile_x'
TILE_Y_COLUMN: Optional[str] = 'tile_y'

TRAIN_SPLIT_LABEL: str = 'train'
TEST_SPLIT_LABEL: str = 'test'

DEFAULT_CSV_FILENAME: str = "dataset.csv"

N_CLASSES: int = 1 # binary classification by default
mebristo marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
train: Optional[bool] = None) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
:param train: If `True`, loads only the training split (resp. `False` for test split). By
default (`None`), loads the entire dataset as-is.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")

self.root_dir = Path(root)

if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)

columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN,
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
for column in columns:
if column is not None and column not in dataset_df.columns:
raise ValueError(f"Expected column '{column}' not found in the dataframe")

dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]

def __len__(self) -> int:
return self.dataset_df.shape[0]

def __getitem__(self, index: int) -> Dict[str, Any]:
tile_id = self.dataset_df.index[index]
sample = {
self.TILE_ID_COLUMN: tile_id,
**self.dataset_df.loc[tile_id].to_dict()
}
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN))
# we're replicating this column because we want to propagate the path to the batch
sample[self.PATH_COLUMN] = sample[self.IMAGE_COLUMN]
return sample

@property
def slide_ids(self) -> pd.Series:
return self.dataset_df[self.SLIDE_ID_COLUMN]

def get_slide_labels(self) -> pd.Series:
return self.dataset_df.groupby(self.SLIDE_ID_COLUMN)[self.LABEL_COLUMN].agg(pd.Series.mode)

def get_class_weights(self) -> torch.Tensor:
slide_labels = self.get_slide_labels()
classes = np.unique(slide_labels)
class_weights = compute_class_weight(class_weight='balanced', classes=classes, y=slide_labels)
return torch.as_tensor(class_weights)
8 changes: 8 additions & 0 deletions InnerEye/ML/Histopathology/datasets/default_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PANDA_TILES_DATASET_ID = "PANDA_tiles"
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"

DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID
Loading