Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def __init__(
global dist
from deepspeed import comm as dist
self._is_gradient_accumulation_boundary = None
self.scale_wrt_gas = None

# for debug purposes - can then debug print: debug_get_module_name(module)
debug_extract_module_and_param_names(model)
Expand Down Expand Up @@ -1693,7 +1694,11 @@ def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
self.buffered_allreduce_fallback(elements_per_buffer=bucket_size)

@instrument_w_nvtx
def backward(self, loss, allreduce_gradients=True, release_loss=False):
def backward(self,
loss,
allreduce_gradients=True,
release_loss=False,
scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
Expand All @@ -1702,13 +1707,16 @@ def backward(self, loss, allreduce_gradients=True, release_loss=False):

see_memory_usage("Engine before backward", force=self.memory_breakdown())

if self.scale_wrt_gas is not None:
scale_wrt_gas = self.scale_wrt_gas

if not allreduce_gradients:
logger.warning(
f"Argument `allreduce_gradients` is deprecated, ignored, and will soon be removed"
)

# scale loss w.r.t. gradient accumulation if needed
if self.gradient_accumulation_steps() > 1:
if self.gradient_accumulation_steps() > 1 and scale_wrt_gas:
loss = self._scale_loss_by_gas(loss.float())

# Log training Loss
Expand Down
6 changes: 5 additions & 1 deletion deepspeed/utils/timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,13 +190,17 @@ def stop(self, report_speed=True):
self.end_time = time.time()
duration = self.end_time - self.start_time
self.total_elapsed_time += duration

curr_samples_sec = (self.batch_size * self.num_workers) / duration

if self.local_step_count % self.steps_per_output == 0:
if report_speed:
self.logging(
"{}/{}, SamplesPerSec={}, MemAllocated={}GB, MaxMemAllocated={}GB"
"{}/{}, RunningAvgSamplesPerSec={}, CurrSamplesPerSec={}, MemAllocated={}GB, MaxMemAllocated={}GB"
.format(self.epoch_count,
self.local_step_count,
self.avg_samples_per_sec(),
curr_samples_sec,
round(torch.cuda.memory_allocated() / 1024**3,
2),
round(torch.cuda.max_memory_allocated() / 1024**3,
Expand Down