Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1754,14 +1754,17 @@ def eval(self):
self.warn_unscaled_loss = True
self.module.train(False)

def _scale_loss_by_gas(self, prescaled_loss):
def _scale_loss_by_gas(self, prescaled_loss, eval_micro_batches=None):
# In pipeline evaluation, there is an option to use different micro-bs, which creates different number of
# micro batches, thus the training gas, is not valid in this case. need to use the number of eval_micro_batches
scaling_factor = self.gradient_accumulation_steps() if eval_micro_batches is None else eval_micro_batches
if isinstance(prescaled_loss, torch.Tensor):
scaled_loss = prescaled_loss / self.gradient_accumulation_steps()
scaled_loss = prescaled_loss / scaling_factor
elif isinstance(prescaled_loss, tuple) or isinstance(prescaled_loss, list):
scaled_loss = []
for l in prescaled_loss:
if isinstance(l, torch.Tensor):
scaled_loss.append(l / self.gradient_accumulation_steps())
scaled_loss.append(l / scaling_factor)
else:
scaled_loss.append(l)
else:
Expand Down
17 changes: 13 additions & 4 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,13 @@ def train_batch(self, data_iter=None):
# TODO: should return precisely what loss returned and allow others to be queried?
return self.agg_train_loss

def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_output='avg', bcast_loss=True):
def eval_batch(self,
data_iter,
return_logits=False,
compute_loss=True,
reduce_output='avg',
bcast_loss=True,
num_micro_batches=None):
"""Evaluate the pipeline on a batch of data from ``data_iter``. The
engine will evaluate ``self.train_batch_size()`` total samples
collectively across all workers.
Expand Down Expand Up @@ -451,6 +457,9 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o
train_iterator = self.data_iterator
self.set_dataiterator(data_iter)

# set the number micro batches in case the user chose value than training
micro_batches = self.micro_batches if num_micro_batches is None else num_micro_batches

# Do the work
sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
stages=self.num_stages,
Expand All @@ -463,7 +472,7 @@ def eval_batch(self, data_iter, return_logits=False, compute_loss=True, reduce_o
self._exec_schedule(sched)

if self.is_last_stage():
eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output)
eval_output = self._reduce_outputs(self.fwd_outputs, reduce=reduce_output, micro_batches=micro_batches)

if compute_loss and (bcast_loss or self.monitor.enabled):
eval_output = self._bcast_pipe_scalar(eval_output)
Expand Down Expand Up @@ -505,7 +514,7 @@ def is_last_stage(self):
"""True if this process is in the last stage in the pipeline."""
return self.stage_id == self.num_stages - 1

def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True, micro_batches=None):
if reduce is None:
return outputs

Expand All @@ -520,7 +529,7 @@ def _reduce_outputs(self, outputs, reduce='avg', reduce_dp=True):
reduced[idx] += out

# Average over the microbatches
reduced = self._scale_loss_by_gas(reduced)
reduced = self._scale_loss_by_gas(reduced, eval_micro_batches=micro_batches)

# Average over DP groups
if reduce_dp and self.is_data_parallel:
Expand Down