diff --git a/.github/workflows/cpath-pr.yml b/.github/workflows/cpath-pr.yml index 3e280c004..532c4a69f 100644 --- a/.github/workflows/cpath-pr.yml +++ b/.github/workflows/cpath-pr.yml @@ -151,6 +151,22 @@ jobs: cd ${{ env.folder }} make smoke_test_tilespandaimagenetmil_aml + smoke_test_tcgacrcksslmil_aml: + runs-on: ubuntu-20.04 + steps: + - uses: actions/checkout@v3 + with: + lfs: true + + - name: Set up smoke test environment + id: setup-sslmil-smoke-test-environment + uses: ./.github/actions/prepare_smoke_test_environment + + - name: smoke test + run: | + cd ${{ env.folder }} + make smoke_test_tcgacrcksslmil_aml + smoke_test_crck_simclr_aml: runs-on: ubuntu-20.04 steps: diff --git a/hi-ml-cpath/Makefile b/hi-ml-cpath/Makefile index 0191f265a..dcd78e908 100644 --- a/hi-ml-cpath/Makefile +++ b/hi-ml-cpath/Makefile @@ -77,6 +77,8 @@ pytest: pytest_coverage: pytest --cov=health_cpath --cov SSL --cov-branch --cov-report=html --cov-report=xml --cov-report=term-missing --cov-config=.coveragerc +SSL_CKPT_RUN_ID_CRCK := CRCK_SimCLR_1655731022_85790606 + # Run regression tests and compare performance define BASE_CPATH_RUNNER_COMMAND cd ../ ; \ @@ -92,7 +94,7 @@ define DEEPSMILEPANDATILES_ARGS endef define TCGACRCKSSLMIL_ARGS ---model=health_cpath.TcgaCrckSSLMIL +--model=health_cpath.TcgaCrckSSLMIL --ssl_checkpoint_run_id=${SSL_CKPT_RUN_ID_CRCK} endef define CRCKSIMCLR_ARGS @@ -193,9 +195,6 @@ smoke_test_tilespandaimagenetmil_aml: { ${BASE_CPATH_RUNNER_COMMAND} ${DEEPSMILEPANDATILES_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \ ${DEEPSMILEPANDATILES_SMOKE_TEST_ARGS} ${AML_MULTIPLE_DEVICE_ARGS};} -# Note: this test doesn't currently run in hi-ml Workspace since the checkpoint run specified in run_ids -# innereye_ssl_checkpoint_crck_4ws does not exist there. Once we can specify alternative checkpoints -# this can be run with any Workspace # The following test takes about 30 seconds smoke_test_tcgacrcksslmil_local: { ${BASE_CPATH_RUNNER_COMMAND} ${TCGACRCKSSLMIL_ARGS} ${DEFAULT_SMOKE_TEST_ARGS} \ diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py index 59e3e6e24..ce5fccaf1 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/BaseMIL.py @@ -71,6 +71,8 @@ class BaseMIL(LightningContainer, EncoderParams, PoolingParams): "generating outputs.") maximise_primary_metric: bool = param.Boolean(True, doc="Whether the primary validation metric should be " "maximised (otherwise minimised).") + ssl_checkpoint_run_id: str = param.String(default="", doc="Optional run id from which to load checkpoint if " + "using SSLEncoder") def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) @@ -82,9 +84,6 @@ def __init__(self, **kwargs: Any) -> None: def cache_dir(self) -> Path: return Path(f"/tmp/himl_cache/{self.__class__.__name__}-{self.encoder_type}/") - def setup(self) -> None: - self.ssl_ckpt_run_id = "" - def get_test_plot_options(self) -> Set[PlotOption]: options = {PlotOption.HISTOGRAM, PlotOption.CONFUSION_MATRIX} if self.num_top_slides > 0: @@ -206,7 +205,7 @@ def get_dataloader_kwargs(self) -> dict: def get_transforms_dict(self, image_key: str) -> Dict[ModelKey, Union[Callable, None]]: if self.is_caching: - encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_ckpt_run_id, + encoder = create_from_matching_params(self, EncoderParams).get_encoder(self.ssl_checkpoint_run_id, self.outputs_folder) transform = Compose([ LoadTilesBatchd(image_key, progress=True), @@ -226,7 +225,7 @@ def create_model(self) -> TilesDeepMILModule: class_weights=self.data_module.class_weights, dropout_rate=self.dropout_rate, outputs_folder=self.outputs_folder, - ssl_ckpt_run_id=self.ssl_ckpt_run_id, + ssl_ckpt_run_id=self.ssl_checkpoint_run_id, encoder_params=create_from_matching_params(self, EncoderParams), pooling_params=create_from_matching_params(self, PoolingParams), optimizer_params=create_from_matching_params(self, OptimizerParams), @@ -266,7 +265,7 @@ def create_model(self) -> SlidesDeepMILModule: class_weights=self.data_module.class_weights, dropout_rate=self.dropout_rate, outputs_folder=self.outputs_folder, - ssl_ckpt_run_id=self.ssl_ckpt_run_id, + ssl_ckpt_run_id=self.ssl_checkpoint_run_id, encoder_params=create_from_matching_params(self, EncoderParams), pooling_params=create_from_matching_params(self, PoolingParams), optimizer_params=create_from_matching_params(self, OptimizerParams), diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py index 081e245b5..29930d8ce 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILECrck.py @@ -55,7 +55,8 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: super().setup() - self.ssl_ckpt_run_id = innereye_ssl_checkpoint_crck_4ws + # If no SSL checkpoint is provided, use the default one + self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_crck_4ws def get_data_module(self) -> TilesDataModule: return TcgaCrckTilesDataModule( diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py index f624c47e5..e429b3c39 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILEPanda.py @@ -70,7 +70,8 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: BaseMILTiles.setup(self) - self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary + # If no SSL checkpoint is provided, use the default one + self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary def get_data_module(self) -> PandaTilesDataModule: return PandaTilesDataModule( @@ -135,7 +136,8 @@ def __init__(self, **kwargs: Any) -> None: def setup(self) -> None: BaseMILSlides.setup(self) - self.ssl_ckpt_run_id = innereye_ssl_checkpoint_binary + # If no SSL checkpoint is provided, use the default one + self.ssl_checkpoint_run_id = self.ssl_checkpoint_run_id or innereye_ssl_checkpoint_binary def get_dataloader_kwargs(self) -> dict: return dict( diff --git a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py index e2a626f17..1d4577978 100644 --- a/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py +++ b/hi-ml-cpath/src/health_cpath/configs/classification/DeepSMILESlidesPandaBenchmark.py @@ -128,7 +128,7 @@ def create_model(self) -> SlidesDeepMILModule: class_weights=self.data_module.class_weights, dropout_rate=self.dropout_rate, outputs_folder=self.outputs_folder, - ssl_ckpt_run_id=self.ssl_ckpt_run_id, + ssl_ckpt_run_id=self.ssl_checkpoint_run_id, encoder_params=create_from_matching_params(self, EncoderParams), pooling_params=create_from_matching_params(self, PoolingParams), optimizer_params=create_from_matching_params(self, OptimizerParams),