Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable fine-tuning in Deepmil #650

Merged
merged 28 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ jobs that run in AzureML.
- ([#634](https://github.com/microsoft/InnerEye-DeepLearning/pull/634)) Add WSI heatmaps and thumbnails to standard test outputs
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
10 changes: 7 additions & 3 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import more_itertools as mi

from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, no_grad, optim, round
from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix

from InnerEye.Common import fixed_paths
Expand Down Expand Up @@ -54,7 +54,8 @@ def __init__(self,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1,
class_names: Optional[List[str]] = None) -> None:
class_names: Optional[List[str]] = None,
is_finetune: bool = False) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
Expand All @@ -73,6 +74,7 @@ def __init__(self,
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: The names of the classes if available (default=None).
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
"""
super().__init__()

Expand Down Expand Up @@ -112,6 +114,8 @@ def __init__(self,

self.verbose = verbose

self.is_finetune = is_finetune

self.aggregation_fn, self.num_pooling = self.get_pooling()
self.classifier_fn = self.get_classifier()
self.loss_fn = self.get_loss()
Expand Down Expand Up @@ -187,7 +191,7 @@ def log_metrics(self,
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)

def forward(self, images: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with no_grad():
with set_grad_enabled(self.is_finetune):
H = self.encoder(images) # N X L x 1 x 1
harshita-s marked this conversation as resolved.
Show resolved Hide resolved
A, M = self.aggregation_fn(H) # A: K x N | M: K x L
M = M.view(-1, self.num_encoding * self.pool_out_dim)
Expand Down
4 changes: 3 additions & 1 deletion InnerEye/ML/Histopathology/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,6 @@ class HistoSSLEncoder(TileEncoder):
def _get_encoder(self) -> Tuple[Callable, int]:
resnet18_model = resnet18(pretrained=False)
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore
histossl_encoder.fc = torch.nn.Sequential()
num_features = resnet18_model.fc.in_features
return histossl_encoder, num_features # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion to revert these changes in _get_encoder(), as setup_feature_extractor() now implements the same behaviour.

36 changes: 26 additions & 10 deletions InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.common import get_best_checkpoint_path
Expand All @@ -35,12 +36,18 @@


class DeepSMILEPanda(BaseMIL):
"""`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and
cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.SAME takes ~ 5 min/epoch. Fine-tuning is tested with batch size 8 on PANDA.
"""
def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.NONE,
max_bag_size=10,

# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
Expand Down Expand Up @@ -78,6 +85,7 @@ def __init__(self, **kwargs: Any) -> None:
mode="max",
)
self.callbacks = best_checkpoint_callback
self.is_finetune = True

@property
def cache_dir(self) -> Path:
Expand All @@ -98,17 +106,20 @@ def setup(self) -> None:
os.chdir(fixed_paths.repository_parent_directory())
self.downloader.download_checkpoint_if_necessary()
self.encoder = self.get_encoder()
self.encoder.cuda()
self.encoder.eval()
# self.encoder.cuda()
if not self.is_finetune:
self.encoder.eval()

harshita-s marked this conversation as resolved.
Show resolved Hide resolved
def get_data_module(self) -> PandaTilesDataModule:
image_key = PandaTilesDataset.IMAGE_COLUMN
transform = Compose(
[
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size),
]
)
if self.is_finetune:
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
else:
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
])

return PandaTilesDataModule(
root_path=self.local_dataset,
max_bag_size=self.max_bag_size,
Expand All @@ -128,7 +139,11 @@ def create_model(self) -> DeepMILModule:
self.slide_dataset = self.get_slide_dataset()
self.level = 1
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
if self.is_finetune:
self.model_encoder = self.encoder
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
Expand All @@ -139,7 +154,8 @@ def create_model(self) -> DeepMILModule:
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level,
class_names=self.class_names)
class_names=self.class_names,
is_finetune=self.is_finetune)

def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
Expand Down
23 changes: 20 additions & 3 deletions Tests/SSL/test_ssl_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,7 +634,24 @@ def test_simclr_dataloader_type() -> None:
""" This test checks if the transform pipeline of a SSL job can handle different
data types coming from the dataloader.
"""
def check_types_in_dataloader(dataloader: CombinedLoader) -> None:
# TODO: Once the pytorch lightning bug is fixed the following test can be removed.
# The training and val loader will be both CombinedLoaders
def check_types_in_train_dataloader(dataloader: dict) -> None:
for i, batch in enumerate(dataloader[SSLDataModuleType.ENCODER]):
assert isinstance(batch[0][0], torch.Tensor)
assert isinstance(batch[0][1], torch.Tensor)
assert isinstance(batch[1], torch.Tensor)
if i == 1:
break

for i, batch in enumerate(dataloader[SSLDataModuleType.LINEAR_HEAD]):
assert isinstance(batch[0], torch.Tensor)
assert isinstance(batch[1], torch.Tensor)
assert isinstance(batch[2], torch.Tensor)
if i == 1:
break
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved

def check_types_in_val_dataloader(dataloader: CombinedLoader) -> None:
for i, batch in enumerate(dataloader):
assert isinstance(batch[SSLDataModuleType.ENCODER][0][0], torch.Tensor)
assert isinstance(batch[SSLDataModuleType.ENCODER][0][1], torch.Tensor)
Expand All @@ -646,8 +663,8 @@ def check_types_in_dataloader(dataloader: CombinedLoader) -> None:
break

def check_types_in_train_and_val(data: CombinedDataModule) -> None:
check_types_in_dataloader(data.train_dataloader())
check_types_in_dataloader(data.val_dataloader())
check_types_in_train_dataloader(data.train_dataloader())
check_types_in_val_dataloader(data.val_dataloader())

container = DummySimCLR()
container.setup()
Expand Down