Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 50 additions & 30 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,6 +393,10 @@ def __init__(
self.flatten = util_ops.flatten
self.unflatten = util_ops.unflatten

# patch model generate with ours if model uses it
if hasattr(self.module, "generate"):
self.generate = self._generate

def destroy(self):
if self.optimizer is not None and hasattr(self.optimizer, 'destroy'):
self.optimizer.destroy()
Expand Down Expand Up @@ -1766,16 +1770,9 @@ def _scale_loss_by_gas(self, prescaled_loss):

return scaled_loss

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
r"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""

def pre_forward(self):
if self.autotuning_profile_model_info():
ma = get_ma_status()
self._pre_fwd_memory_allocated = get_ma_status()
else:
see_memory_usage("Engine before forward", force=self.memory_breakdown())

Expand All @@ -1800,21 +1797,6 @@ def forward(self, *inputs, **kwargs):
if flops_profiler_active:
self.flops_profiler.start_profile(ignore_list=None)

if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({
"curriculum_seqlen":
self.curriculum_scheduler_legacy.get_current_difficulty()
})

if self.module.training and self.random_ltd_enabled():
self.random_ltd_scheduler.update_seq(self.global_steps)

Expand All @@ -1830,32 +1812,70 @@ def forward(self, *inputs, **kwargs):
if self.training_dataloader is None:
self.tput_timer.start()

if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

loss = self.module(*inputs, **kwargs)

def post_forward(self):
if self.zero_optimization_partition_weights():
# Disable automated discovery of external parameters
for module in self.module.modules():
module._parameters._in_forward = False

flops_profiler_active = (self.flops_profiler_enabled() and self.global_steps
== self.flops_profiler_profile_step()
and self.global_rank == 0)

self._stop_timers(self.engine_timers.forward_timers)

if flops_profiler_active:
self.flops_profiler.stop_profile()

if self.autotuning_profile_model_info():
activation_mem = get_ma_status() - ma
activation_mem = get_ma_status() - self._pre_fwd_memory_allocated
self.autotuning_model_info["activation_mem_per_gpu"] = activation_mem
print_json_dist(self.autotuning_model_info,
[0],
path=self.autotuning_model_info_path())
exit()
else:
see_memory_usage("Engine after forward", force=self.memory_breakdown())

@instrument_w_nvtx
def forward(self, *inputs, **kwargs):
"""Execute forward propagation
Arguments:
*inputs: Variable length input list
**kwargs: variable length keyword arguments
"""

self.pre_forward()

if self.__class__.__name__ != "PipelineEngine":
# TODO: The above if condition is a HACK since for PipelineEngine
# it's difficult to inject argument in forward pass.
if self.module.training and self.curriculum_enabled_legacy():
self.curriculum_scheduler_legacy.update_difficulty(self.global_steps + 1)
if self.curriculum_params_legacy()["curriculum_type"] == "seqlen":
kwargs.update({
"curriculum_seqlen":
self.curriculum_scheduler_legacy.get_current_difficulty()
})

if self.module.training:
if self.progressive_layer_drop:
kwargs.update(self.progressive_layer_drop.get_state())

if self.fp16_auto_cast():
inputs = self._cast_inputs_half(inputs)

loss = self.module(*inputs, **kwargs)
self.post_forward()

return loss

def _generate(self, *inputs, **kwargs):
self.pre_forward()
out = self.module.generate(*inputs, **kwargs)
self.post_forward()
return out

def _cast_inputs_half(self, inputs):
if isinstance(inputs, (list, tuple)):
new_inputs = []
Expand Down