From 5f2241893cefdc04ff19cc9bb64b8b4e3fe38c26 Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 12 Apr 2023 16:12:41 +0000 Subject: [PATCH 1/2] Set batch size min to world size --- utils/neuralmagic/sparsification_manager.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/utils/neuralmagic/sparsification_manager.py b/utils/neuralmagic/sparsification_manager.py index ff743538aaa7..d60c1443bf61 100644 --- a/utils/neuralmagic/sparsification_manager.py +++ b/utils/neuralmagic/sparsification_manager.py @@ -10,7 +10,7 @@ from utils.loss import ComputeLoss from utils.neuralmagic.quantization import update_model_bottlenecks from utils.neuralmagic.utils import ALMOST_ONE, QAT_BATCH_SCALE, ToggleableModelEMA, load_ema, nm_log_console -from utils.torch_utils import ModelEMA +from utils.torch_utils import WORLD_SIZE, ModelEMA __all__ = [ "SparsificationManager", @@ -391,8 +391,18 @@ def rescale_gradient_accumulation( maintaining the original effective batch size """ + if batch_size == WORLD_SIZE: + self.log_console( + "Could not scale down batch size for QAT as minimum batch size of " + f"{batch_size} is already used. Run may encounter an out of memory " + "error due to QAT", + level="warning", + ) + + return batch_size, accumulate + effective_batch_size = batch_size * accumulate - batch_size = max(batch_size // QAT_BATCH_SCALE, 1) + batch_size = max(batch_size // QAT_BATCH_SCALE, WORLD_SIZE) accumulate = effective_batch_size // batch_size self.log_console( From 237dbae46795d92ac9902f5a5ea9edd53317582e Mon Sep 17 00:00:00 2001 From: Damian Date: Wed, 12 Apr 2023 16:20:04 +0000 Subject: [PATCH 2/2] Remove unused arg --- utils/neuralmagic/sparsification_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/neuralmagic/sparsification_manager.py b/utils/neuralmagic/sparsification_manager.py index d60c1443bf61..a79c79c64210 100644 --- a/utils/neuralmagic/sparsification_manager.py +++ b/utils/neuralmagic/sparsification_manager.py @@ -382,7 +382,7 @@ def disable_ema_amp( return ema, amp, scaler def rescale_gradient_accumulation( - self, batch_size: int, accumulate: int, image_size: int + self, batch_size: int, accumulate: int ) -> Tuple[int, int]: """ Used when autobatch and QAT are both enabled. Training with QAT adds additional