Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Enable custom SSL encoder checkpoint #562

Merged
merged 8 commits into from
Aug 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/cpath-pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,22 @@ jobs:
cd ${{ env.folder }}
make smoke_test_tilespandaimagenetmil_aml

smoke_test_tcgacrcksslmil_aml:
runs-on: ubuntu-20.04
steps:
- uses: actions/checkout@v3
with:
lfs: true

- name: Set up smoke test environment
id: setup-sslmil-smoke-test-environment
uses: ./.github/actions/prepare_smoke_test_environment

- name: smoke test
run: |
cd ${{ env.folder }}
make smoke_test_tcgacrcksslmil_aml

smoke_test_crck_simclr_aml:
runs-on: ubuntu-20.04
steps:
Expand Down
7 changes: 3 additions & 4 deletions hi-ml-cpath/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ pytest:
pytest_coverage:
pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc

SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606

# Run regression tests and compare performance
define BASE_CPATH_RUNNER_COMMAND
cd ../ ; \
Expand All @@ -92,7 +94,7 @@ define DEEPSMILEPANDATILES_ARGS
endef

define TCGACRCKSSLMIL_ARGS
--model=health_cpath.TcgaCrckSSLMIL
--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK}
endef

define CRCKSIMCLR_ARGS
Expand Down Expand Up @@ -193,9 +195,6 @@ smoke_test_tilespandaimagenetmil_aml:
{ ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};}

# Note: this test doesn't currently run in hi-ml Workspace since the checkpoint run specified in run_ids
# innereye_ssl_checkpoint_crck_4ws does not exist there. Once we can specify alternative checkpoints
# this can be run with any Workspace
# The following test takes about 30 seconds
smoke_test_tcgacrcksslmil_local:
{ ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \
Expand Down
11 changes: 5 additions & 6 deletions hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams):
"generating outputs.")
maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be "
"maximised (otherwise minimised).")
ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if "
"using SSLEncoder")

def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
Expand All @@ -82,9 +84,6 @@ def __init__(self, **kwargs: Any) -> None:
def cache_dir(self) -> Path:
return Path(f"/tmp/himl_cache/{self.__class__.__name__}-{self.encoder_type}/")

def setup(self) -> None:
self.ssl_ckpt_run_id = ""

def get_test_plot_options(self) -> Set[PlotOption]:
options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX}
if self.num_top_slides > 0:
Expand Down Expand Up @@ -206,7 +205,7 @@ def get_dataloader_kwargs(self) -> dict:

def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]:
if self.is_caching:
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_ckpt_run_id,
encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id,
self.outputs_folder)
transform = Compose([
LoadTilesBatchd(image_key, progress=True),
Expand All @@ -226,7 +225,7 @@ def create_model(self) -> TilesDeepMILModule:
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
Expand Down Expand Up @@ -266,7 +265,7 @@ def create_model(self) -> SlidesDeepMILModule:
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def __init__(self, **kwargs: Any) -> None:

def setup(self) -> None:
super().setup()
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_crck_4ws
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws

def get_data_module(self) -> TilesDataModule:
return TcgaCrckTilesDataModule(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def __init__(self, **kwargs: Any) -> None:

def setup(self) -> None:
BaseMILTiles.setup(self)
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary

def get_data_module(self) -> PandaTilesDataModule:
return PandaTilesDataModule(
Expand Down Expand Up @@ -135,7 +136,8 @@ def __init__(self, **kwargs: Any) -> None:

def setup(self) -> None:
BaseMILSlides.setup(self)
self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary
# If no SSL checkpoint is provided, use the default one
self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary

def get_dataloader_kwargs(self) -> dict:
return dict(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def create_model(self) -> SlidesDeepMILModule:
class_weights=self.data_module.class_weights,
dropout_rate=self.dropout_rate,
outputs_folder=self.outputs_folder,
ssl_ckpt_run_id=self.ssl_ckpt_run_id,
ssl_ckpt_run_id=self.ssl_checkpoint_run_id,
encoder_params=create_from_matching_params(self, EncoderParams),
pooling_params=create_from_matching_params(self, PoolingParams),
optimizer_params=create_from_matching_params(self, OptimizerParams),
Expand Down