Skip to content

Commit

Permalink
Sampling with Pipe Parallel model (microsoft#28)
Browse files Browse the repository at this point in the history
* test sparse self_attn fix

* mlperf attn initial commit

* add inference_batch fn

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* Pull changes from DeepSpeed

* cleanup, reinstantiate sending of logits / layer_past

* cleanup, reinstantiate sending of logits / layer_past

Co-authored-by: sid <sidney.black@aleph-alpha.de>
  • Loading branch information
sdtblck and sid committed Apr 8, 2021
1 parent 04a52ad commit 0e95737
Show file tree
Hide file tree
Showing 3 changed files with 223 additions and 47 deletions.
51 changes: 26 additions & 25 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def print_configuration(args, name):
class DeepSpeedEngine(Module):
r"""DeepSpeed engine for training.
"""

def __init__(self,
args,
model,
Expand Down Expand Up @@ -150,7 +151,7 @@ def __init__(self,

if mpu is not None:
assert not self.elasticity_enabled(), "Elasticity is not currently supported" \
" with model parallelism."
" with model parallelism."

self._set_distributed_vars()

Expand Down Expand Up @@ -458,7 +459,7 @@ def _configure_checkpointing(self, dist_init_required):

# only the first data parallel process needs to store the model checkpoint
self.save_non_zero_checkpoint = (
dp_rank == 0) or self.zero_optimization_partition_weights()
dp_rank == 0) or self.zero_optimization_partition_weights()

if self.zero_optimization():
param_rank = torch.distributed.get_rank(
Expand Down Expand Up @@ -538,7 +539,7 @@ def _do_args_sanity_check(self, args):

def _is_supported_optimizer(self, optimizer_name):
return optimizer_name in DEEPSPEED_OPTIMIZERS or \
getattr(torch.optim, optimizer_name, None) is not None
getattr(torch.optim, optimizer_name, None) is not None

# Validate configuration based on command line arguments
def _do_sanity_check(self):
Expand Down Expand Up @@ -713,7 +714,7 @@ def _configure_fp16_optimizer(self, optimizer):
else:
log_dist('Creating fp16 optimizer with static loss scale: {}'.format(
self.loss_scale()),
ranks=[0])
ranks=[0])
optimizer = FP16_Optimizer(
optimizer,
static_loss_scale=self.loss_scale(),
Expand Down Expand Up @@ -943,11 +944,11 @@ def forward(self, *inputs, **kwargs):
return loss

def allreduce_gradients(self, bucket_size=MEMORY_OPT_ALLREDUCE_SIZE):
#Zero stage 2 communicates during non gradient accumulation boundaries as well
# Zero stage 2 communicates during non gradient accumulation boundaries as well
if self.zero_optimization_partition_gradients():
self.optimizer.overlapping_partition_gradients_reduce_epilogue()

#Communicate only at gradient accumulation boundaries
# Communicate only at gradient accumulation boundaries
elif self.is_gradient_accumulation_boundary():
if self.zero_optimization_stage() == ZERO_OPTIMIZATION_OPTIMIZER_STATES:
assert self.zero_reduce_scatter()
Expand Down Expand Up @@ -1048,7 +1049,7 @@ def is_gradient_accumulation_boundary(self):
bool: if the current step is a gradient accumulation boundary.
"""
return (self.micro_steps + 1) % \
self.gradient_accumulation_steps() == 0
self.gradient_accumulation_steps() == 0

def zero_grad(self):
"""
Expand Down Expand Up @@ -1082,8 +1083,8 @@ def _take_model_step(self, lr_kwargs):
self.timers('_step_step').stop()

self.timers('_step_zero_grad').start()
#zero grad in basic optimizer could be unreliable and may not exhibit
#the behaviour that we want
# zero grad in basic optimizer could be unreliable and may not exhibit
# the behaviour that we want
if not self.zero_optimization() and not self.fp16_enabled(
) and not self.amp_enabled():
self.zero_grad()
Expand Down Expand Up @@ -1431,7 +1432,7 @@ def load_checkpoint(self,
tag = fd.read().strip()
else:
logger.warning(f"Unable to find latest file at {latest_path}, if trying to load latest " \
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.")
"checkpoint please ensure this file exists or pass an explicit checkpoint tag when loading a checkpoint.")
return None, None

load_path, client_states = self._load_checkpoint(load_dir,
Expand Down Expand Up @@ -1459,7 +1460,7 @@ def _load_checkpoint(self,
if not os.path.exists(load_path):
logger.warn(
'Client provided checkpoint load path: {} does not exist ... skip checkpoint load'
.format(load_path))
.format(load_path))
return None, None

logger.info(f'rank: {self.global_rank} loading checkpoint: {load_path}')
Expand Down Expand Up @@ -1502,7 +1503,7 @@ def _load_checkpoint(self,
client_state = {
key: value
for key,
value in checkpoint.items() if not key in deepspeed_states
value in checkpoint.items() if not key in deepspeed_states
}

return load_path, client_state
Expand Down Expand Up @@ -1592,8 +1593,8 @@ def _checkpoint_tag_validation(self, tag):
dist.all_reduce(min_bhash, op=torch.distributed.ReduceOp.MIN)
valid = all(min_bhash == bhash) and all(max_bhash == bhash)
msg = f"[rank={dist.get_rank()}] The checkpoint tag name '{tag}' is not consistent across " \
"all ranks. Including rank unique information in checkpoint tag could cause issues when " \
"restoring with different world sizes."
"all ranks. Including rank unique information in checkpoint tag could cause issues when " \
"restoring with different world sizes."
if self.checkpoint_tag_validation_fail():
assert valid, msg
elif not valid:
Expand Down Expand Up @@ -1685,29 +1686,29 @@ def _save_checkpoint(self, save_dir, tag, client_state={}):

state = {
'module':
self.module_state_dict(),
self.module_state_dict(),
'optimizer':
self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None,
self.optimizer.state_dict()
if self.optimizer and not self.zero_optimization() else None,
'lr_scheduler':
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
self.lr_scheduler.state_dict() if self.lr_scheduler is not None else None,
'csr_tensor_module_names':
self.csr_tensor_module_names,
self.csr_tensor_module_names,
'skipped_steps':
self.skipped_steps,
self.skipped_steps,
'global_steps':
self.global_steps,
self.global_steps,
'global_samples':
self.global_samples,
self.global_samples,
'dp_world_size':
self.dp_world_size,
self.dp_world_size,
'mp_world_size':
self.mp_world_size
self.mp_world_size
}
state.update(client_state)

log_dist(message=f'Saving model checkpoint: {save_path}', ranks=[0])
#logger.info('Saving model checkpoint: {}'.format(save_path))
# logger.info('Saving model checkpoint: {}'.format(save_path))
torch.save(state, save_path)
self._curr_save_path = None

Expand Down
119 changes: 117 additions & 2 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,123 @@ def eval_batch(self, data_iter):

return self.agg_eval_loss

def inference_batch(self, data_iter):
"""Inference the pipeline on a single batch of data from ``data_iter``.
This method is equivalent to:
.. code-block:: python
module.eval()
with torch.no_grad():
output = module(batch)
.. warning::
we're assuming that in inference we a) don't want to calculate loss and b) gradient_accum_steps = 0
Args:
data_iter (Iterator): Iterator of data to evaluate.
data_iter should have dummy labels as deepspeed expects it this way
Returns:
logits, presents (NB this is not a general purpose function, it's designed specifically to run with
gpt-neox, which will return logits + presents in inference. This is a massive hack.)
"""
self.module.eval()
self.total_loss = None
if self.micro_batches > 1:
print_rank_0('WARNING: setting g.a.s to 1 in inference')
self.micro_batches = 1
train_batch_fn = self.batch_fn
self.set_batch_fn(lambda x: x) # we just want to return `data_iter` as is
# deepspeed sends metadata across pipeline stages only once in the first step, then assumes it will stay
# constant in inference, the metadata of the tensors being sent across pipe stages may change we need to set
# these two flags in order for deepspeed to send the metadata every step, otherwise torch.distributed hangs
# silently.
self.first_output_send = True
self.pipe_recv_buf = None
if self.is_data_parallel:
raise NotImplementedError('Inference not yet implemented for pipeline + data parellel')

# Use the provided data iterator
train_iterator = self.data_iterator
self.set_dataiterator(data_iter)

# Do the work
sched = schedule.InferenceSchedule(micro_batches=self.micro_batches,
stages=self.num_stages,
stage_id=self.stage_id)
with torch.no_grad():
self._exec_schedule(sched)

# the shapes are variable so we need to first broadcast the shapes, then the tensors themselves
if self.is_last_stage():
logits, presents = self.total_loss
logits = logits.clone().detach()
presents = presents.clone().detach()
logits_shape = list(logits.shape)
presents_shape = list(presents.shape)

logits_shape_tensor = torch.LongTensor(logits_shape).to(self.device)
presents_shape_tensor = torch.LongTensor(presents_shape).to(self.device)
dist.broadcast(tensor=logits_shape_tensor,
src=self.global_rank)
dist.broadcast(tensor=presents_shape_tensor,
src=self.global_rank)
else:
src_rank = self.grid.stage_to_global(self.num_stages - 1)
logits_shape_tensor = torch.LongTensor([0] * 3).to(self.device)
presents_shape_tensor = torch.LongTensor([0] * 6).to(self.device)
dist.broadcast(tensor=logits_shape_tensor,
src=src_rank)
dist.broadcast(tensor=presents_shape_tensor,
src=src_rank)
logits_shape_tensor = logits_shape_tensor.clone().detach()
presents_shape_tensor = presents_shape_tensor.clone().detach()

logits_shape = logits_shape_tensor.tolist()
presents_shape = presents_shape_tensor.tolist()

if self.is_last_stage():
dist.broadcast(tensor=logits,
src=self.global_rank,
group=self.mpu.get_pipe_parallel_group())
dist.broadcast(tensor=presents,
src=self.global_rank,
group=self.mpu.get_pipe_parallel_group())

else:
logits = torch.zeros(logits_shape, dtype=torch.half if self.fp16_enabled() else torch.float32).to(
self.device)
presents = torch.zeros(presents_shape, dtype=torch.half if self.fp16_enabled() else torch.float32).to(
self.device)
src_rank = self.grid.stage_to_global(self.num_stages - 1)
assert src_rank in self.grid.pp_group
dist.broadcast(tensor=logits,
src=src_rank,
group=self.grid.get_pipe_parallel_group())
dist.broadcast(tensor=presents,
src=src_rank,
group=self.grid.get_pipe_parallel_group())
logits = logits.clone().detach()
presents = presents.clone().detach()

# self.agg_eval_loss = self._aggregate_total_loss()
if self.tensorboard_enabled():
if self.global_rank == 0:
self.summary_events = [(f'Train/Samples/eval_loss',
self.agg_eval_loss.mean().item(),
self.global_samples)]
for event in self.summary_events: # write_summary_events
self.summary_writer.add_scalar(event[0], event[1], event[2])
self.summary_writer.flush()

# Restore the training iterator & batch_fn
self.set_dataiterator(train_iterator)
self.set_batch_fn(train_batch_fn)

return logits, presents

def is_first_stage(self):
"""True if this process is in the first stage in the pipeline."""
return self.stage_id == 0
Expand Down Expand Up @@ -779,7 +896,6 @@ def _exec_send_activations(self, buffer_id):
if self.wall_clock_breakdown():
self.timers('pipe_send_output').start()
self.timers('comms').start()

outputs = self.pipe_buffers['outputs'][buffer_id]

# NCCL does not like to send torch.BoolTensor types, so cast the mask to half().
Expand Down Expand Up @@ -870,7 +986,6 @@ def _exec_recv_activations(self, buffer_id):
self.timers('pipe_recv_input').start()

recvd = None

# Allocate the buffer if necessary
if self.pipe_recv_buf is None:
self.pipe_recv_buf = self._recv_tensor_meta(self.prev_stage)
Expand Down
Loading

0 comments on commit 0e95737

Please sign in to comment.