From e1fab05ce7da5765073d72c7a1a7b4ba20cf23cd Mon Sep 17 00:00:00 2001 From: eajechiloae <97950284+eugen-ajechiloae-clearml@users.noreply.github.com> Date: Thu, 26 Oct 2023 19:13:28 +0300 Subject: [PATCH] Add ClearML tracker (#2034) * add clearml tracker * fix style in tracking.py * run ruff --fix * run ruff fix on src/accelerate/utils/__init__.py as well * properly run make style * add tests * modify code based on code review * changes based on code review * quote data_frame * fix docs * remove pandas req in log_table * style changes * add tracker to docs --- docs/source/package_reference/tracking.md | 2 + docs/source/usage_guides/tracking.md | 4 +- src/accelerate/test_utils/testing.py | 16 +++ src/accelerate/tracking.py | 163 +++++++++++++++++++++- src/accelerate/utils/__init__.py | 2 + src/accelerate/utils/dataclasses.py | 1 + src/accelerate/utils/imports.py | 8 ++ tests/test_tracking.py | 144 +++++++++++++++++++ 8 files changed, 336 insertions(+), 4 deletions(-) diff --git a/docs/source/package_reference/tracking.md b/docs/source/package_reference/tracking.md index 36719db683a..6845ca4bc05 100644 --- a/docs/source/package_reference/tracking.md +++ b/docs/source/package_reference/tracking.md @@ -31,3 +31,5 @@ rendered properly in your Markdown viewer. - __init__ [[autodoc]] tracking.MLflowTracker - __init__ +[[autodoc]] tracking.ClearMLTracker + - __init__ diff --git a/docs/source/usage_guides/tracking.md b/docs/source/usage_guides/tracking.md index fa123cde699..141fea6924b 100644 --- a/docs/source/usage_guides/tracking.md +++ b/docs/source/usage_guides/tracking.md @@ -20,12 +20,14 @@ There are a large number of experiment tracking API's available, however getting ## Integrated Trackers -Currently `Accelerate` supports four trackers out-of-the-box: +Currently `Accelerate` supports six trackers out-of-the-box: - TensorBoard - WandB - CometML +- Aim - MLFlow +- ClearML To use any of them, pass in the selected type(s) to the `log_with` parameter in [`Accelerate`]: ```python diff --git a/src/accelerate/test_utils/testing.py b/src/accelerate/test_utils/testing.py index 40afdac4813..bb887285954 100644 --- a/src/accelerate/test_utils/testing.py +++ b/src/accelerate/test_utils/testing.py @@ -31,10 +31,12 @@ from ..utils import ( gather, is_bnb_available, + is_clearml_available, is_comet_ml_available, is_datasets_available, is_deepspeed_available, is_mps_available, + is_pandas_available, is_safetensors_available, is_tensorboard_available, is_timm_available, @@ -231,6 +233,20 @@ def require_comet_ml(test_case): return unittest.skipUnless(is_comet_ml_available(), "test requires comet_ml")(test_case) +def require_clearml(test_case): + """ + Decorator marking a test that requires clearml installed. These tests are skipped when clearml isn't installed + """ + return unittest.skipUnless(is_clearml_available(), "test requires clearml")(test_case) + + +def require_pandas(test_case): + """ + Decorator marking a test that requires pandas installed. These tests are skipped when pandas isn't installed + """ + return unittest.skipUnless(is_pandas_available(), "test requires pandas")(test_case) + + _atleast_one_tracker_available = ( any([is_wandb_available(), is_tensorboard_available()]) and not is_comet_ml_available() ) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index aa2eb08b742..4f536d57812 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -28,6 +28,7 @@ from .utils import ( LoggerType, is_aim_available, + is_clearml_available, is_comet_ml_available, is_mlflow_available, is_tensorboard_available, @@ -53,6 +54,9 @@ if is_mlflow_available(): _available_trackers.append(LoggerType.MLFLOW) +if is_clearml_available(): + _available_trackers.append(LoggerType.CLEARML) + logger = get_logger(__name__) @@ -365,11 +369,11 @@ def log_table( Args: table_name (`str`): The name to give to the logged table on the wandb workspace - columns (List of `str`'s *optional*): + columns (list of `str`, *optional*): The name of the columns on the table - data (List of List of Any data type *optional*): + data (List of List of Any data type, *optional*): The data to be logged in the table - dataframe (Any data type *optional*): + dataframe (Any data type, *optional*): The data to be logged in the table step (`int`, *optional*): The run step. If included, the log will be affiliated with this step. @@ -681,12 +685,165 @@ def finish(self): mlflow.end_run() +class ClearMLTracker(GeneralTracker): + """ + A `Tracker` class that supports `clearml`. Should be initialized at the start of your script. + + Args: + run_name (`str`, *optional*): + Name of the experiment. Environment variables `CLEARML_PROJECT` and `CLEARML_TASK` have priority over this + argument. + kwargs: + Kwargs passed along to the `Task.__init__` method. + """ + + name = "clearml" + requires_logging_directory = False + + @on_main_process + def __init__(self, run_name: str = None, **kwargs): + from clearml import Task + + current_task = Task.current_task() + self._initialized_externally = False + if current_task: + self._initialized_externally = True + self.task = current_task + return + + kwargs.setdefault("project_name", os.environ.get("CLEARML_PROJECT", run_name)) + kwargs.setdefault("task_name", os.environ.get("CLEARML_TASK", run_name)) + self.task = Task.init(**kwargs) + + @property + def tracker(self): + return self.task + + @on_main_process + def store_init_configuration(self, values: dict): + """ + Connect configuration dictionary to the Task object. Should be run at the beginning of your experiment. + + Args: + values (`dict`): + Values to be stored as initial hyperparameters as key-value pairs. + """ + return self.task.connect_configuration(values) + + @on_main_process + def log(self, values: Dict[str, Union[int, float]], step: Optional[int] = None, **kwargs): + """ + Logs `values` dictionary to the current run. The dictionary keys must be strings. The dictionary values must be + ints or floats + + Args: + values (`Dict[str, Union[int, float]]`): + Values to be logged as key-value pairs. If the key starts with 'eval_'/'test_'/'train_', the value will + be reported under the 'eval'/'test'/'train' series and the respective prefix will be removed. + Otherwise, the value will be reported under the 'train' series, and no prefix will be removed. + step (`int`, *optional*): + If specified, the values will be reported as scalars, with the iteration number equal to `step`. + Otherwise they will be reported as single values. + kwargs: + Additional key word arguments passed along to the `clearml.Logger.report_single_value` or + `clearml.Logger.report_scalar` methods. + """ + clearml_logger = self.task.get_logger() + for k, v in values.items(): + if not isinstance(v, (int, float)): + logger.warning( + "Accelerator is attempting to log a value of " + f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' + "This invocation of ClearML logger's report_scalar() " + "is incorrect so we dropped this attribute." + ) + continue + if step is None: + clearml_logger.report_single_value(name=k, value=v, **kwargs) + continue + title, series = ClearMLTracker._get_title_series(k) + clearml_logger.report_scalar(title=title, series=series, value=v, iteration=step, **kwargs) + + @on_main_process + def log_images(self, values: dict, step: Optional[int] = None, **kwargs): + """ + Logs `images` to the current run. + + Args: + values (`Dict[str, List[Union[np.ndarray, PIL.Image]]`): + Values to be logged as key-value pairs. The values need to have type `List` of `np.ndarray` or + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `clearml.Logger.report_image` method. + """ + clearml_logger = self.task.get_logger() + for k, v in values.items(): + title, series = ClearMLTracker._get_title_series(k) + clearml_logger.report_image(title=title, series=series, iteration=step, image=v, **kwargs) + + @on_main_process + def log_table( + self, + table_name: str, + columns: List[str] = None, + data: List[List[Any]] = None, + dataframe: Any = None, + step: Optional[int] = None, + **kwargs, + ): + """ + Log a Table to the task. Can be defined eitherwith `columns` and `data` or with `dataframe`. + + Args: + table_name (`str`): + The name of the table + columns (list of `str`, *optional*): + The name of the columns on the table + data (List of List of Any data type, *optional*): + The data to be logged in the table. If `columns` is not specified, then the first entry in data will be + the name of the columns of the table + dataframe (Any data type, *optional*): + The data to be logged in the table + step (`int`, *optional*): + The run step. If included, the log will be affiliated with this step. + kwargs: + Additional key word arguments passed along to the `clearml.Logger.report_table` method. + """ + to_report = dataframe + if dataframe is None: + if data is None: + raise ValueError( + "`ClearMLTracker.log_table` requires that `data` to be supplied if `dataframe` is `None`" + ) + to_report = [columns] + data if columns else data + title, series = ClearMLTracker._get_title_series(table_name) + self.task.get_logger().report_table(title=title, series=series, table_plot=to_report, iteration=step, **kwargs) + + @on_main_process + def finish(self): + """ + Close the ClearML task. If the task was initialized externally (e.g. by manually calling `Task.init`), this + function is a noop + """ + if self.task and not self._initialized_externally: + self.task.close() + + @staticmethod + def _get_title_series(name): + for prefix in ["eval", "test", "train"]: + if name.startswith(prefix + "_"): + return name[len(prefix) + 1 :], prefix + return name, "train" + + LOGGER_TYPE_TO_CLASS = { "aim": AimTracker, "comet_ml": CometMLTracker, "mlflow": MLflowTracker, "tensorboard": TensorBoardTracker, "wandb": WandBTracker, + "clearml": ClearMLTracker, } diff --git a/src/accelerate/utils/__init__.py b/src/accelerate/utils/__init__.py index c430c7b6d6f..317204fc64c 100644 --- a/src/accelerate/utils/__init__.py +++ b/src/accelerate/utils/__init__.py @@ -46,6 +46,7 @@ is_bnb_available, is_boto3_available, is_ccl_available, + is_clearml_available, is_comet_ml_available, is_cuda_available, is_datasets_available, @@ -56,6 +57,7 @@ is_mlflow_available, is_mps_available, is_npu_available, + is_pandas_available, is_rich_available, is_safetensors_available, is_sagemaker_available, diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index 44f48c34017..954850a2df6 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -348,6 +348,7 @@ class LoggerType(BaseEnum): WANDB = "wandb" COMETML = "comet_ml" MLFLOW = "mlflow" + CLEARML = "clearml" class PrecisionType(BaseEnum): diff --git a/src/accelerate/utils/imports.py b/src/accelerate/utils/imports.py index 44d5459c5be..3ff167aecf6 100644 --- a/src/accelerate/utils/imports.py +++ b/src/accelerate/utils/imports.py @@ -210,6 +210,14 @@ def is_tqdm_available(): return _is_package_available("tqdm") +def is_clearml_available(): + return _is_package_available("clearml") + + +def is_pandas_available(): + return _is_package_available("pandas") + + def is_mlflow_available(): if _is_package_available("mlflow"): return True diff --git a/tests/test_tracking.py b/tests/test_tracking.py index 5a0288232e8..545b51fefd4 100644 --- a/tests/test_tracking.py +++ b/tests/test_tracking.py @@ -25,6 +25,7 @@ from typing import Optional from unittest import mock +import numpy as np import torch # We use TF to parse the logs @@ -32,7 +33,9 @@ from accelerate.test_utils.testing import ( MockingTestCase, TempDirTestCase, + require_clearml, require_comet_ml, + require_pandas, require_tensorboard, require_wandb, skip, @@ -250,6 +253,147 @@ def test_log(self): self.assertEqual(self.get_value_from_key(list_of_json, "my_text"), "some_value") +@require_clearml +class ClearMLTest(TempDirTestCase, MockingTestCase): + def setUp(self): + super().setUp() + # ClearML offline session location is stored in CLEARML_CACHE_DIR + self.add_mocks(mock.patch.dict(os.environ, {"CLEARML_CACHE_DIR": self.tmpdir})) + + @staticmethod + def _get_offline_dir(accelerator): + from clearml.config import get_offline_dir + + return get_offline_dir(task_id=accelerator.get_tracker("clearml", unwrap=True).id) + + @staticmethod + def _get_metrics(offline_dir): + metrics = [] + with open(os.path.join(offline_dir, "metrics.jsonl")) as f: + json_lines = f.readlines() + for json_line in json_lines: + metrics.extend(json.loads(json_line)) + return metrics + + def test_init_trackers(self): + from clearml import Task + from clearml.utilities.config import text_to_config_dict + + Task.set_offline(True) + accelerator = Accelerator(log_with="clearml") + config = {"num_iterations": 12, "learning_rate": 1e-2, "some_boolean": False, "some_string": "some_value"} + accelerator.init_trackers("test_project_with_config", config) + + offline_dir = ClearMLTest._get_offline_dir(accelerator) + accelerator.end_training() + + with open(os.path.join(offline_dir, "task.json")) as f: + offline_session = json.load(f) + clearml_offline_config = text_to_config_dict(offline_session["configuration"]["General"]["value"]) + self.assertDictEqual(config, clearml_offline_config) + + def test_log(self): + from clearml import Task + + Task.set_offline(True) + accelerator = Accelerator(log_with="clearml") + accelerator.init_trackers("test_project_with_log") + values_with_iteration = {"should_be_under_train": 1, "eval_value": 2, "test_value": 3.1, "train_value": 4.1} + accelerator.log(values_with_iteration, step=1) + single_values = {"single_value_1": 1.1, "single_value_2": 2.2} + accelerator.log(single_values) + + offline_dir = ClearMLTest._get_offline_dir(accelerator) + accelerator.end_training() + + metrics = ClearMLTest._get_metrics(offline_dir) + self.assertEqual(len(values_with_iteration) + len(single_values), len(metrics)) + for metric in metrics: + if metric["metric"] == "Summary": + self.assertIn(metric["variant"], single_values) + self.assertEqual(metric["value"], single_values[metric["variant"]]) + elif metric["metric"] == "should_be_under_train": + self.assertEqual(metric["variant"], "train") + self.assertEqual(metric["iter"], 1) + self.assertEqual(metric["value"], values_with_iteration["should_be_under_train"]) + else: + values_with_iteration_key = metric["variant"] + "_" + metric["metric"] + self.assertIn(values_with_iteration_key, values_with_iteration) + self.assertEqual(metric["iter"], 1) + self.assertEqual(metric["value"], values_with_iteration[values_with_iteration_key]) + + def test_log_images(self): + from clearml import Task + + Task.set_offline(True) + accelerator = Accelerator(log_with="clearml") + accelerator.init_trackers("test_project_with_log_images") + + base_image = np.eye(256, 256, dtype=np.uint8) * 255 + base_image_3d = np.concatenate((np.atleast_3d(base_image), np.zeros((256, 256, 2), dtype=np.uint8)), axis=2) + images = { + "base_image": base_image, + "base_image_3d": base_image_3d, + } + accelerator.get_tracker("clearml").log_images(images, step=1) + + offline_dir = ClearMLTest._get_offline_dir(accelerator) + accelerator.end_training() + + images_saved = Path(os.path.join(offline_dir, "data")).rglob("*.jpeg") + self.assertEqual(len(list(images_saved)), len(images)) + + def test_log_table(self): + from clearml import Task + + Task.set_offline(True) + accelerator = Accelerator(log_with="clearml") + accelerator.init_trackers("test_project_with_log_table") + + accelerator.get_tracker("clearml").log_table( + "from lists with columns", columns=["A", "B", "C"], data=[[1, 3, 5], [2, 4, 6]] + ) + accelerator.get_tracker("clearml").log_table("from lists", data=[["A2", "B2", "C2"], [7, 9, 11], [8, 10, 12]]) + offline_dir = ClearMLTest._get_offline_dir(accelerator) + accelerator.end_training() + + metrics = ClearMLTest._get_metrics(offline_dir) + self.assertEqual(len(metrics), 2) + for metric in metrics: + self.assertIn(metric["metric"], ["from lists", "from lists with columns"]) + plot = json.loads(metric["plot_str"]) + if metric["metric"] == "from lists with columns": + print(plot["data"][0]) + self.assertCountEqual(plot["data"][0]["header"]["values"], ["A", "B", "C"]) + self.assertCountEqual(plot["data"][0]["cells"]["values"], [[1, 2], [3, 4], [5, 6]]) + else: + self.assertCountEqual(plot["data"][0]["header"]["values"], ["A2", "B2", "C2"]) + self.assertCountEqual(plot["data"][0]["cells"]["values"], [[7, 8], [9, 10], [11, 12]]) + + @require_pandas + def test_log_table_pandas(self): + import pandas as pd + from clearml import Task + + Task.set_offline(True) + accelerator = Accelerator(log_with="clearml") + accelerator.init_trackers("test_project_with_log_table_pandas") + + accelerator.get_tracker("clearml").log_table( + "from df", dataframe=pd.DataFrame({"A": [1, 2], "B": [3, 4], "C": [5, 6]}), step=1 + ) + + offline_dir = ClearMLTest._get_offline_dir(accelerator) + accelerator.end_training() + + metrics = ClearMLTest._get_metrics(offline_dir) + self.assertEqual(len(metrics), 1) + self.assertEqual(metrics[0]["metric"], "from df") + plot = json.loads(metrics[0]["plot_str"]) + self.assertCountEqual(plot["data"][0]["header"]["values"], [["A"], ["B"], ["C"]]) + self.assertCountEqual(plot["data"][0]["cells"]["values"], [[1, 2], [3, 4], [5, 6]]) + + class MyCustomTracker(GeneralTracker): "Basic tracker that writes to a csv for testing" _col_names = [