From fa677bd5a893fcf03b84010493aaa7dae2b14110 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 28 Nov 2023 14:11:55 -0500 Subject: [PATCH 1/5] Use logger warn instead --- src/accelerate/tracking.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 711f616b73e..ff3f3238c3c 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -866,7 +866,7 @@ class DVCLiveTracker(GeneralTracker): @on_main_process def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs): - from dvclive import Live + from dvclive import Live, Metric super().__init__() self.live = live if live is not None else Live(**kwargs) @@ -904,7 +904,15 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): if step is not None: self.live.step = step for k, v in values.items(): - self.live.log_metric(k, v, **kwargs) + if Metric.could_log(v): + self.live.log_metric(k, v, **kwargs) + else: + logger.warning( + "Accelerator attempted to log a value of " + f'"{v}" of type {type(v)} for key "{k} as a scalar. ' + "This invocation of DVCLive's Live.log_metric() " + "is incorrect so we dropped this attribute." + ) @on_main_process def finish(self): From d1387c1fcef17cd5f067144a4f73ea2cc939dbbd Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 28 Nov 2023 14:12:23 -0500 Subject: [PATCH 2/5] Warn --- src/accelerate/tracking.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index ff3f3238c3c..22a724713de 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -866,7 +866,7 @@ class DVCLiveTracker(GeneralTracker): @on_main_process def __init__(self, run_name: Optional[str] = None, live: Optional[Any] = None, **kwargs): - from dvclive import Live, Metric + from dvclive import Live super().__init__() self.live = live if live is not None else Live(**kwargs) @@ -901,6 +901,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): kwargs: Additional key word arguments passed along to `dvclive.Live.log_metric()`. """ + from dvclive import Metric if step is not None: self.live.step = step for k, v in values.items(): From 96a16cc1decc5ff367e1f7f418cd2acaa172ac6e Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 28 Nov 2023 14:12:42 -0500 Subject: [PATCH 3/5] Right import --- src/accelerate/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 22a724713de..274864b183c 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -901,7 +901,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): kwargs: Additional key word arguments passed along to `dvclive.Live.log_metric()`. """ - from dvclive import Metric + from dvclive.plots import Metric if step is not None: self.live.step = step for k, v in values.items(): From 5959b53d5ea4c33829ee932d6c44edec3729a3a0 Mon Sep 17 00:00:00 2001 From: "[[ -z $EMAIL ]] && read -e -p \"Enter your email (for git configuration): \" EMAIL" Date: Tue, 28 Nov 2023 14:53:10 -0500 Subject: [PATCH 4/5] Clean up logs --- src/accelerate/logging.py | 12 ++++++++++++ src/accelerate/tracking.py | 11 ++++++----- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/accelerate/logging.py b/src/accelerate/logging.py index d553b9a993c..ebb8c1eb830 100644 --- a/src/accelerate/logging.py +++ b/src/accelerate/logging.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools import logging import os @@ -67,6 +68,17 @@ def log(self, level, msg, *args, **kwargs): self.logger.log(level, msg, *args, **kwargs) state.wait_for_everyone() + @functools.lru_cache(None) + def warning_once(self, *args, **kwargs): + """ + This method is identical to `logger.warning()`, but will emit the warning with the same message only once + + Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the + cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to + switch to another type of cache that includes the caller frame information in the hashing function. + """ + self.warning(*args, **kwargs) + def get_logger(name: str, log_level: str = None): """ diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 274864b183c..27c4955a661 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -640,8 +640,8 @@ def store_init_configuration(self, values: dict): for name, value in list(values.items()): # internally, all values are converted to str in MLflow if len(str(value)) > mlflow.utils.validation.MAX_PARAM_VAL_LENGTH: - logger.warning( - f'Trainer is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s' + logger.warning_once( + f'Accelerate is attempting to log a value of "{value}" for key "{name}" as a parameter. MLflow\'s' f" log_param() only accepts values no longer than {mlflow.utils.validation.MAX_PARAM_VAL_LENGTH} characters so we dropped this attribute." ) del values[name] @@ -670,7 +670,7 @@ def log(self, values: dict, step: Optional[int]): if isinstance(v, (int, float)): metrics[k] = v else: - logger.warning( + logger.warning_once( f'MLflowTracker is attempting to log a value of "{v}" of type {type(v)} for key "{k}" as a metric. ' "MLflow's log_metric() only accepts float and int types so we dropped this attribute." ) @@ -755,7 +755,7 @@ def log(self, values: Dict[str, Union[int, float]], step: Optional[int] = None, clearml_logger = self.task.get_logger() for k, v in values.items(): if not isinstance(v, (int, float)): - logger.warning( + logger.warning_once( "Accelerator is attempting to log a value of " f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' "This invocation of ClearML logger's report_scalar() " @@ -902,13 +902,14 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): Additional key word arguments passed along to `dvclive.Live.log_metric()`. """ from dvclive.plots import Metric + if step is not None: self.live.step = step for k, v in values.items(): if Metric.could_log(v): self.live.log_metric(k, v, **kwargs) else: - logger.warning( + logger.warning_once( "Accelerator attempted to log a value of " f'"{v}" of type {type(v)} for key "{k} as a scalar. ' "This invocation of DVCLive's Live.log_metric() " From 50220f000944298b49b2e8095f8ed558fd5ece45 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Tue, 28 Nov 2023 14:53:41 -0500 Subject: [PATCH 5/5] Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/accelerate/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/tracking.py b/src/accelerate/tracking.py index 27c4955a661..7276f552aaf 100644 --- a/src/accelerate/tracking.py +++ b/src/accelerate/tracking.py @@ -911,7 +911,7 @@ def log(self, values: dict, step: Optional[int] = None, **kwargs): else: logger.warning_once( "Accelerator attempted to log a value of " - f'"{v}" of type {type(v)} for key "{k} as a scalar. ' + f'"{v}" of type {type(v)} for key "{k}" as a scalar. ' "This invocation of DVCLive's Live.log_metric() " "is incorrect so we dropped this attribute." )