Skip to content
This repository was archived by the owner on Jun 4, 2025. It is now read-only.
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions utils/neuralmagic/sparsification_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from utils.loss import ComputeLoss
from utils.neuralmagic.quantization import update_model_bottlenecks
from utils.neuralmagic.utils import ALMOST_ONE, QAT_BATCH_SCALE, ToggleableModelEMA, load_ema, nm_log_console
from utils.torch_utils import ModelEMA
from utils.torch_utils import WORLD_SIZE, ModelEMA

__all__ = [
"SparsificationManager",
Expand Down Expand Up @@ -382,7 +382,7 @@ def disable_ema_amp(
return ema, amp, scaler

def rescale_gradient_accumulation(
self, batch_size: int, accumulate: int, image_size: int
self, batch_size: int, accumulate: int
) -> Tuple[int, int]:
"""
Used when autobatch and QAT are both enabled. Training with QAT adds additional
Expand All @@ -391,8 +391,18 @@ def rescale_gradient_accumulation(
maintaining the original effective batch size
"""

if batch_size == WORLD_SIZE:
self.log_console(
"Could not scale down batch size for QAT as minimum batch size of "
f"{batch_size} is already used. Run may encounter an out of memory "
"error due to QAT",
level="warning",
)

return batch_size, accumulate

effective_batch_size = batch_size * accumulate
batch_size = max(batch_size // QAT_BATCH_SCALE, 1)
batch_size = max(batch_size // QAT_BATCH_SCALE, WORLD_SIZE)
accumulate = effective_batch_size // batch_size

self.log_console(
Expand Down