Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Enable fine-tuning in Deepmil #650

Merged
merged 28 commits into from
Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ jobs that run in AzureML.
- ([#635](https://github.com/microsoft/InnerEye-DeepLearning/pull/635)) Add tile selection and binary label for online evaluation of PANDA SSL
- ([#647](https://github.com/microsoft/InnerEye-DeepLearning/pull/647)) Add class-wise accuracy logging and confusion matrix to DeepMIL
- ([#653](https://github.com/microsoft/InnerEye-DeepLearning/pull/653)) Add dropout to DeepMIL and fix feature extractor setup.
- ([#650](https://github.com/microsoft/InnerEye-DeepLearning/pull/650)) Enable fine-tuning in DeepMIL using PANDA as the classification task.

### Changed
- ([#588](https://github.com/microsoft/InnerEye-DeepLearning/pull/588)) Replace SciPy with PIL.PngImagePlugin.PngImageFile to load png files.
Expand Down
11 changes: 8 additions & 3 deletions InnerEye/ML/Histopathology/models/deepmil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import more_itertools as mi

from pytorch_lightning import LightningModule
from torch import Tensor, argmax, mode, nn, no_grad, optim, round
from torch import Tensor, argmax, mode, nn, set_grad_enabled, optim, round
from torchmetrics import AUROC, F1, Accuracy, Precision, Recall, ConfusionMatrix

from InnerEye.Common import fixed_paths
Expand Down Expand Up @@ -55,7 +55,8 @@ def __init__(self,
slide_dataset: SlidesDataset = None,
tile_size: int = 224,
level: int = 1,
class_names: Optional[List[str]] = None) -> None:
class_names: Optional[List[str]] = None,
is_finetune: bool = False) -> None:
"""
:param label_column: Label key for input batch dictionary.
:param n_classes: Number of output classes for MIL prediction. For binary classification, n_classes should be set to 1.
Expand All @@ -75,6 +76,7 @@ def __init__(self,
:param tile_size: The size of each tile (default=224).
:param level: The downsampling level (e.g. 0, 1, 2) of the tiles if available (default=1).
:param class_names: The names of the classes if available (default=None).
:param is_finetune: Boolean value to enable/disable finetuning (default=False).
"""
super().__init__()

Expand Down Expand Up @@ -115,6 +117,9 @@ def __init__(self,

self.verbose = verbose

# Finetuning attributes
self.is_finetune = is_finetune

self.aggregation_fn, self.num_pooling = self.get_pooling()
self.classifier_fn = self.get_classifier()
self.loss_fn = self.get_loss()
Expand Down Expand Up @@ -196,7 +201,7 @@ def log_metrics(self,
log_on_epoch(self, f'{stage}/{metric_name}', metric_object)

def forward(self, instances: Tensor) -> Tuple[Tensor, Tensor]: # type: ignore
with no_grad():
with set_grad_enabled(self.is_finetune):
instance_features = self.encoder(instances) # N X L x 1 x 1
attentions, bag_features = self.aggregation_fn(instance_features) # K x N | K x L
bag_features = bag_features.view(-1, self.num_encoding * self.pool_out_dim)
Expand Down
6 changes: 5 additions & 1 deletion InnerEye/ML/Histopathology/models/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,5 +139,9 @@ class HistoSSLEncoder(TileEncoder):

def _get_encoder(self) -> Tuple[Callable, int]:
resnet18_model = resnet18(pretrained=False)
num_features = resnet18_model.fc.in_features
histossl_encoder = load_weights_to_model(self.WEIGHTS_URL, resnet18_model)
return setup_feature_extractor(histossl_encoder, self.input_dim) # type: ignore
histossl_encoder.fc = torch.nn.Sequential()
for param in histossl_encoder.parameters():
param.requires_grad = False
return histossl_encoder, num_features # type: ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestion to revert these changes in _get_encoder(), as setup_feature_extractor() now implements the same behaviour.

14 changes: 11 additions & 3 deletions InnerEye/ML/configs/histo_configs/classification/BaseMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
class BaseMIL(LightningContainer):
# Model parameters:
pooling_type: str = param.String(doc="Name of the pooling layer class to use.")
is_finetune: bool = param.Boolean(doc="Whether to fine-tune the encoder. Options:"
"`False` (default), or `True`.")
dropout_rate: Optional[float] = param.Number(None, bounds=(0, 1), doc="Pre-classifier dropout rate.")
# l_rate, weight_decay, adam_betas are already declared in OptimizerParams superclass

Expand Down Expand Up @@ -62,8 +64,8 @@ def setup(self) -> None:
raise NotImplementedError("InnerEyeSSLEncoder requires a pre-trained checkpoint.")

self.encoder = self.get_encoder()
self.encoder.cuda()
self.encoder.eval()
if not self.is_finetune:
self.encoder.eval()

def get_encoder(self) -> TileEncoder:
if self.encoder_type == ImageNetEncoder.__name__:
Expand Down Expand Up @@ -95,7 +97,13 @@ def create_model(self) -> DeepMILModule:
self.data_module = self.get_data_module()
# Encoding is done in the datamodule, so here we provide instead a dummy
# no-op IdentityEncoder to be used inside the model
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
if self.is_finetune:
self.model_encoder = self.encoder
for params in self.model_encoder.parameters():
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
Expand Down
37 changes: 27 additions & 10 deletions InnerEye/ML/configs/histo_configs/classification/DeepSMILEPanda.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from health_azure.utils import get_workspace, is_running_in_azure_ml
from health_ml.networks.layers.attention_layers import GatedAttentionLayer
from InnerEye.Common import fixed_paths
from InnerEye.ML.Histopathology.datamodules.base_module import CacheMode, CacheLocation
from InnerEye.ML.Histopathology.datamodules.panda_module import PandaTilesDataModule
from InnerEye.ML.Histopathology.datasets.panda_tiles_dataset import PandaTilesDataset
from InnerEye.ML.common import get_best_checkpoint_path
Expand All @@ -35,12 +36,19 @@


class DeepSMILEPanda(BaseMIL):
"""`is_finetune` sets the fine-tuning mode. If this is set, setting cache_mode=CacheMode.NONE takes ~30 min/epoch and
cache_mode=CacheMode.MEMORY, precache_location=CacheLocation.CPU takes ~[5-10] min/epoch.
Fine-tuning with caching completes using batch_size=4, max_bag_size=1000, num_epochs=20, max_num_gpus=1 on PANDA.
"""
def __init__(self, **kwargs: Any) -> None:
default_kwargs = dict(
# declared in BaseMIL:
pooling_type=GatedAttentionLayer.__name__,
# average number of tiles is 56 for PANDA
encoding_chunk_size=60,
cache_mode=CacheMode.MEMORY,
precache_location=CacheLocation.CPU,
is_finetune=False,

# declared in DatasetParams:
local_dataset=Path("/tmp/datasets/PANDA_tiles"),
Expand Down Expand Up @@ -98,17 +106,19 @@ def setup(self) -> None:
os.chdir(fixed_paths.repository_parent_directory())
self.downloader.download_checkpoint_if_necessary()
self.encoder = self.get_encoder()
self.encoder.cuda()
self.encoder.eval()
if not self.is_finetune:
self.encoder.eval()

harshita-s marked this conversation as resolved.
Show resolved Hide resolved
def get_data_module(self) -> PandaTilesDataModule:
image_key = PandaTilesDataset.IMAGE_COLUMN
transform = Compose(
[
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size),
]
)
if self.is_finetune:
transform = Compose([LoadTilesBatchd(image_key, progress=True)])
vale-salvatelli marked this conversation as resolved.
Show resolved Hide resolved
else:
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
EncodeTilesBatchd(image_key, self.encoder, chunk_size=self.encoding_chunk_size)
])

return PandaTilesDataModule(
root_path=self.local_dataset,
max_bag_size=self.max_bag_size,
Expand All @@ -128,7 +138,13 @@ def create_model(self) -> DeepMILModule:
self.slide_dataset = self.get_slide_dataset()
self.level = 1
self.class_names = ["ISUP 0", "ISUP 1", "ISUP 2", "ISUP 3", "ISUP 4", "ISUP 5"]
return DeepMILModule(encoder=IdentityEncoder(input_dim=(self.encoder.num_encoding,)),
if self.is_finetune:
self.model_encoder = self.encoder
for params in self.model_encoder.parameters():
params.requires_grad = True
else:
self.model_encoder = IdentityEncoder(input_dim=(self.encoder.num_encoding,))
return DeepMILModule(encoder=self.model_encoder,
label_column=self.data_module.train_dataset.LABEL_COLUMN,
n_classes=self.data_module.train_dataset.N_CLASSES,
pooling_layer=self.get_pooling_layer(),
Expand All @@ -139,7 +155,8 @@ def create_model(self) -> DeepMILModule:
slide_dataset=self.get_slide_dataset(),
tile_size=self.tile_size,
level=self.level,
class_names=self.class_names)
class_names=self.class_names,
is_finetune=self.is_finetune)

def get_slide_dataset(self) -> PandaDataset:
return PandaDataset(root=self.extra_local_dataset_paths[0]) # type: ignore
Expand Down