Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
BUG: Rename pretraining_run_checkpoints (#841)
Browse files Browse the repository at this point in the history
"Parameter `extra_downloaded_run_id` has been renamed to
`pretraining_run_checkpoints`" (According to CHANGELOG.md) but is still
in used in SSLClassifierContainer class (in
InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py).
It will raise an AttributeError: 'SSLClassifierCIFAR' object has no
attribute 'extra_downloaded_run_id' when trying to run python
InnerEyeML/runner.py --model=CXRImageClassifier
--pretraining_run_recovery_id={THE_ID_TO_YOUR_SSL_TRAINING_JOB}.
So renamed it to `pretraining_run_checkpoints` there too.
<!--
## Guidelines

Please follow the guidelines for pull requests (PRs) in
[CONTRIBUTING](/docs/contributing.md). Checklist:

- Ensure that your PR is small, and implements one change
- Give your PR title one of the prefixes ENH, BUG, STYLE, DOC, DEL to
indicate what type of change that is (see
[CONTRIBUTING](/docs/contributing.md))
- Link the correct GitHub issue for tracking
- Add unit tests for all functions that you introduced or modified
- Run automatic code formatting / linting on all files ("Format
Document" Shift-Alt-F in VSCode)

## Change the default merge message

When completing your PR, you will be asked for a title and an optional
extended description. By default, the extended description will be a
concatenation of the individual
commit messages. Please DELETE/REPLACE that with a human readable
extended description for non-trivial PRs.
-->
  • Loading branch information
erann1987 authored Feb 16, 2023
1 parent d902e02 commit 03b4cc3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions InnerEye/ML/SSL/lightning_containers/ssl_image_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,12 @@ def create_model(self) -> LightningModuleWithOptimizer:
This method must create the actual Lightning model that will be trained.
"""
if self.local_ssl_weights_path is None:
assert self.extra_downloaded_run_id is not None
assert self.pretraining_run_checkpoints is not None
try:
path_to_checkpoint = self.extra_downloaded_run_id.get_best_checkpoint_paths()
path_to_checkpoint = self.pretraining_run_checkpoints.get_best_checkpoint_paths()
except FileNotFoundError:
logging.info("Best checkpoint not found - using last recovery checkpoint instead")
path_to_checkpoint = self.extra_downloaded_run_id.get_recovery_checkpoint_paths()
path_to_checkpoint = self.pretraining_run_checkpoints.get_recovery_checkpoint_paths()
path_to_checkpoint = path_to_checkpoint[0] # type: ignore
else:
path_to_checkpoint = self.local_ssl_weights_path
Expand Down

0 comments on commit 03b4cc3

Please sign in to comment.