From bb4575813aa81a3fee3803abd2c85d821b0eb549 Mon Sep 17 00:00:00 2001 From: yx1215 <37573101+yx1215@users.noreply.github.com> Date: Wed, 15 Jun 2022 12:22:40 -0400 Subject: [PATCH] Add `TimeLimitCallback` to `mx/trainer` callbacks. (#1631) Co-authored-by: Jasper --- requirements/requirements.txt | 2 +- src/gluonts/mx/trainer/_base.py | 40 +++++---- src/gluonts/mx/trainer/callback.py | 90 +++++++++++++++++-- .../mx/trainer/model_iteration_averaging.py | 3 +- 4 files changed, 109 insertions(+), 26 deletions(-) diff --git a/requirements/requirements.txt b/requirements/requirements.txt index 8dcb8fd214..c8de8aca52 100644 --- a/requirements/requirements.txt +++ b/requirements/requirements.txt @@ -2,7 +2,7 @@ holidays>=0.9 matplotlib~=3.0 numpy~=1.16 pandas~=1.0 -pydantic~=1.1 +pydantic~=1.7 tqdm~=4.23 toolz~=0.10 diff --git a/src/gluonts/mx/trainer/_base.py b/src/gluonts/mx/trainer/_base.py index 10374cf26f..ef83651552 100644 --- a/src/gluonts/mx/trainer/_base.py +++ b/src/gluonts/mx/trainer/_base.py @@ -384,12 +384,16 @@ def loop( # todo call run epoch loss.backward() trainer.step(batch_size) - self.callbacks.on_train_batch_end( - training_network=net + should_continue = ( + self.callbacks.on_train_batch_end( + training_network=net + ) ) else: - self.callbacks.on_validation_batch_end( - training_network=net + should_continue = ( + self.callbacks.on_validation_batch_end( + training_network=net + ) ) epoch_loss.update(None, preds=loss) @@ -411,22 +415,26 @@ def loop( # todo call run epoch f"Number of parameters in {net_name}:" f" {num_model_param}" ) + if not should_continue: + self.halt = True + break it.close() # mark epoch end time and log time cost of current epoch - toc = time.time() - logger.info( - "Epoch[%d] Elapsed time %.3f seconds", - epoch_no, - (toc - tic), - ) + if not self.halt: + toc = time.time() + logger.info( + "Epoch[%d] Elapsed time %.3f seconds", + epoch_no, + (toc - tic), + ) - logger.info( - "Epoch[%d] Evaluation metric '%s'=%f", - epoch_no, - ("" if is_training else "validation_") + "epoch_loss", - lv, - ) + logger.info( + "Epoch[%d] Evaluation metric '%s'=%f", + epoch_no, + ("" if is_training else "validation_") + "epoch_loss", + lv, + ) return epoch_loss diff --git a/src/gluonts/mx/trainer/callback.py b/src/gluonts/mx/trainer/callback.py index c67c7717bd..2524d9a9e1 100644 --- a/src/gluonts/mx/trainer/callback.py +++ b/src/gluonts/mx/trainer/callback.py @@ -11,20 +11,24 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -# Standard library imports -from typing import List, Union, Dict, Any +from dataclasses import dataclass, field +from typing import List, Union, Dict, Any, Optional import logging import math +import time # Third-party imports import mxnet.gluon.nn as nn import mxnet as mx from mxnet import gluon +from pydantic import BaseModel, PrivateAttr # First-party imports from gluonts.core.component import validated from gluonts.mx.util import copy_parameters +logger = logging.getLogger(__name__) + class Callback: """ @@ -85,7 +89,7 @@ def on_validation_epoch_start( The network that is being trained. """ - def on_train_batch_end(self, training_network: nn.HybridBlock) -> None: + def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: """ Hook that is called after each training batch. @@ -93,11 +97,17 @@ def on_train_batch_end(self, training_network: nn.HybridBlock) -> None: ---------- training_network The network that is being trained. + + Returns + ------- + bool + A boolean whether the training should continue. Defaults to `True`. """ + return True def on_validation_batch_end( self, training_network: nn.HybridBlock - ) -> None: + ) -> bool: """ Hook that is called after each validation batch. This hook is never called if no validation data is available during training. @@ -106,7 +116,13 @@ def on_validation_batch_end( ---------- training_network The network that is being trained. + + Returns + ------- + bool + A boolean whether the training should continue. Defaults to `True`. """ + return True def on_train_epoch_end( self, @@ -269,11 +285,11 @@ def on_train_epoch_start(self, *args: Any, **kwargs: Any) -> None: def on_validation_epoch_start(self, *args: Any, **kwargs: Any) -> None: self._exec("on_validation_epoch_start", *args, **kwargs) - def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None: - self._exec("on_train_batch_end", *args, **kwargs) + def on_train_batch_end(self, *args: Any, **kwargs: Any) -> bool: + return all(self._exec("on_train_batch_end", *args, **kwargs)) - def on_validation_batch_end(self, *args: Any, **kwargs: Any) -> None: - self._exec("on_validation_batch_end", *args, **kwargs) + def on_validation_batch_end(self, *args: Any, **kwargs: Any) -> bool: + return all(self._exec("on_validation_batch_end", *args, **kwargs)) def on_train_epoch_end(self, *args: Any, **kwargs: Any) -> bool: return all(self._exec("on_train_epoch_end", *args, **kwargs)) @@ -341,3 +357,61 @@ def on_network_initializing_end( self, training_network: nn.HybridBlock ) -> None: copy_parameters(self.predictor.prediction_net, training_network) + + +@dataclass +class _Timer: + duration: float + _start: Optional[float] = field(init=False, default=None) + + def start(self): + assert self._start is None + + self._start = time.time() + + def remaining(self) -> float: + assert self._start is not None + + return max(self._start - time.time(), 0.0) + + def is_running(self) -> bool: + return self.remaining() > 0 + + +class TrainingTimeLimit(BaseModel, Callback): + """Limit time spent for training. + + This is useful when ensuring that training for a given model doesn't + exceed a budget, for example when doing AutoML. + + If `stop_within_epoch` is set to true, training can be stopped after + each batch, otherwise it stops after the end of the epoch. + """ + + time_limit: float + stop_within_epoch: bool = False + _timer: _Timer = PrivateAttr() + + def __init__(self, **data): + super().__init__(**data) + self._timer = _Timer(self.time_limit) + + def on_train_start(self, max_epochs: int) -> None: + self._timer.start() + + def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: + if self.stop_within_epoch: + return self._timer.is_running() + + return True + + def on_epoch_end( + self, + epoch_no: int, + epoch_loss: float, + training_network: nn.HybridBlock, + trainer: gluon.Trainer, + best_epoch_info: Dict[str, Any], + ctx: mx.Context, + ) -> bool: + return self._timer.is_running() diff --git a/src/gluonts/mx/trainer/model_iteration_averaging.py b/src/gluonts/mx/trainer/model_iteration_averaging.py index 051249598d..2681e94919 100644 --- a/src/gluonts/mx/trainer/model_iteration_averaging.py +++ b/src/gluonts/mx/trainer/model_iteration_averaging.py @@ -336,8 +336,9 @@ def on_validation_epoch_end( self.avg_strategy.load_cached_model(training_network) return True - def on_train_batch_end(self, training_network: nn.HybridBlock) -> None: + def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool: self.avg_strategy.apply(training_network) + return True def on_epoch_end( self,