-
Notifications
You must be signed in to change notification settings - Fork 401
/
logger_destination.py
121 lines (97 loc) · 5.1 KB
/
logger_destination.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0
"""Base class for logger callback."""
from __future__ import annotations
import pathlib
from abc import ABC
from typing import Any, Dict
from composer.core.callback import Callback
from composer.core.state import State
from composer.loggers.logger import LogLevel
__all__ = ['LoggerDestination']
class LoggerDestination(Callback, ABC):
"""Base class for logger destination.
As this class extends :class:`~.callback.Callback`, logger destinations can run on any training loop
:class:`~composer.core.event.Event`. For example, it may be helpful to run on
:attr:`~composer.core.event.Event.EPOCH_END` to perform any flushing at the end of every epoch.
Example:
.. doctest::
>>> from composer.loggers import LoggerDestination
>>> class MyLogger(LoggerDestination):
... def log_data(self, state, log_level, data):
... print(f'Batch {int(state.timestamp.batch)}: {data}')
>>> logger = MyLogger()
>>> trainer = Trainer(
... ...,
... loggers=[logger]
... )
Batch 0: {'rank_zero_seed': ...}
"""
def log_data(self, state: State, log_level: LogLevel, data: Dict[str, Any]):
"""Log data.
Subclasses should implement this method to store logged data (e.g. write it to a file, send it to a server,
etc...). However, not all loggers need to implement this method.
.. note::
This method will block the training loop. For optimal performance, it is recommended to deepcopy the
``data`` (e.g. ``copy.deepcopy(data)``), and store the copied data in queue. Then, either:
* Use background thread(s) or process(s) to read from this queue to perform any I/O.
* Batch the data together and flush periodically on events, such as
:attr:`~composer.core.event.Event.BATCH_END` or :attr:`~composer.core.event.Event.EPOCH_END`.
.. seealso:: :class:`~composer.loggers.file_logger.FileLogger` as an example.
Args:
state (State): The training state.
log_level (LogLevel): The log level.
data (Dict[str, Any]): The data to log.
"""
del state, log_level, data # unused
pass
def log_file_artifact(
self,
state: State,
log_level: LogLevel,
artifact_name: str,
file_path: pathlib.Path,
*,
overwrite: bool,
):
"""Handle logging of a file artifact stored at ``file_path`` to an artifact named ``artifact_name``.
Subclasses should implement this method to store logged files (e.g. copy it to another folder or upload it to
an object store), then it should implement this method. However, not all loggers need to implement this method.
For example, the :class:`~composer.loggers.tqdm_logger.TQDMLogger` does not implement this method, as it cannot
handle file artifacts.
.. note::
* This method will block the training loop. For optimal performance, it is recommended that this
method copy the file to a temporary directory, enqueue the copied file for processing, and return.
Then, use a background thread(s) or process(s) to read from this queue to perform any I/O.
* After this method returns, training can resume, and the contents of ``file_path`` may change (or be may
deleted). Thus, if processing the file in the background (as is recommended), it is necessary to first
copy the file to a temporary directory. Otherwise, the original file may no longer exist, or the logged
artifact can be corrupted (e.g., if the logger destination is reading from file while the training loop
is writing to it).
Args:
state (State): The training state.
log_level (Union[str, LogLevel]): A :class:`LogLevel`.
artifact_name (str): The name of the artifact.
file_path (pathlib.Path): The file path.
overwrite (bool, optional): Whether to overwrite an existing artifact with the same ``artifact_name``.
(default: ``False``)
"""
del state, log_level, artifact_name, file_path, overwrite # unused
pass
def get_file_artifact(
self,
artifact_name: str,
destination: str,
overwrite: bool = False,
progress_bar: bool = True,
):
"""Handle downloading an artifact named ``artifact_name`` to ``destination``.
Args:
artifact_name (str): The name of the artifact.
destination (str): The destination filepath.
overwrite (bool): Whether to overwrite an existing file at ``destination``. Defaults to ``False``.
progress_bar (bool, optional): Whether to show a progress bar. Ignored if ``path`` is a local file.
(default: ``True``)
"""
del artifact_name, destination, overwrite, progress_bar # unused
raise NotImplementedError