diff --git a/tests/trainers/lightning/test_checkpoint.py b/tests/trainers/lightning/test_checkpoint.py index 25155049b..d4553ab5a 100644 --- a/tests/trainers/lightning/test_checkpoint.py +++ b/tests/trainers/lightning/test_checkpoint.py @@ -199,19 +199,19 @@ def _get_lightning_trainer( class TestLightningCheckpoint(TestLightningCheckpoint): def test_load_resume_parity_with_mmf(self): # with checkpoint.resume = True, by default it loads "current.ckpt" - self._load_checkpoint("current.ckpt", ckpt_config={"resume": True}) + self._load_checkpoint_and_test("current.ckpt", ckpt_config={"resume": True}) def test_load_resume_best_parity_with_mmf(self): # with checkpoint.resume = True and checkpoint.resume_best = True # by default it loads best.ckpt. It should load the "best.ckpt" - self._load_checkpoint( + self._load_checkpoint_and_test( "best.ckpt", ckpt_config={"resume": True, "resume_best": True} ) def test_load_resume_ignore_resume_zoo(self): # specifying both checkpoint.resume = True and resume_zoo # resume zoo should be ignored. It should load the "current.ckpt" - self._load_checkpoint( + self._load_checkpoint_and_test( "current.ckpt", ckpt_config={"resume": True, "resume_zoo": "visual_bert.pretrained.coco"}, ) @@ -474,7 +474,7 @@ def _get_mmf_ckpt(self, filename, ckpt_config=None): ) return mmf_ckpt_current - def _load_checkpoint(self, filename, ckpt_config=None): + def _load_checkpoint_and_test(self, filename, ckpt_config=None): # Make sure it loads x.ckpt when mmf mmf_ckpt = self._get_mmf_ckpt(filename, ckpt_config=ckpt_config)