Skip to content

Commit

Permalink
TrainerCallback with batch/epoch/end hooks (allenai#4708)
Browse files Browse the repository at this point in the history
* Wrote tests and stuff. But can't test on PC..

* Added TrainerCallback with metaclass that automatically creates Batch/Epoch callbacks wrapping it.

* Reformatting.

* Changelog

* Added type: ignore's

* Updated with docstrings.

* Refactored _make_callback_type into metaclass.

Co-authored-by: Dirk Groeneveld <dirkg@allenai.org>
  • Loading branch information
viking-sudo-rm and dirkgr committed Oct 7, 2020
1 parent 001e1f7 commit 321d4f4
Show file tree
Hide file tree
Showing 4 changed files with 216 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
be used from the `cached-path` command with `allennlp cached-path --inspect`.
- Added a function `remove_cache_entries` to `common.file_utils` that removes any cache entries matching the given
glob patterns. This can used from the `cached-path` command with `allennlp cached-path --remove some-files-*`.
- Added a `TrainerCallback` object to support state sharing between batch and epoch-level training callbacks.

### Changed

Expand Down
1 change: 1 addition & 0 deletions allennlp/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@
GradientDescentTrainer,
BatchCallback,
EpochCallback,
TrainerCallback,
TrackEpochCallback,
)
129 changes: 128 additions & 1 deletion allennlp/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import time
import traceback
from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Type, Union

from allennlp.common.util import int_to_device

Expand Down Expand Up @@ -198,6 +198,115 @@ def __call__(
trainer.model.epoch = epoch + 1


_BasicCallback = Union[BatchCallback, EpochCallback]


class _TrainerCallbackMeta(type):
def __new__(cls, name, bases, dct):
"""
Add subclasses that wrap the `TrainerCallback` into other interfaces.
"""
subtype = super().__new__(cls, name, bases, dct)
# These subtypes wrap the `TrainerCallback` into the `_BasicCallback` interfaces.
subtype.Batch = cls._make_callback_type(BatchCallback, subtype.on_batch)
subtype.Epoch = cls._make_callback_type(EpochCallback, subtype.on_epoch)
subtype.End = cls._make_callback_type(EpochCallback, subtype.on_end)
return subtype

@classmethod
def _make_callback_type(
cls,
call_type: Type[_BasicCallback],
call: Callable[[], None],
) -> Type[_BasicCallback]: # type: ignore
class _Wrapper(call_type): # type: ignore
def __init__(self, trainer_callback: "TrainerCallback"):
self.trainer_callback = trainer_callback

def __call__(self, trainer: "GradientDescentTrainer", *args, **kwargs):
call(self.trainer_callback, trainer, *args, **kwargs) # type: ignore

return _Wrapper


class TrainerCallback(Registrable, metaclass=_TrainerCallbackMeta):
"""
A general callback object that wraps all three types of callbacks into one.
Rather than a `__call__` method, this class has `on_batch`, `on_epoch`, and `on_end` methods, corresponding to
each callback type. Each one receives the state of the wrapper object as `self`. This enables easier state
sharing between related callbacks.
Under the hood, this is a metaclass that creates wrapping subclasses each time a subclass is created.
"""

def on_batch(
self,
trainer: "GradientDescentTrainer",
batch_inputs: List[List[TensorDict]],
batch_outputs: List[Dict[str, Any]],
epoch: int,
batch_number: int,
is_training: bool,
is_master: bool,
) -> None:
"""
This callback hook is called after the end of each batch. This is equivalent to `BatchCallback`.
"""
pass

def on_epoch(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_master: bool,
) -> None:
"""
This callback hook is called after the end of each epoch. This is equivalent to `EpochCallback`.
"""
pass

def on_end(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_master: bool,
) -> None:
"""
This callback hook is called after the final training epoch. The `epoch` is passed as an argument.
"""
pass

def batch(self):
"""
Construct a `BatchCallback` wrapper for this `TrainCallback`.
The `cls.Batch` type is created by the metaclass.
"""
return self.Batch(self)

def epoch(self):
"""
Construct an `EpochCallback` wrapper for this instance.
The `cls.Epoch` type is created by the metaclass.
"""
return self.Epoch(self)

def end(self):
"""
Construct an `EpochCallback` wrapping the `on_end` end-of-training hook.
The `cls.End` type is created by the metaclass.
"""
return self.End(self)


TrainerCallback.register("null")(TrainerCallback)


@Trainer.register("gradient_descent", constructor="from_partial_objects")
class GradientDescentTrainer(Trainer):
"""
Expand Down Expand Up @@ -315,6 +424,13 @@ class GradientDescentTrainer(Trainer):
A list of callbacks that will be called at the end of every epoch, and at the start of
training (with epoch = -1).
end_callbacks : `List[EpochCallback]`, optional (default = `None`)
A list of callbacks that will be called after the final epoch at the end of training. The type of the
callbacks is the same as `epoch_callbacks`.
trainer_callbacks : `List[TrainerCallback]`, optional (default = `None`)
A list of callbacks that will be called at each batch, epoch, and at the start and end of training.
distributed : `bool`, optional, (default = `False`)
If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also
requires `world_size` to be greater than 1.
Expand Down Expand Up @@ -366,6 +482,8 @@ def __init__(
moving_average: Optional[MovingAverage] = None,
batch_callbacks: List[BatchCallback] = None,
epoch_callbacks: List[EpochCallback] = None,
end_callbacks: List[EpochCallback] = None,
trainer_callbacks: List[TrainerCallback] = None,
distributed: bool = False,
local_rank: int = 0,
world_size: int = 1,
Expand Down Expand Up @@ -414,6 +532,12 @@ def __init__(
self._moving_average = moving_average
self._batch_callbacks = batch_callbacks or []
self._epoch_callbacks = epoch_callbacks or []
self._end_callbacks = end_callbacks or []

for callback in trainer_callbacks or []:
self._batch_callbacks.append(callback.batch())
self._epoch_callbacks.append(callback.epoch())
self._end_callbacks.append(callback.end())

# We keep the total batch number as an instance variable because it
# is used inside a closure for the hook which logs activations in
Expand Down Expand Up @@ -979,6 +1103,9 @@ def train(self) -> Dict[str, Any]:

epochs_trained += 1

for callback in self._end_callbacks:
callback(self, metrics=metrics, epoch=epoch, is_master=self._master)

# make sure pending events are flushed to disk and files are closed properly
self._tensorboard.close()

Expand Down
86 changes: 86 additions & 0 deletions tests/training/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
TensorboardWriter,
BatchCallback,
EpochCallback,
TrainerCallback,
TrackEpochCallback,
)
from allennlp.training.learning_rate_schedulers import CosineWithRestarts
Expand Down Expand Up @@ -1006,6 +1007,91 @@ def test_track_epoch_callback(self):
trainer.train()
assert trainer.model.epoch == num_epochs

def test_end_callback_is_called_at_end(self):
class FakeEndCallback(EpochCallback):
def __call__(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_master: bool,
) -> None:
if not hasattr(trainer, "end_callback_calls"):
trainer.end_callback_calls = [] # type: ignore
trainer.end_callback_calls.append(epoch) # type: ignore

trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=4,
validation_data_loader=self.validation_data_loader,
end_callbacks=[FakeEndCallback()],
)
trainer.train()
expected_calls = [3]
assert trainer.end_callback_calls == expected_calls

def test_trainer_callback_is_called_everywhere(self):
class FakeTrainerCallback(TrainerCallback):
def on_batch(
self,
trainer: "GradientDescentTrainer",
batch_inputs: List[List[TensorDict]],
batch_outputs: List[Dict[str, Any]],
epoch: int,
batch_number: int,
is_training: bool,
is_master: bool,
) -> None:
if not hasattr(trainer, "batch_callback_calls"):
trainer.batch_callback_calls = [] # type: ignore
trainer.batch_callback_calls.append((epoch, batch_number, is_training)) # type: ignore

def on_epoch(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_master: bool,
) -> None:
if not hasattr(trainer, "epoch_callback_calls"):
trainer.epoch_callback_calls = [] # type: ignore
trainer.epoch_callback_calls.append(epoch) # type: ignore

def on_end(
self,
trainer: "GradientDescentTrainer",
metrics: Dict[str, Any],
epoch: int,
is_master: bool,
) -> None:
if not hasattr(trainer, "end_callback_calls"):
trainer.end_callback_calls = [] # type: ignore
trainer.end_callback_calls.append(epoch) # type: ignore

trainer = GradientDescentTrainer(
self.model,
self.optimizer,
self.data_loader,
num_epochs=2,
validation_data_loader=self.validation_data_loader,
trainer_callbacks=[FakeTrainerCallback()],
)
trainer.train()
expected_batch_calls = [
(epoch, batch_number + 1, is_train)
for epoch in range(2)
for is_train in (True, False)
for batch_number in range(len(self.instances) // 2)
]
expected_epoch_calls = [epoch for epoch in range(-1, 2)]
expected_end_calls = [1]

assert trainer.batch_callback_calls == expected_batch_calls
assert trainer.epoch_callback_calls == expected_epoch_calls
assert trainer.end_callback_calls == expected_end_calls

def test_total_loss_is_average_of_batch_loss(self):

batches_per_epoch = 3
Expand Down

0 comments on commit 321d4f4

Please sign in to comment.