Skip to content

Commit

Permalink
Add TimeLimitCallback to mx/trainer callbacks. (awslabs#1631)
Browse files Browse the repository at this point in the history
Co-authored-by: Jasper <schjaspe@amazon.de>
  • Loading branch information
2 people authored and kashif committed Jun 24, 2022
1 parent d9a56b2 commit bb45758
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 26 deletions.
2 changes: 1 addition & 1 deletion requirements/requirements.txt
Expand Up @@ -2,7 +2,7 @@ holidays>=0.9
matplotlib~=3.0
numpy~=1.16
pandas~=1.0
pydantic~=1.1
pydantic~=1.7
tqdm~=4.23
toolz~=0.10

Expand Down
40 changes: 24 additions & 16 deletions src/gluonts/mx/trainer/_base.py
Expand Up @@ -384,12 +384,16 @@ def loop( # todo call run epoch
loss.backward()
trainer.step(batch_size)

self.callbacks.on_train_batch_end(
training_network=net
should_continue = (
self.callbacks.on_train_batch_end(
training_network=net
)
)
else:
self.callbacks.on_validation_batch_end(
training_network=net
should_continue = (
self.callbacks.on_validation_batch_end(
training_network=net
)
)

epoch_loss.update(None, preds=loss)
Expand All @@ -411,22 +415,26 @@ def loop( # todo call run epoch
f"Number of parameters in {net_name}:"
f" {num_model_param}"
)
if not should_continue:
self.halt = True
break
it.close()

# mark epoch end time and log time cost of current epoch
toc = time.time()
logger.info(
"Epoch[%d] Elapsed time %.3f seconds",
epoch_no,
(toc - tic),
)
if not self.halt:
toc = time.time()
logger.info(
"Epoch[%d] Elapsed time %.3f seconds",
epoch_no,
(toc - tic),
)

logger.info(
"Epoch[%d] Evaluation metric '%s'=%f",
epoch_no,
("" if is_training else "validation_") + "epoch_loss",
lv,
)
logger.info(
"Epoch[%d] Evaluation metric '%s'=%f",
epoch_no,
("" if is_training else "validation_") + "epoch_loss",
lv,
)

return epoch_loss

Expand Down
90 changes: 82 additions & 8 deletions src/gluonts/mx/trainer/callback.py
Expand Up @@ -11,20 +11,24 @@
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

# Standard library imports
from typing import List, Union, Dict, Any
from dataclasses import dataclass, field
from typing import List, Union, Dict, Any, Optional
import logging
import math
import time

# Third-party imports
import mxnet.gluon.nn as nn
import mxnet as mx
from mxnet import gluon
from pydantic import BaseModel, PrivateAttr

# First-party imports
from gluonts.core.component import validated
from gluonts.mx.util import copy_parameters

logger = logging.getLogger(__name__)


class Callback:
"""
Expand Down Expand Up @@ -85,19 +89,25 @@ def on_validation_epoch_start(
The network that is being trained.
"""

def on_train_batch_end(self, training_network: nn.HybridBlock) -> None:
def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool:
"""
Hook that is called after each training batch.
Parameters
----------
training_network
The network that is being trained.
Returns
-------
bool
A boolean whether the training should continue. Defaults to `True`.
"""
return True

def on_validation_batch_end(
self, training_network: nn.HybridBlock
) -> None:
) -> bool:
"""
Hook that is called after each validation batch. This hook is never
called if no validation data is available during training.
Expand All @@ -106,7 +116,13 @@ def on_validation_batch_end(
----------
training_network
The network that is being trained.
Returns
-------
bool
A boolean whether the training should continue. Defaults to `True`.
"""
return True

def on_train_epoch_end(
self,
Expand Down Expand Up @@ -269,11 +285,11 @@ def on_train_epoch_start(self, *args: Any, **kwargs: Any) -> None:
def on_validation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
self._exec("on_validation_epoch_start", *args, **kwargs)

def on_train_batch_end(self, *args: Any, **kwargs: Any) -> None:
self._exec("on_train_batch_end", *args, **kwargs)
def on_train_batch_end(self, *args: Any, **kwargs: Any) -> bool:
return all(self._exec("on_train_batch_end", *args, **kwargs))

def on_validation_batch_end(self, *args: Any, **kwargs: Any) -> None:
self._exec("on_validation_batch_end", *args, **kwargs)
def on_validation_batch_end(self, *args: Any, **kwargs: Any) -> bool:
return all(self._exec("on_validation_batch_end", *args, **kwargs))

def on_train_epoch_end(self, *args: Any, **kwargs: Any) -> bool:
return all(self._exec("on_train_epoch_end", *args, **kwargs))
Expand Down Expand Up @@ -341,3 +357,61 @@ def on_network_initializing_end(
self, training_network: nn.HybridBlock
) -> None:
copy_parameters(self.predictor.prediction_net, training_network)


@dataclass
class _Timer:
duration: float
_start: Optional[float] = field(init=False, default=None)

def start(self):
assert self._start is None

self._start = time.time()

def remaining(self) -> float:
assert self._start is not None

return max(self._start - time.time(), 0.0)

def is_running(self) -> bool:
return self.remaining() > 0


class TrainingTimeLimit(BaseModel, Callback):
"""Limit time spent for training.
This is useful when ensuring that training for a given model doesn't
exceed a budget, for example when doing AutoML.
If `stop_within_epoch` is set to true, training can be stopped after
each batch, otherwise it stops after the end of the epoch.
"""

time_limit: float
stop_within_epoch: bool = False
_timer: _Timer = PrivateAttr()

def __init__(self, **data):
super().__init__(**data)
self._timer = _Timer(self.time_limit)

def on_train_start(self, max_epochs: int) -> None:
self._timer.start()

def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool:
if self.stop_within_epoch:
return self._timer.is_running()

return True

def on_epoch_end(
self,
epoch_no: int,
epoch_loss: float,
training_network: nn.HybridBlock,
trainer: gluon.Trainer,
best_epoch_info: Dict[str, Any],
ctx: mx.Context,
) -> bool:
return self._timer.is_running()
3 changes: 2 additions & 1 deletion src/gluonts/mx/trainer/model_iteration_averaging.py
Expand Up @@ -336,8 +336,9 @@ def on_validation_epoch_end(
self.avg_strategy.load_cached_model(training_network)
return True

def on_train_batch_end(self, training_network: nn.HybridBlock) -> None:
def on_train_batch_end(self, training_network: nn.HybridBlock) -> bool:
self.avg_strategy.apply(training_network)
return True

def on_epoch_end(
self,
Expand Down

0 comments on commit bb45758

Please sign in to comment.