Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable fine-tuning in Deepmil #650

Merged
merged 28 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs that run in AzureML.
- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
11 changes: 9 additions & 2 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def __init__(self,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1,
class_names: Optional[List[str]] = None) -> None:
class_names: Optional[List[str]] = None,
is_finetune: Optional[bool] = False) -> None:
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
Expand All @@ -73,6 +74,7 @@ def __init__(self,
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: The names of the classes if available (default=None).
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
"""
super().__init__()

Expand Down Expand Up @@ -112,6 +114,8 @@ def __init__(self,

self.verbose = verbose

self.is_finetune = is_finetune

self.aggregation_fn, self.num_pooling = self.get_pooling()
self.classifier_fn = self.get_classifier()
self.loss_fn = self.get_loss()
Expand Down Expand Up @@ -187,8 +191,11 @@ def log_metrics(self,
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)

def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with no_grad():
if self.is_finetune:
H = self.encoder(images) # N X L x 1 x 1
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
else:
with no_grad():
dccastro marked this conversation as resolved.
Show resolved Hide resolved
H = self.encoder(images) # N X L x 1 x 1
A, M = self.aggregation_fn(H) # A: K x N | M: K x L
M = M.view(-1, self.num_encoding * self.pool_out_dim)
Y_prob = self.classifier_fn(M)
Expand Down
4 changes: 3 additions & 1 deletion InnerEye/ML/Histopathology/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,6 @@ class HistoSSLEncoder(TileEncoder):
def _get_encoder(self) -> Tuple[Callable, int]:
resnet18_model = resnet18(pretrained=False)
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore
histossl_encoder.fc = torch.nn.Sequential()
num_features = 512
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
return histossl_encoder, num_features # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion to revert these changes in _get_encoder(), as setup_feature_extractor() now implements the same behaviour.

28 changes: 20 additions & 8 deletions InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.common import get_best_checkpoint_path
Expand Down Expand Up @@ -41,6 +42,9 @@ def __init__(self, **kwargs: Any) -> None:
pooling_type=GatedAttentionLayer.__name__,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.SAME,
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
batch_size=8,
harshita-s marked this conversation as resolved.
Show resolved Hide resolved

# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
Expand Down Expand Up @@ -78,6 +82,7 @@ def __init__(self, **kwargs: Any) -> None:
mode="max",
)
self.callbacks = best_checkpoint_callback
self.is_finetune = False

@property
def cache_dir(self) -> Path:
Expand All @@ -103,12 +108,14 @@ def setup(self) -> None:

harshita-s marked this conversation as resolved.
Show resolved Hide resolved
def get_data_module(self) -> PandaTilesDataModule:
image_key = PandaTilesDataset.IMAGE_COLUMN
transform = Compose(
[
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size),
]
)
if self.is_finetune:
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
else:
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
])

return PandaTilesDataModule(
root_path=self.local_dataset,
max_bag_size=self.max_bag_size,
Expand All @@ -128,7 +135,11 @@ def create_model(self) -> DeepMILModule:
self.slide_dataset = self.get_slide_dataset()
self.level = 1
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
if self.is_finetune:
self.model_encoder = self.encoder
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
Expand All @@ -139,7 +150,8 @@ def create_model(self) -> DeepMILModule:
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level,
class_names=self.class_names)
class_names=self.class_names,
is_finetune=self.is_finetune)

def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
Expand Down
23 changes: 20 additions & 3 deletions Tests/SSL/test_ssl_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,24 @@ def test_simclr_dataloader_type() -> None:
""" This test checks if the transform pipeline of a SSL job can handle different
data types coming from the dataloader.
"""
def check_types_in_dataloader(dataloader: CombinedLoader) -> None:
# TODO: Once the pytorch lightning bug is fixed the following test can be removed.
# The training and val loader will be both CombinedLoaders
def check_types_in_train_dataloader(dataloader: dict) -> None:
for i, batch in enumerate(dataloader[SSLDataModuleType.ENCODER]):
assert isinstance(batch[0][0], torch.Tensor)
assert isinstance(batch[0][1], torch.Tensor)
assert isinstance(batch[1], torch.Tensor)
if i == 1:
break

for i, batch in enumerate(dataloader[SSLDataModuleType.LINEAR_HEAD]):
assert isinstance(batch[0], torch.Tensor)
assert isinstance(batch[1], torch.Tensor)
assert isinstance(batch[2], torch.Tensor)
if i == 1:
break
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved

def check_types_in_val_dataloader(dataloader: CombinedLoader) -> None:
for i, batch in enumerate(dataloader):
assert isinstance(batch[SSLDataModuleType.ENCODER][0][0], torch.Tensor)
assert isinstance(batch[SSLDataModuleType.ENCODER][0][1], torch.Tensor)
Expand All @@ -646,8 +663,8 @@ def check_types_in_dataloader(dataloader: CombinedLoader) -> None:
break

def check_types_in_train_and_val(data: CombinedDataModule) -> None:
check_types_in_dataloader(data.train_dataloader())
check_types_in_dataloader(data.val_dataloader())
check_types_in_train_dataloader(data.train_dataloader())
check_types_in_val_dataloader(data.val_dataloader())

container = DummySimCLR()
container.setup()
Expand Down