diff --git a/utils/neuralmagic/sparsification_manager.py b/utils/neuralmagic/sparsification_manager.py index ff743538aaa7..a79c79c64210 100644 --- a/utils/neuralmagic/sparsification_manager.py +++ b/utils/neuralmagic/sparsification_manager.py @@ -10,7 +10,7 @@ from utils.loss import ComputeLoss from utils.neuralmagic.quantization import update_model_bottlenecks from utils.neuralmagic.utils import ALMOST_ONE, QAT_BATCH_SCALE, ToggleableModelEMA, load_ema, nm_log_console -from utils.torch_utils import ModelEMA +from utils.torch_utils import WORLD_SIZE, ModelEMA __all__ = [ "SparsificationManager", @@ -382,7 +382,7 @@ def disable_ema_amp( return ema, amp, scaler def rescale_gradient_accumulation( - self, batch_size: int, accumulate: int, image_size: int + self, batch_size: int, accumulate: int ) -> Tuple[int, int]: """ Used when autobatch and QAT are both enabled. Training with QAT adds additional @@ -391,8 +391,18 @@ def rescale_gradient_accumulation( maintaining the original effective batch size """ + if batch_size == WORLD_SIZE: + self.log_console( + "Could not scale down batch size for QAT as minimum batch size of " + f"{batch_size} is already used. Run may encounter an out of memory " + "error due to QAT", + level="warning", + ) + + return batch_size, accumulate + effective_batch_size = batch_size * accumulate - batch_size = max(batch_size // QAT_BATCH_SCALE, 1) + batch_size = max(batch_size // QAT_BATCH_SCALE, WORLD_SIZE) accumulate = effective_batch_size // batch_size self.log_console(