From edf76c6d63b1636765ba2a9fc3db41877561728d Mon Sep 17 00:00:00 2001 From: Melissa Bristow Date: Mon, 8 Aug 2022 13:41:51 +0100 Subject: [PATCH 1/7] Add command line arg for custom SSL encoder checkpoint --- hi-ml-cpath/Makefile | 2 +- .../src/health_cpath/configs/classification/BaseMIL.py | 4 +++- .../health_cpath/configs/classification/DeepSMILECrck.py | 3 ++- .../health_cpath/configs/classification/DeepSMILEPanda.py | 6 ++++-- hi-ml-cpath/testhisto/testhisto/mocks/container.py | 1 + 5 files changed, 11 insertions(+), 5 deletions(-) diff --git a/hi-ml-cpath/Makefile b/hi-ml-cpath/Makefile index 0191f265a..a9c8632e6 100644 --- a/hi-ml-cpath/Makefile +++ b/hi-ml-cpath/Makefile @@ -92,7 +92,7 @@ define DEEPSMILEPANDATILES_ARGS endef define TCGACRCKSSLMIL_ARGS ---model=health_cpath.TcgaCrckSSLMIL +--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint=CRCK_SimCLR_1655731022_85790606 endef define CRCKSIMCLR_ARGS diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py index 59e3e6e24..af4068dd9 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py @@ -71,6 +71,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams): "generating outputs.") maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be " "maximised (otherwise minimised).") + ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if " + "using SSLEncoder") def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -83,7 +85,7 @@ def cache_dir(self) -> Path: return Path(f"/tmp/himl_cache/{self.__class__.__name__}-{self.encoder_type}/") def setup(self) -> None: - self.ssl_ckpt_run_id = "" + self.ssl_ckpt_run_id = self.ssl_checkpoint_run_id def get_test_plot_options(self) -> Set[PlotOption]: options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX} diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py index 081e245b5..008b0054b 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py @@ -55,7 +55,8 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: super().setup() - self.ssl_ckpt_run_id = innereye_ssl_checkpoint_crck_4ws + # If no SSL checkpoint is provided, use the default one + self.ssl_ckpt_run_id = self.ssl_ckpt_run_id or innereye_ssl_checkpoint_crck_4ws def get_data_module(self) -> TilesDataModule: return TcgaCrckTilesDataModule( diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py index f624c47e5..0ccc20b7c 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py @@ -70,7 +70,8 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: BaseMILTiles.setup(self) - self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary + # If no SSL checkpoint is provided, use the default one + self.ssl_ckpt_run_id = self.ssl_ckpt_run_id or innereye_ssl_checkpoint_binary def get_data_module(self) -> PandaTilesDataModule: return PandaTilesDataModule( @@ -135,7 +136,8 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: BaseMILSlides.setup(self) - self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary + # If no SSL checkpoint is provided, use the default one + self.ssl_ckpt_run_id = self.ssl_ckpt_run_id or innereye_ssl_checkpoint_binary def get_dataloader_kwargs(self) -> dict: return dict( diff --git a/hi-ml-cpath/testhisto/testhisto/mocks/container.py b/hi-ml-cpath/testhisto/testhisto/mocks/container.py index 8ee6b2030..64c0622a0 100644 --- a/hi-ml-cpath/testhisto/testhisto/mocks/container.py +++ b/hi-ml-cpath/testhisto/testhisto/mocks/container.py @@ -38,6 +38,7 @@ def __init__(self, tmp_path: Path, **kwargs: Any) -> None: # declared in TrainerParams: max_epochs=2, crossval_count=1, + ssl_checkpoint_run_id="", ) default_kwargs.update(kwargs) super().__init__(**default_kwargs) From 76e1f305b1cc8f215ddbc93bf1a44520f8fbdbff Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 8 Aug 2022 12:44:30 +0000 Subject: [PATCH 2/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py index af4068dd9..3d92f2af8 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py @@ -72,7 +72,7 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams): maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be " "maximised (otherwise minimised).") ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if " - "using SSLEncoder") + "using SSLEncoder") def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) From 2ea13274d8873d634414c0e317fd977f818c6187 Mon Sep 17 00:00:00 2001 From: Melissa Bristow Date: Mon, 8 Aug 2022 13:44:42 +0100 Subject: [PATCH 3/7] Add smoke test to workflow --- .github/workflows/cpath-pr.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/cpath-pr.yml b/.github/workflows/cpath-pr.yml index 3e280c004..532c4a69f 100644 --- a/.github/workflows/cpath-pr.yml +++ b/.github/workflows/cpath-pr.yml @@ -151,6 +151,22 @@ jobs: cd ${{ env.folder }} make smoke_test_tilespandaimagenetmil_aml + smoke_test_tcgacrcksslmil_aml: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v3 + with: + lfs: true + + - name: Set up smoke test environment + id: setup-sslmil-smoke-test-environment + uses: ./.github/actions/prepare_smoke_test_environment + + - name: smoke test + run: | + cd ${{ env.folder }} + make smoke_test_tcgacrcksslmil_aml + smoke_test_crck_simclr_aml: runs-on: ubuntu-20.04 steps: From f12af83b66042a52adc97b18f942e078b345f42c Mon Sep 17 00:00:00 2001 From: Melissa Bristow Date: Mon, 8 Aug 2022 14:51:21 +0100 Subject: [PATCH 4/7] Remove uneccessary change --- hi-ml-cpath/testhisto/testhisto/mocks/container.py | 1 - 1 file changed, 1 deletion(-) diff --git a/hi-ml-cpath/testhisto/testhisto/mocks/container.py b/hi-ml-cpath/testhisto/testhisto/mocks/container.py index 64c0622a0..8ee6b2030 100644 --- a/hi-ml-cpath/testhisto/testhisto/mocks/container.py +++ b/hi-ml-cpath/testhisto/testhisto/mocks/container.py @@ -38,7 +38,6 @@ def __init__(self, tmp_path: Path, **kwargs: Any) -> None: # declared in TrainerParams: max_epochs=2, crossval_count=1, - ssl_checkpoint_run_id="", ) default_kwargs.update(kwargs) super().__init__(**default_kwargs) From c5816ebab84052f0ff42b7b5ddb92ecc1fedfc3d Mon Sep 17 00:00:00 2001 From: Melissa Bristow Date: Mon, 8 Aug 2022 16:16:17 +0100 Subject: [PATCH 5/7] Address PR comments --- hi-ml-cpath/Makefile | 2 +- .../src/health_cpath/configs/classification/BaseMIL.py | 9 +++------ .../health_cpath/configs/classification/DeepSMILECrck.py | 2 +- .../configs/classification/DeepSMILEPanda.py | 4 ++-- .../classification/DeepSMILESlidesPandaBenchmark.py | 2 +- 5 files changed, 8 insertions(+), 11 deletions(-) diff --git a/hi-ml-cpath/Makefile b/hi-ml-cpath/Makefile index a9c8632e6..43dd8a41b 100644 --- a/hi-ml-cpath/Makefile +++ b/hi-ml-cpath/Makefile @@ -92,7 +92,7 @@ define DEEPSMILEPANDATILES_ARGS endef define TCGACRCKSSLMIL_ARGS ---model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint=CRCK_SimCLR_1655731022_85790606 +--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=CRCK_SimCLR_1655731022_85790606 endef define CRCKSIMCLR_ARGS diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py index 3d92f2af8..ce5fccaf1 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py @@ -84,9 +84,6 @@ def __init__(self, **kwargs: Any) -> None: def cache_dir(self) -> Path: return Path(f"/tmp/himl_cache/{self.__class__.__name__}-{self.encoder_type}/") - def setup(self) -> None: - self.ssl_ckpt_run_id = self.ssl_checkpoint_run_id - def get_test_plot_options(self) -> Set[PlotOption]: options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX} if self.num_top_slides > 0: @@ -208,7 +205,7 @@ def get_dataloader_kwargs(self) -> dict: def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]: if self.is_caching: - encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_ckpt_run_id, + encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id, self.outputs_folder) transform = Compose([ LoadTilesBatchd(image_key, progress=True), @@ -228,7 +225,7 @@ def create_model(self) -> TilesDeepMILModule: class_weights=self.data_module.class_weights, dropout_rate=self.dropout_rate, outputs_folder=self.outputs_folder, - ssl_ckpt_run_id=self.ssl_ckpt_run_id, + ssl_ckpt_run_id=self.ssl_checkpoint_run_id, encoder_params=create_from_matching_params(self, EncoderParams), pooling_params=create_from_matching_params(self, PoolingParams), optimizer_params=create_from_matching_params(self, OptimizerParams), @@ -268,7 +265,7 @@ def create_model(self) -> SlidesDeepMILModule: class_weights=self.data_module.class_weights, dropout_rate=self.dropout_rate, outputs_folder=self.outputs_folder, - ssl_ckpt_run_id=self.ssl_ckpt_run_id, + ssl_ckpt_run_id=self.ssl_checkpoint_run_id, encoder_params=create_from_matching_params(self, EncoderParams), pooling_params=create_from_matching_params(self, PoolingParams), optimizer_params=create_from_matching_params(self, OptimizerParams), diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py index 008b0054b..29930d8ce 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py @@ -56,7 +56,7 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: super().setup() # If no SSL checkpoint is provided, use the default one - self.ssl_ckpt_run_id = self.ssl_ckpt_run_id or innereye_ssl_checkpoint_crck_4ws + self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws def get_data_module(self) -> TilesDataModule: return TcgaCrckTilesDataModule( diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py index 0ccc20b7c..e429b3c39 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py @@ -71,7 +71,7 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: BaseMILTiles.setup(self) # If no SSL checkpoint is provided, use the default one - self.ssl_ckpt_run_id = self.ssl_ckpt_run_id or innereye_ssl_checkpoint_binary + self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary def get_data_module(self) -> PandaTilesDataModule: return PandaTilesDataModule( @@ -137,7 +137,7 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: BaseMILSlides.setup(self) # If no SSL checkpoint is provided, use the default one - self.ssl_ckpt_run_id = self.ssl_ckpt_run_id or innereye_ssl_checkpoint_binary + self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary def get_dataloader_kwargs(self) -> dict: return dict( diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py index e2a626f17..1d4577978 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py @@ -128,7 +128,7 @@ def create_model(self) -> SlidesDeepMILModule: class_weights=self.data_module.class_weights, dropout_rate=self.dropout_rate, outputs_folder=self.outputs_folder, - ssl_ckpt_run_id=self.ssl_ckpt_run_id, + ssl_ckpt_run_id=self.ssl_checkpoint_run_id, encoder_params=create_from_matching_params(self, EncoderParams), pooling_params=create_from_matching_params(self, PoolingParams), optimizer_params=create_from_matching_params(self, OptimizerParams), From c5d9b6a1b25bb38cffb1d6c8d445c179cdb87a64 Mon Sep 17 00:00:00 2001 From: Melissa Bristow Date: Mon, 8 Aug 2022 16:57:10 +0100 Subject: [PATCH 6/7] remove warning --- hi-ml-cpath/Makefile | 3 --- 1 file changed, 3 deletions(-) diff --git a/hi-ml-cpath/Makefile b/hi-ml-cpath/Makefile index 43dd8a41b..f8bb73164 100644 --- a/hi-ml-cpath/Makefile +++ b/hi-ml-cpath/Makefile @@ -193,9 +193,6 @@ smoke_test_tilespandaimagenetmil_aml: { ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \ ${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};} -# Note: this test doesn't currently run in hi-ml Workspace since the checkpoint run specified in run_ids -# innereye_ssl_checkpoint_crck_4ws does not exist there. Once we can specify alternative checkpoints -# this can be run with any Workspace # The following test takes about 30 seconds smoke_test_tcgacrcksslmil_local: { ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \ From bffcc8896056e72661a983a1a4aa61a337c75a8d Mon Sep 17 00:00:00 2001 From: Melissa Bristow Date: Mon, 8 Aug 2022 17:12:42 +0100 Subject: [PATCH 7/7] Add variable to makefile --- hi-ml-cpath/Makefile | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/hi-ml-cpath/Makefile b/hi-ml-cpath/Makefile index f8bb73164..dcd78e908 100644 --- a/hi-ml-cpath/Makefile +++ b/hi-ml-cpath/Makefile @@ -77,6 +77,8 @@ pytest: pytest_coverage: pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc +SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606 + # Run regression tests and compare performance define BASE_CPATH_RUNNER_COMMAND cd ../ ; \ @@ -92,7 +94,7 @@ define DEEPSMILEPANDATILES_ARGS endef define TCGACRCKSSLMIL_ARGS ---model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=CRCK_SimCLR_1655731022_85790606 +--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK} endef define CRCKSIMCLR_ARGS