Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the sequence-parallelism for the dense model architecture #4530

Merged
merged 8 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 6 additions & 2 deletions deepspeed/runtime/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,10 @@ def get_sparse_gradients_enabled(param_dict):
return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT)


def get_communication_data_type(param_dict):
val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT)
def get_communication_data_type(param_dict,
comm_type=COMMUNICATION_DATA_TYPE,
comm_data_type_default=COMMUNICATION_DATA_TYPE_DEFAULT):
val = get_scalar_param(param_dict, comm_type, comm_data_type_default)
val = val.lower() if val is not None else val
if val is None:
return val # we must determine it by other parameters
Expand Down Expand Up @@ -784,6 +786,8 @@ def _initialize_params(self, param_dict):

self.disable_allgather = get_disable_allgather(param_dict)
self.communication_data_type = get_communication_data_type(param_dict)
self.seq_parallel_communication_data_type = get_communication_data_type(
param_dict, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT)
self.prescale_gradients = get_prescale_gradients(param_dict)
self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict)
self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)
Expand Down
13 changes: 13 additions & 0 deletions deepspeed/runtime/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,19 @@
COMMUNICATION_DATA_TYPE = "communication_data_type"
COMMUNICATION_DATA_TYPE_DEFAULT = None

###########################################################
# Gradient communication data type for sequence parallelism
###########################################################
# Supported types: ['fp16', 'bf16','fp32']
# Default value is fp32
# Users can configure in ds_config.json as below example:
SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_FORMAT = '''
Optional comm data type for seq paralleism should be set as:
"seq_parallel_communication_data_type": "fp32"
'''
SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_comm_data_type"
SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32"

#########################################
# Scale/predivide gradients before allreduce
#########################################
Expand Down
14 changes: 11 additions & 3 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,10 @@ def communication_data_type(self):

return torch.float32

@communication_data_type.setter
tjruwase marked this conversation as resolved.
Show resolved Hide resolved
def communication_data_type(self, value):
self._config.communication_data_type = value

def postscale_gradients(self):
return not self._config.prescale_gradients

Expand Down Expand Up @@ -1114,6 +1118,9 @@ def _configure_distributed_model(self, model):
self.mp_world_size = groups._get_model_parallel_world_size()
self.expert_parallel_group = groups._get_expert_parallel_group_dict()
self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict()
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
if self.sequence_parallel_size > 1:
self.communication_data_type = self._config.seq_parallel_communication_data_type

if not (self.amp_enabled() or is_zero_init_model):
self._broadcast_model()
Expand Down Expand Up @@ -2370,7 +2377,7 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer):
if self.pipeline_parallelism:
dp_group = self.mpu.get_data_parallel_group()
else:
dp_group = groups._get_data_parallel_group()
dp_group = groups._get_sequence_data_parallel_group()

if bucket_type == SparseTensor.type():
self.sparse_allreduce_no_retain(bucket, dp_group=dp_group)
Expand Down Expand Up @@ -2431,9 +2438,10 @@ def sparse_allreduce(self, sparse, dp_group):

if self.postscale_gradients():
if self.gradient_average:
values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group))
values.mul_(self.gradient_predivide_factor() /
(dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))
else:
values.mul_(1. / dist.get_world_size(group=dp_group))
values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size)))

indices_device_list = self.sparse_all_gather(indices, dp_group)
values_device_list = self.sparse_all_gather(values, dp_group)
Expand Down
5 changes: 3 additions & 2 deletions deepspeed/runtime/zero/stage3.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def __init__(
self.reduce_scatter = reduce_scatter

self.dp_process_group = self.parameter_offload.dp_process_group
self.sequence_parallel_size = groups._get_sequence_parallel_world_size()

self.all2all_process_group = all2all_process_group

Expand Down Expand Up @@ -1177,7 +1178,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tenso

world_sz = dist.get_world_size(self.dp_process_group)
rank = dist.get_rank(self.dp_process_group)
buffer_to_reduce.div_(world_sz)
buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size))

dist.all_reduce(buffer_to_reduce, group=self.dp_process_group)

Expand Down Expand Up @@ -1476,7 +1477,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None):
if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)

tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

if rank is None:
# "All Reducing"
Expand Down
16 changes: 9 additions & 7 deletions deepspeed/runtime/zero/stage_1_and_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from deepspeed.utils import link_hp_params
from deepspeed.checkpoint import enable_universal_checkpoint

from deepspeed.utils import groups
# Toggle this to true to enable correctness test
# with gradient partitioning and without
pg_correctness_test = False
Expand Down Expand Up @@ -182,7 +183,7 @@ def __init__(self,
self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu'

self.dp_process_group = dp_process_group

self.sequence_parallel_size = groups._get_sequence_parallel_world_size()
#expert parallel group
self.ep_process_group = expert_parallel_group

Expand Down Expand Up @@ -941,9 +942,10 @@ def gradient_reduction_w_predivide(self, tensor):
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)

if self.gradient_predivide_factor != dp_world_size:
tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size)
tensor_to_allreduce.mul_(self.gradient_predivide_factor /
(dp_world_size / float(self.sequence_parallel_size)))
else:
tensor_to_allreduce.div_(dp_world_size)
tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size))
dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group)

if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce:
Expand Down Expand Up @@ -985,7 +987,7 @@ def average_tensor(self, tensor):
if self.ipg_bucket_has_moe_params:
process_group = self.expert_dp_process_group[param.group_name] if is_moe_param(
param) else self.dp_process_group
grad_reduc.data.div_(dist.get_world_size(group=process_group))
grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size))

partition_ids = self.param_to_partition_ids[i][param_id]
assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids
Expand Down Expand Up @@ -1025,7 +1027,7 @@ def average_tensor(self, tensor):
prev_id, prev_process_group = partition_id, process_group

if not self.ipg_bucket_has_moe_params:
tensor.div_(dist.get_world_size(group=self.dp_process_group))
tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

tensor_to_reduce = tensor
if self.communication_data_type != tensor.dtype:
Expand Down Expand Up @@ -1395,15 +1397,15 @@ def allreduce_bucket(self, bucket, rank=None, log=None):

tensor_to_allreduce = tensor

if pg_correctness_test:
if pg_correctness_test or self.sequence_parallel_size > 1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.sequence_parallel_size >1 is now redundant given the ds_config flag, right?

communication_data_type = torch.float32
else:
communication_data_type = self.communication_data_type

if communication_data_type != tensor.dtype:
tensor_to_allreduce = tensor.to(communication_data_type)

tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group))
tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size))

if rank is None:
# "All Reducing"
Expand Down