Skip to content

Commit

Permalink
dvclive tracker
Browse files Browse the repository at this point in the history
  • Loading branch information
dberenbaum committed Nov 9, 2023
1 parent 6727ac4 commit f5d67f1
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 4 deletions.
3 changes: 2 additions & 1 deletion docs/source/usage_guides/tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,15 @@ There are a large number of experiment tracking API's available, however getting

## Integrated Trackers

Currently `Accelerate` supports six trackers out-of-the-box:
Currently `Accelerate` supports seven trackers out-of-the-box:

- TensorBoard
- WandB
- CometML
- Aim
- MLFlow
- ClearML
- DVCLive

To use any of them, pass in the selected type(s) to the `log_with` parameter in [`Accelerate`]:
```python
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/deepspeed_with_config_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
Expand Down
2 changes: 1 addition & 1 deletion examples/by_feature/megatron_lm_gpt_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def parse_args():
default="all",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
' `"wandb"` and `"comet_ml"`. Use `"all"` (default) to report to all integrations.'
' `"wandb"`, `"comet_ml"`, and `"dvclive"`. Use `"all"` (default) to report to all integrations.'
"Only applicable when `--with_tracking` is passed."
),
)
Expand Down
8 changes: 8 additions & 0 deletions src/accelerate/test_utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
is_comet_ml_available,
is_datasets_available,
is_deepspeed_available,
is_dvclive_available,
is_mps_available,
is_pandas_available,
is_tensorboard_available,
Expand Down Expand Up @@ -231,6 +232,13 @@ def require_clearml(test_case):
return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case)


def require_dvclive(test_case):
"""
Decorator marking a test that requires dvclive installed. These tests are skipped when dvclive isn't installed
"""
return unittest.skipUnless(is_dvclive_available(), "test requires dvclive")(test_case)


def require_pandas(test_case):
"""
Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed
Expand Down
70 changes: 70 additions & 0 deletions src/accelerate/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
is_aim_available,
is_clearml_available,
is_comet_ml_available,
is_dvclive_available,
is_mlflow_available,
is_tensorboard_available,
is_wandb_available,
Expand Down Expand Up @@ -57,6 +58,9 @@
if is_clearml_available():
_available_trackers.append(LoggerType.CLEARML)

if is_dvclive_available():
_available_trackers.append(LoggerType.DVCLIVE)

logger = get_logger(__name__)


Expand Down Expand Up @@ -837,13 +841,78 @@ def _get_title_series(name):
return name, "train"


class DVCLiveTracker(GeneralTracker):
"""
A `Tracker` class that supports `dvclive`. Should be initialized at the start of your script.
Args:
run_name (`str`):
Ignored for dvclive. See `kwargs` instead.
kwargs:
Additional key word arguments passed along to `dvclive.Live()`.
"""

name = "dvclive"
requires_logging_directory = False

@on_main_process
def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs):
from dvclive import Live

super().__init__()
self.live = live if live is not None else Live(**kwargs)

@property
def tracker(self):
return self.live

@on_main_process
def store_init_configuration(self, values: dict):
"""
Logs `values` as hyperparameters for the run. Should be run at the beginning of your experiment. Stores the
hyperparameters in a yaml file for future use.
Args:
values (Dictionary `str` to `bool`, `str`, `float`, `int`, or a List or Dict of those types):
Values to be stored as initial hyperparameters as key-value pairs. The values need to have type `bool`,
`str`, `float`, or `int`.
"""
self.live.log_params(values)

@on_main_process
def log(self, values: dict, step: Optional[int] = None, **kwargs):
"""
Logs `values` to the current run.
Args:
values (Dictionary `str` to `str`, `float`, or `int`):
Values to be logged as key-value pairs. The values need to have type `str`, `float`, or `int`.
step (`int`, *optional*):
The run step. If included, the log will be affiliated with this step.
kwargs:
Additional key word arguments passed along to `dvclive.Live.log_metric()`.
"""
if step:
self.live.step = step
for k, v in values.items():
self.live.log_metric(k, v, **kwargs)

@on_main_process
def finish(self):
"""
Closes `dvclive.Live()`.
"""
self.live.end()


LOGGER_TYPE_TO_CLASS = {
"aim": AimTracker,
"comet_ml": CometMLTracker,
"mlflow": MLflowTracker,
"tensorboard": TensorBoardTracker,
"wandb": WandBTracker,
"clearml": ClearMLTracker,
"dvclive": DVCLiveTracker,
}


Expand All @@ -866,6 +935,7 @@ def filter_trackers(
- `"wandb"`
- `"comet_ml"`
- `"mlflow"`
- `"dvclive"`
If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
logging_dir (`str`, `os.PathLike`, *optional*):
Expand Down
1 change: 1 addition & 0 deletions src/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
is_cuda_available,
is_datasets_available,
is_deepspeed_available,
is_dvclive_available,
is_fp8_available,
is_ipex_available,
is_megatron_lm_available,
Expand Down
2 changes: 2 additions & 0 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ class LoggerType(BaseEnum):
- **TENSORBOARD** -- TensorBoard as an experiment tracker
- **WANDB** -- wandb as an experiment tracker
- **COMETML** -- comet_ml as an experiment tracker
- **DVCLIVE** -- dvclive as an experiment tracker
"""

ALL = "all"
Expand All @@ -349,6 +350,7 @@ class LoggerType(BaseEnum):
COMETML = "comet_ml"
MLFLOW = "mlflow"
CLEARML = "clearml"
DVCLIVE = "dvclive"


class PrecisionType(BaseEnum):
Expand Down
4 changes: 4 additions & 0 deletions src/accelerate/utils/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,3 +297,7 @@ def is_xpu_available(check_device=False):
except RuntimeError:
return False
return hasattr(torch, "xpu") and torch.xpu.is_available()


def is_dvclive_available():
return _is_package_available("dvclive")
54 changes: 53 additions & 1 deletion tests/test_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,19 @@
TempDirTestCase,
require_clearml,
require_comet_ml,
require_dvclive,
require_pandas,
require_tensorboard,
require_wandb,
skip,
)
from accelerate.tracking import CometMLTracker, GeneralTracker
from accelerate.utils import ProjectConfiguration, is_comet_ml_available, is_tensorboard_available
from accelerate.utils import (
ProjectConfiguration,
is_comet_ml_available,
is_dvclive_available,
is_tensorboard_available,
)


if is_comet_ml_available():
Expand All @@ -52,6 +58,11 @@

import tensorboard.compat.proto.event_pb2 as event_pb2

if is_dvclive_available():
from dvclive.plots.metric import Metric
from dvclive.serialize import load_yaml
from dvclive.utils import parse_metrics

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -473,3 +484,44 @@ def test_log(self):
"some_string": "",
}
self.assertDictEqual(data, truth)


@require_dvclive
class DVCLiveTrackingTest(unittest.TestCase):
def test_init_trackers(self):
with mock.patch("dvclive.live.get_dvc_repo") as repo_mock:
repo_mock.return_value = None
project_name = "test_project_with_config"
with tempfile.TemporaryDirectory() as dirpath:
accelerator = Accelerator(log_with="dvclive")
config = {
"num_iterations": 12,
"learning_rate": 1e-2,
"some_boolean": False,
"some_string": "some_value",
}
init_kwargs = {"dvclive": {"dir": dirpath, "save_dvc_exp": False, "dvcyaml": None}}
accelerator.init_trackers(project_name, config, init_kwargs)
accelerator.end_training()
live = accelerator.trackers[0].live
params = load_yaml(live.params_file)
assert params == config

def test_log(self):
with mock.patch("dvclive.live.get_dvc_repo") as repo_mock:
repo_mock.return_value = None
project_name = "test_project_with_log"
with tempfile.TemporaryDirectory() as dirpath:
accelerator = Accelerator(log_with="dvclive", project_dir=dirpath)
init_kwargs = {"dvclive": {"dir": dirpath, "save_dvc_exp": False, "dvcyaml": None}}
accelerator.init_trackers(project_name, init_kwargs=init_kwargs)
values = {"total_loss": 0.1, "iteration": 1, "my_text": "some_value"}
accelerator.log(values, step=0)
accelerator.end_training()
live = accelerator.trackers[0].live
logs, latest = parse_metrics(live)
assert latest == values
scalars = os.path.join(live.plots_dir, Metric.subfolder)
assert os.path.join(scalars, "total_loss.tsv") in logs
assert os.path.join(scalars, "iteration.tsv") in logs
assert os.path.join(scalars, "my_text.tsv") in logs

0 comments on commit f5d67f1

Please sign in to comment.