Skip to content

Commit

Permalink
add best_model_path to TrainOutput
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #12

Pull Request resolved: facebookexternal/stl_tasks#12

It's useful to return the best_model_path after training. e.g. F6 + multimodality needs it to publish the best model

Reviewed By: tangbinh

Differential Revision: D33716125

fbshipit-source-id: 2c8a5cd37d430062b7273c3fbebb0d58843f3b1d
  • Loading branch information
hudeven authored and facebook-github-bot committed Jan 25, 2022
1 parent b67e5a2 commit 84bf109
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 9 deletions.
27 changes: 19 additions & 8 deletions torchrecipes/core/base_train_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
@dataclass
class TrainOutput:
tensorboard_log_dir: Optional[str] = None
best_model_path: Optional[str] = None


class TestOutput(TypedDict):
Expand All @@ -56,6 +57,7 @@ class BaseTrainApp:
trainer_conf: TrainerConf
log_dir: Optional[str]
root_dir: Optional[str]
_checkpoint_callback: Optional[OSSModelCheckpoint]

def __init__(
self,
Expand All @@ -73,6 +75,7 @@ def __init__(
self.trainer_conf = trainer
self.log_dir = None
self.root_dir = None
self._checkpoint_callback = None
torch._C._log_api_usage_once(
f"torchrecipes.{self.__module__}.{self.__class__.__name__}"
)
Expand Down Expand Up @@ -155,16 +158,22 @@ def _set_trainer_params(
callbacks = trainer_params.get("callbacks", [])
callbacks.extend(self.get_callbacks())

ckpt_callbacks = [c for c in callbacks if isinstance(c, OSSModelCheckpoint)]

# create default model checkpoint callback unless disabled
if trainer_params.get("checkpoint_callback", True):
checkpoint_callback = self.get_default_model_checkpoint()
callbacks.append(checkpoint_callback)
if len(ckpt_callbacks) > 0:
self._checkpoint_callback = ckpt_callbacks[0]
else:
self._checkpoint_callback = self.get_default_model_checkpoint()
callbacks.append(self._checkpoint_callback)

# auto-resume from last default checkpoint
ckpt_path = checkpoint_callback.dirpath
if not trainer_params.get("resume_from_checkpoint") and ckpt_path:
last_checkpoint = find_last_checkpoint_path(ckpt_path)
trainer_params["resume_from_checkpoint"] = last_checkpoint
if self._checkpoint_callback:
ckpt_path = self._checkpoint_callback.dirpath
if not trainer_params.get("resume_from_checkpoint") and ckpt_path:
last_checkpoint = find_last_checkpoint_path(ckpt_path)
trainer_params["resume_from_checkpoint"] = last_checkpoint

trainer_params["callbacks"] = callbacks

Expand Down Expand Up @@ -192,8 +201,10 @@ def train(self) -> TrainOutput:
log_params["run_status"] = JobStatus.FAILED.value
log_run(**log_params)
raise got_exception

return TrainOutput(tensorboard_log_dir=self.log_dir)
best_model_path = getattr(self._checkpoint_callback, "best_model_path", None)
return TrainOutput(
tensorboard_log_dir=self.log_dir, best_model_path=best_model_path
)

def test(self) -> _EVALUATE_OUTPUT:
trainer, _ = self._get_trainer()
Expand Down
30 changes: 30 additions & 0 deletions torchrecipes/core/tests/test_base_train_app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


#!/usr/bin/env python3

# pyre-strict
from pytorch_lightning.callbacks import ModelCheckpoint
from torchrecipes.core.base_train_app import BaseTrainApp
from torchrecipes.core.conf import TrainerConf
from torchrecipes.core.test_utils.test_base import BaseTrainAppTestCase


class TestTrainApp(BaseTrainAppTestCase):
def test_ckpt_callback_fallback_to_default(self) -> None:
app = BaseTrainApp(None, TrainerConf(), None)
app._set_trainer_params(trainer_params={})
self.assertIsNotNone(app._checkpoint_callback)
self.assertIsNone(app._checkpoint_callback.monitor)

def test_ckpt_callback_user_provided(self) -> None:
app = BaseTrainApp(None, TrainerConf(), None)
self.mock_callable(app, "get_callbacks").to_return_value(
[ModelCheckpoint(monitor="some_metrics")]
)
app._set_trainer_params(trainer_params={})
self.assertIsNotNone(app._checkpoint_callback)
self.assertEqual(app._checkpoint_callback.monitor, "some_metrics")
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def _get_train_app(
def test_train_model(self, root_dir: str) -> None:
train_app = self._get_train_app(tb_save_dir=root_dir)
# Train the model with the config
train_app.train()
output = train_app.train()
self.assertIsNotNone(output.tensorboard_log_dir)
# we don't save checkpoints for tests, because it would make the tests flaky
self.assertIsNone(output.best_model_path)

@tempdir
def test_fine_tuning(self, root_dir: str) -> None:
Expand Down

0 comments on commit 84bf109

Please sign in to comment.