Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Add histopathology module and add hi-ml as submodule #603

Merged
merged 24 commits into from
Dec 6, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "fastMRI"]
path = fastMRI
url = https://github.com/facebookresearch/fastMRI
[submodule "hi-ml"]
path = hi-ml
url = https://github.com/microsoft/hi-ml
157 changes: 157 additions & 0 deletions InnerEye/ML/Histopathology/datamodules/base_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import pickle
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import torch
from monai.data.dataset import CacheDataset, Dataset, PersistentDataset
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader

from health_ml.utils.bag_utils import BagDataset, multibag_collate
from InnerEye.ML.Histopathology.datasets.base_dataset import TilesDataset
from InnerEye.ML.Histopathology.models.transforms import LoadTilesBatchd


def _create_generator(seed: Optional[int]) -> torch.Generator:
mebristo marked this conversation as resolved.
Show resolved Hide resolved
generator = torch.Generator()
if seed is None:
seed = int(torch.empty((), dtype=torch.int64).random_().item())
generator.manual_seed(seed)
return generator


class CacheMode(Enum):
NONE = 'none'
MEMORY = 'memory'
DISK = 'disk'


class TilesDataModule(LightningDataModule):
"""Base class to load the tiles of a dataset as train, val, test sets"""

def __init__(self, root_path: Path, max_bag_size: int = 0, batch_size: int = 1,
seed: Optional[int] = None, transform: Optional[Callable] = None,
cache_mode: CacheMode = CacheMode.NONE, save_precache: bool = False,
cache_dir: Optional[Path] = None,
number_of_cross_validation_splits: int = 0,
cross_validation_split_index: int = 0) -> None:
"""
:param root_path: Root directory of the source dataset.
:param max_bag_size: Upper bound on number of tiles in each loaded bag. If 0 (default),
will return all samples in each bag. If > 0 , bags larger than `max_bag_size` will yield
random subsets of instances.
:param batch_size: Number of slides to load per batch.
:param seed: PRNG seed to use for shuffling instances and bags. Note that randomness in
mebristo marked this conversation as resolved.
Show resolved Hide resolved
train/val/test splits is handled independently in `get_splits()`. (default: `None`)
:param transform: A transform to apply to the source tiles dataset, or a composition of
transforms using `monai.transforms.Compose`. By default (`None`), applies `LoadTilesBatchd`.
:param cache_mode: The type of caching to perform, i.e. whether the results of all
transforms up to the first randomised one should be computed only once and reused in
subsequent iterations:
- `MEMORY`: the entire transformed dataset is kept in memory for fastest access;
- `DISK`: each transformed sample is saved to disk and loaded on-demand;
- `NONE` (default): no caching is performed.
:param save_precache: Whether to pre-cache the entire transformed dataset upfront and save
it to disk. This is done once in `prepare_data()` only on the local rank-0 process, so
multiple processes can afterwards access the same cache without contention in DDP settings.
:param cache_dir: The directory onto which to cache data if caching is enabled.
:param number_of_cross_validation_splits: Number of folds to perform.
:param cross_validation_split_index: Index of the cross validation split to be performed.
"""
if save_precache and cache_mode is CacheMode.NONE:
raise ValueError("Can only pre-cache if caching is enabled")
if save_precache and cache_dir is None:
raise ValueError("A cache directory is required for pre-caching")
if cache_mode is CacheMode.DISK and cache_dir is None:
raise ValueError("A cache directory is required for on-disk caching")
super().__init__()

self.root_path = root_path
self.max_bag_size = max_bag_size
self.transform = transform
self.cache_mode = cache_mode
self.save_precache = save_precache
self.cache_dir = cache_dir
self.batch_size = batch_size
self.number_of_cross_validation_splits = number_of_cross_validation_splits
self.cross_validation_split_index = cross_validation_split_index
self.train_dataset, self.val_dataset, self.test_dataset = self.get_splits()
self.class_weights = self.train_dataset.get_class_weights()
self.seed = seed

def get_splits(self) -> Tuple[TilesDataset, TilesDataset, TilesDataset]:
"""Create the training, validation, and test datasets"""
raise NotImplementedError

def prepare_data(self) -> None:
if self.save_precache:
self._load_dataset(self.train_dataset, stage='train', shuffle=True)
self._load_dataset(self.val_dataset, stage='val', shuffle=True)
self._load_dataset(self.test_dataset, stage='test', shuffle=True)

def _dataset_pickle_path(self, stage: str) -> Optional[Path]:
if self.cache_dir is None:
return None
return self.cache_dir / f"{stage}_dataset.pkl"

def _load_dataset(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool) -> Dataset:
dataset_pickle_path = self._dataset_pickle_path(stage)

if dataset_pickle_path and dataset_pickle_path.exists():
with dataset_pickle_path.open('rb') as f:
return pickle.load(f)

generator = _create_generator(self.seed)
bag_dataset = BagDataset(tiles_dataset, # type: ignore
bag_ids=tiles_dataset.slide_ids,
max_bag_size=self.max_bag_size,
shuffle_samples=shuffle,
generator=generator)
transform = self.transform or LoadTilesBatchd(tiles_dataset.IMAGE_COLUMN)

# Save and restore PRNG state for consistency across (pre-)caching options
generator_state = generator.get_state()
transformed_bag_dataset = self._get_transformed_dataset(bag_dataset, transform) # type: ignore
generator.set_state(generator_state)

if dataset_pickle_path:
dataset_pickle_path.parent.mkdir(parents=True, exist_ok=True)
with dataset_pickle_path.open('wb') as f:
pickle.dump(transformed_bag_dataset, f)

return transformed_bag_dataset

def _get_transformed_dataset(self, base_dataset: BagDataset,
transform: Union[Sequence[Callable], Callable]) -> Dataset:
if self.cache_mode is CacheMode.MEMORY:
dataset = CacheDataset(base_dataset, transform, num_workers=1) # type: ignore
elif self.cache_mode is CacheMode.DISK:
dataset = PersistentDataset(base_dataset, transform, cache_dir=self.cache_dir) # type: ignore
if self.save_precache:
import tqdm # TODO: Make optional

for i in tqdm.trange(len(dataset), desc="Loading dataset"):
dataset[i] # empty loop to pre-compute all transformed samples
else:
dataset = Dataset(base_dataset, transform) # type: ignore
return dataset

def _get_dataloader(self, tiles_dataset: TilesDataset, stage: str, shuffle: bool,
**dataloader_kwargs: Any) -> DataLoader:
transformed_bag_dataset = self._load_dataset(tiles_dataset, stage=stage, shuffle=shuffle)
bag_dataset: BagDataset = transformed_bag_dataset.data # type: ignore
generator = bag_dataset.bag_sampler.generator
return DataLoader(transformed_bag_dataset, batch_size=self.batch_size,
collate_fn=multibag_collate, shuffle=shuffle, generator=generator,
pin_memory=False, # disable pinning as loaded data may already be on GPU
**dataloader_kwargs)

def train_dataloader(self) -> DataLoader:
return self._get_dataloader(self.train_dataset, 'train', shuffle=True)

def val_dataloader(self) -> DataLoader:
return self._get_dataloader(self.val_dataset, 'val', shuffle=True)

def test_dataloader(self) -> DataLoader:
return self._get_dataloader(self.test_dataset, 'test', shuffle=True)
23 changes: 23 additions & 0 deletions InnerEye/ML/Histopathology/datamodules/panda_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Tuple

from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.utils.split_dataset import DatasetSplits


class PandaTilesDataModule(TilesDataModule):
""" PandaTilesDataModule is the child class of TilesDataModule specific to PANDA dataset
Method get_splits() returns the train, val, test splits from the PANDA dataset
"""

def get_splits(self) -> Tuple[PandaTilesDataset, PandaTilesDataset, PandaTilesDataset]:
dataset = PandaTilesDataset(self.root_path)
splits = DatasetSplits.from_proportions(dataset.dataset_df.reset_index(),
proportion_train=.8,
proportion_test=.1,
proportion_val=.1,
subject_column=dataset.TILE_ID_COLUMN,
group_column=dataset.SLIDE_ID_COLUMN)
return (PandaTilesDataset(self.root_path, dataset_df=splits.train),
PandaTilesDataset(self.root_path, dataset_df=splits.val),
PandaTilesDataset(self.root_path, dataset_df=splits.test))
33 changes: 33 additions & 0 deletions InnerEye/ML/Histopathology/datamodules/tcga_crck_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Tuple, Any

from InnerEye.ML.Histopathology.datamodules.base_module import TilesDataModule
from InnerEye.ML.Histopathology.datasets.tcga_crck_tiles_dataset import TcgaCrck_TilesDataset
from InnerEye.ML.utils.split_dataset import DatasetSplits
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved


class TcgaCrckTilesDataModule(TilesDataModule):
""" TcgaCrckTilesDataModule is the child class of TilesDataModule specific to TCGA-Crck dataset
Method get_splits() returns the train, val, test splits from the TCGA-Crck dataset
Methods train_dataloader(), val_dataloader() and test_dataloader() override the base class methods for bag loading
"""

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)

def get_splits(self) -> Tuple[TcgaCrck_TilesDataset, TcgaCrck_TilesDataset, TcgaCrck_TilesDataset]:
trainval_dataset = TcgaCrck_TilesDataset(self.root_path, train=True)
splits = DatasetSplits.from_proportions(trainval_dataset.dataset_df.reset_index(),
proportion_train=0.8,
proportion_test=0.0,
proportion_val=0.2,
subject_column=trainval_dataset.TILE_ID_COLUMN,
group_column=trainval_dataset.SLIDE_ID_COLUMN,
random_seed=5)

if self.number_of_cross_validation_splits > 1:
# Function get_k_fold_cross_validation_splits() will concatenate train and val splits
splits = splits.get_k_fold_cross_validation_splits(self.number_of_cross_validation_splits)[self.cross_validation_split_index]

return (TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.train),
TcgaCrck_TilesDataset(self.root_path, dataset_df=splits.val),
TcgaCrck_TilesDataset(self.root_path, train=False))
107 changes: 107 additions & 0 deletions InnerEye/ML/Histopathology/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

import numpy as np
import pandas as pd
import torch
from sklearn.utils.class_weight import compute_class_weight
from torch.utils.data import Dataset


class TilesDataset(Dataset):
"""Base class for datasets of WSI tiles, iterating dictionaries of image paths and metadata.

:param TILE_ID_COLUMN: CSV column name for tile ID.
:param SLIDE_ID_COLUMN: CSV column name for slide ID.
:param IMAGE_COLUMN: CSV column name for relative path to image file.
:param PATH_COLUMN: CSV column name for relative path to image file. Replicated to propagate the path to the batch.
:param LABEL_COLUMN: CSV column name for tile label.
:param SPLIT_COLUMN: CSV column name for train/test split (optional).
:param TILE_X_COLUMN: CSV column name for horizontal tile coordinate (optional).
:param TILE_Y_COLUMN: CSV column name for vertical tile coordinate (optional).
:param TRAIN_SPLIT_LABEL: Value used to indicate the training split in `SPLIT_COLUMN`.
:param TEST_SPLIT_LABEL: Value used to indicate the test split in `SPLIT_COLUMN`.
:param DEFAULT_CSV_FILENAME: Default name of the dataset CSV at the dataset rood directory.
:param N_CLASSES: Number of classes indexed in `LABEL_COLUMN`.
"""
TILE_ID_COLUMN: str = 'tile_id'
SLIDE_ID_COLUMN: str = 'slide_id'
IMAGE_COLUMN: str = 'image'
PATH_COLUMN: str = 'image_path'
LABEL_COLUMN: str = 'label'
SPLIT_COLUMN: Optional[str] = 'split'
TILE_X_COLUMN: Optional[str] = 'tile_x'
TILE_Y_COLUMN: Optional[str] = 'tile_y'

TRAIN_SPLIT_LABEL: str = 'train'
TEST_SPLIT_LABEL: str = 'test'

DEFAULT_CSV_FILENAME: str = "dataset.csv"

N_CLASSES: int = 1 # binary classification by default
mebristo marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self,
root: Union[str, Path],
dataset_csv: Optional[Union[str, Path]] = None,
dataset_df: Optional[pd.DataFrame] = None,
train: Optional[bool] = None) -> None:
"""
:param root: Root directory of the dataset.
:param dataset_csv: Full path to a dataset CSV file, containing at least
`TILE_ID_COLUMN`, `SLIDE_ID_COLUMN`, and `IMAGE_COLUMN`. If omitted, the CSV will be read
from `"{root}/{DEFAULT_CSV_FILENAME}"`.
:param dataset_df: A potentially pre-processed dataframe in the same format as would be read
from the dataset CSV file, e.g. after some filtering. If given, overrides `dataset_csv`.
:param train: If `True`, loads only the training split (resp. `False` for test split). By
default (`None`), loads the entire dataset as-is.
"""
if self.SPLIT_COLUMN is None and train is not None:
raise ValueError("Train/test split was specified but dataset has no split column")

self.root_dir = Path(root)

if dataset_df is not None:
self.dataset_csv = None
else:
self.dataset_csv = dataset_csv or self.root_dir / self.DEFAULT_CSV_FILENAME
dataset_df = pd.read_csv(self.dataset_csv)

columns = [self.SLIDE_ID_COLUMN, self.IMAGE_COLUMN, self.LABEL_COLUMN, self.LABEL_COLUMN,
self.SPLIT_COLUMN, self.TILE_X_COLUMN, self.TILE_Y_COLUMN]
for column in columns:
if column is not None and column not in dataset_df.columns:
raise ValueError(f"Expected column '{column}' not found in the dataframe")

dataset_df = dataset_df.set_index(self.TILE_ID_COLUMN)
if train is None:
self.dataset_df = dataset_df
else:
split = self.TRAIN_SPLIT_LABEL if train else self.TEST_SPLIT_LABEL
self.dataset_df = dataset_df[dataset_df[self.SPLIT_COLUMN] == split]

def __len__(self) -> int:
return self.dataset_df.shape[0]

def __getitem__(self, index: int) -> Dict[str, Any]:
tile_id = self.dataset_df.index[index]
sample = {
self.TILE_ID_COLUMN: tile_id,
**self.dataset_df.loc[tile_id].to_dict()
}
sample[self.IMAGE_COLUMN] = str(self.root_dir / sample.pop(self.IMAGE_COLUMN))
# we're replicating this column because we want to propagate the path to the batch
sample[self.PATH_COLUMN] = sample[self.IMAGE_COLUMN]
return sample

@property
def slide_ids(self) -> pd.Series:
return self.dataset_df[self.SLIDE_ID_COLUMN]

def get_slide_labels(self) -> pd.Series:
return self.dataset_df.groupby(self.SLIDE_ID_COLUMN)[self.LABEL_COLUMN].agg(pd.Series.mode)

def get_class_weights(self) -> torch.Tensor:
slide_labels = self.get_slide_labels()
classes = np.unique(slide_labels)
class_weights = compute_class_weight('balanced', classes, slide_labels)
return torch.as_tensor(class_weights)
8 changes: 8 additions & 0 deletions InnerEye/ML/Histopathology/datasets/default_paths.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
PANDA_TILES_DATASET_ID = "PANDA_tiles"
TCGA_CRCK_DATASET_ID = "TCGA-CRCk"
TCGA_PRAD_DATASET_ID = "TCGA-PRAD"

DEFAULT_DATASET_LOCATION = "/tmp/datasets/"
PANDA_TILES_DATASET_DIR = DEFAULT_DATASET_LOCATION + PANDA_TILES_DATASET_ID
TCGA_CRCK_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_CRCK_DATASET_ID
TCGA_PRAD_DATASET_DIR = DEFAULT_DATASET_LOCATION + TCGA_PRAD_DATASET_ID
Loading