From 066644d7b2d3d521c710663f433811b26a640dba Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 18 Oct 2023 00:04:22 +0500 Subject: [PATCH 1/6] fix the sequence-parallelism for the dense models --- deepspeed/runtime/zero/stage3.py | 5 +++-- deepspeed/runtime/zero/stage_1_and_2.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/deepspeed/runtime/zero/stage3.py b/deepspeed/runtime/zero/stage3.py index ed44af08a686..a703fe6fa794 100644 --- a/deepspeed/runtime/zero/stage3.py +++ b/deepspeed/runtime/zero/stage3.py @@ -219,6 +219,7 @@ def __init__( self.reduce_scatter = reduce_scatter self.dp_process_group = self.parameter_offload.dp_process_group + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() self.all2all_process_group = all2all_process_group @@ -1177,7 +1178,7 @@ def __avg_scatter_contiguous_grads(self, buffer_to_reduce: Tensor) -> List[Tenso world_sz = dist.get_world_size(self.dp_process_group) rank = dist.get_rank(self.dp_process_group) - buffer_to_reduce.div_(world_sz) + buffer_to_reduce.div_(world_sz / float(self.sequence_parallel_size)) dist.all_reduce(buffer_to_reduce, group=self.dp_process_group) @@ -1476,7 +1477,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) if rank is None: # "All Reducing" diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 8c025a1a2b9f..847cf5f0261e 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -31,6 +31,7 @@ from deepspeed.utils import link_hp_params from deepspeed.checkpoint import enable_universal_checkpoint +from deepspeed.utils import groups # Toggle this to true to enable correctness test # with gradient partitioning and without pg_correctness_test = False @@ -182,7 +183,7 @@ def __init__(self, self.device = get_accelerator().current_device_name() if not self.cpu_offload else 'cpu' self.dp_process_group = dp_process_group - + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() #expert parallel group self.ep_process_group = expert_parallel_group @@ -941,9 +942,9 @@ def gradient_reduction_w_predivide(self, tensor): dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / dp_world_size) + tensor_to_allreduce.mul_(self.gradient_predivide_factor / (dp_world_size / float(self.sequence_parallel_size)) else: - tensor_to_allreduce.div_(dp_world_size) + tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -985,7 +986,7 @@ def average_tensor(self, tensor): if self.ipg_bucket_has_moe_params: process_group = self.expert_dp_process_group[param.group_name] if is_moe_param( param) else self.dp_process_group - grad_reduc.data.div_(dist.get_world_size(group=process_group)) + grad_reduc.data.div_(dist.get_world_size(group=process_group) / float(self.sequence_parallel_size)) partition_ids = self.param_to_partition_ids[i][param_id] assert all([p_id < dist.get_world_size(group=process_group) for p_id in partition_ids @@ -1025,7 +1026,7 @@ def average_tensor(self, tensor): prev_id, prev_process_group = partition_id, process_group if not self.ipg_bucket_has_moe_params: - tensor.div_(dist.get_world_size(group=self.dp_process_group)) + tensor.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) tensor_to_reduce = tensor if self.communication_data_type != tensor.dtype: @@ -1396,7 +1397,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): tensor_to_allreduce = tensor - if pg_correctness_test: + if pg_correctness_test or self.sequence_parallel_size > 1: communication_data_type = torch.float32 else: communication_data_type = self.communication_data_type @@ -1404,7 +1405,7 @@ def allreduce_bucket(self, bucket, rank=None, log=None): if communication_data_type != tensor.dtype: tensor_to_allreduce = tensor.to(communication_data_type) - tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group)) + tensor_to_allreduce.div_(dist.get_world_size(group=self.dp_process_group) / float(self.sequence_parallel_size)) if rank is None: # "All Reducing" From 8d901bfcf87c0adf0089af2856a6735ce050dd5d Mon Sep 17 00:00:00 2001 From: Reza Yazdani Date: Wed, 18 Oct 2023 21:33:04 +0500 Subject: [PATCH 2/6] fix the gradient scale for when zero is not enabled --- deepspeed/runtime/engine.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 8a8193ddd8f5..e6323b996bf9 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1114,6 +1114,9 @@ def _configure_distributed_model(self, model): self.mp_world_size = groups._get_model_parallel_world_size() self.expert_parallel_group = groups._get_expert_parallel_group_dict() self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() + self.sequence_parallel_size = groups._get_sequence_parallel_world_size() + if self.sequence_parallel_size > 1: + self.communication_data_type = torch.float32 if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() @@ -2303,9 +2306,9 @@ def allreduce_bucket(self, bucket, dp_group): dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group): - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + tensor_to_allreduce.mul_(self.gradient_predivide_factor() / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) else: - tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) + tensor_to_allreduce.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -2431,9 +2434,9 @@ def sparse_allreduce(self, sparse, dp_group): if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) + values.mul_(self.gradient_predivide_factor() / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) else: - values.mul_(1. / dist.get_world_size(group=dp_group)) + values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) indices_device_list = self.sparse_all_gather(indices, dp_group) values_device_list = self.sparse_all_gather(values, dp_group) From 0bb95947ab342787084ae6d176732d88b7370044 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Oct 2023 15:38:28 -0700 Subject: [PATCH 3/6] fix comm group for allreduce --- deepspeed/runtime/engine.py | 10 +++++++--- deepspeed/runtime/zero/stage_1_and_2.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index e6323b996bf9..d5b4a8df4d46 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -807,6 +807,10 @@ def communication_data_type(self): return torch.bfloat16 return torch.float32 + + @communication_data_type.setter + def communication_data_type(self, value): + self._config.communication_data_type = value def postscale_gradients(self): return not self._config.prescale_gradients @@ -2306,9 +2310,9 @@ def allreduce_bucket(self, bucket, dp_group): dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.gradient_average: if self.gradient_predivide_factor() != dist.get_world_size(group=dp_group): - tensor_to_allreduce.mul_(self.gradient_predivide_factor() / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + tensor_to_allreduce.mul_(self.gradient_predivide_factor() / dist.get_world_size(group=dp_group)) else: - tensor_to_allreduce.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + tensor_to_allreduce.mul_(1. / dist.get_world_size(group=dp_group)) dist.all_reduce(tensor_to_allreduce, group=dp_group) if self.communication_data_type != tensor.dtype and tensor is not tensor_to_allreduce: @@ -2373,7 +2377,7 @@ def _reduce_non_expert_gradients(self, grads, elements_per_buffer): if self.pipeline_parallelism: dp_group = self.mpu.get_data_parallel_group() else: - dp_group = groups._get_data_parallel_group() + dp_group = groups._get_sequence_data_parallel_group() if bucket_type == SparseTensor.type(): self.sparse_allreduce_no_retain(bucket, dp_group=dp_group) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index 847cf5f0261e..a4fe5d7d1410 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -942,7 +942,7 @@ def gradient_reduction_w_predivide(self, tensor): dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / (dp_world_size / float(self.sequence_parallel_size)) + tensor_to_allreduce.mul_(self.gradient_predivide_factor / (dp_world_size / float(self.sequence_parallel_size))) else: tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) From aaae9949ca4e0f1eb2149d1e44706be4f4e32a98 Mon Sep 17 00:00:00 2001 From: Masahiro Tanaka Date: Wed, 18 Oct 2023 15:41:13 -0700 Subject: [PATCH 4/6] fix format --- deepspeed/runtime/engine.py | 5 +++-- deepspeed/runtime/zero/stage_1_and_2.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index d5b4a8df4d46..fc3f4bdfe9eb 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -807,7 +807,7 @@ def communication_data_type(self): return torch.bfloat16 return torch.float32 - + @communication_data_type.setter def communication_data_type(self, value): self._config.communication_data_type = value @@ -2438,7 +2438,8 @@ def sparse_allreduce(self, sparse, dp_group): if self.postscale_gradients(): if self.gradient_average: - values.mul_(self.gradient_predivide_factor() / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) + values.mul_(self.gradient_predivide_factor() / + (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) else: values.mul_(1. / (dist.get_world_size(group=dp_group) / float(self.sequence_parallel_size))) diff --git a/deepspeed/runtime/zero/stage_1_and_2.py b/deepspeed/runtime/zero/stage_1_and_2.py index a4fe5d7d1410..2c9cf67b5bb7 100755 --- a/deepspeed/runtime/zero/stage_1_and_2.py +++ b/deepspeed/runtime/zero/stage_1_and_2.py @@ -942,7 +942,8 @@ def gradient_reduction_w_predivide(self, tensor): dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) if self.gradient_predivide_factor != dp_world_size: - tensor_to_allreduce.mul_(self.gradient_predivide_factor / (dp_world_size / float(self.sequence_parallel_size))) + tensor_to_allreduce.mul_(self.gradient_predivide_factor / + (dp_world_size / float(self.sequence_parallel_size))) else: tensor_to_allreduce.div_(dp_world_size / float(self.sequence_parallel_size)) dist.all_reduce(tensor_to_allreduce, group=self.dp_process_group) From 01ccf3311b479646258c66a694732afee253ac41 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Wed, 25 Oct 2023 07:53:43 -0700 Subject: [PATCH 5/6] Allow users to set/override sp comm data type from ds config --- deepspeed/runtime/config.py | 8 ++++++-- deepspeed/runtime/constants.py | 13 +++++++++++++ deepspeed/runtime/engine.py | 2 +- 3 files changed, 20 insertions(+), 3 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index c31b9671296f..274316bacda0 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -238,8 +238,10 @@ def get_sparse_gradients_enabled(param_dict): return get_scalar_param(param_dict, SPARSE_GRADIENTS, SPARSE_GRADIENTS_DEFAULT) -def get_communication_data_type(param_dict): - val = get_scalar_param(param_dict, COMMUNICATION_DATA_TYPE, COMMUNICATION_DATA_TYPE_DEFAULT) +def get_communication_data_type(param_dict, + comm_type=COMMUNICATION_DATA_TYPE, + comm_data_type_default=COMMUNICATION_DATA_TYPE_DEFAULT): + val = get_scalar_param(param_dict, comm_type, comm_data_type_default) val = val.lower() if val is not None else val if val is None: return val # we must determine it by other parameters @@ -784,6 +786,8 @@ def _initialize_params(self, param_dict): self.disable_allgather = get_disable_allgather(param_dict) self.communication_data_type = get_communication_data_type(param_dict) + self.seq_parallel_communication_data_type = get_communication_data_type(param_dict, + SEQ_PARALLEL_COMMUNICATION_DATA_TYPE,SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT) self.prescale_gradients = get_prescale_gradients(param_dict) self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict) diff --git a/deepspeed/runtime/constants.py b/deepspeed/runtime/constants.py index 0bdac2557847..cc493ee007c5 100755 --- a/deepspeed/runtime/constants.py +++ b/deepspeed/runtime/constants.py @@ -223,6 +223,19 @@ COMMUNICATION_DATA_TYPE = "communication_data_type" COMMUNICATION_DATA_TYPE_DEFAULT = None +########################################################### +# Gradient communication data type for sequence parallelism +########################################################### +# Supported types: ['fp16', 'bf16','fp32'] +# Default value is fp32 +# Users can configure in ds_config.json as below example: +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_FORMAT = ''' +Optional comm data type for seq paralleism should be set as: +"seq_parallel_communication_data_type": "fp32" +''' +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE = "seq_parallel_comm_data_type" +SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT = "fp32" + ######################################### # Scale/predivide gradients before allreduce ######################################### diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index fc3f4bdfe9eb..3338b8867fa4 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -1120,7 +1120,7 @@ def _configure_distributed_model(self, model): self.expert_data_parallel_group = groups._get_expert_data_parallel_group_dict() self.sequence_parallel_size = groups._get_sequence_parallel_world_size() if self.sequence_parallel_size > 1: - self.communication_data_type = torch.float32 + self.communication_data_type = self._config.seq_parallel_communication_data_type if not (self.amp_enabled() or is_zero_init_model): self._broadcast_model() From 568ae5a6f2d1830137675e9bc71a6a459aba7c74 Mon Sep 17 00:00:00 2001 From: Sam Ade Jacobs Date: Wed, 25 Oct 2023 09:02:56 -0700 Subject: [PATCH 6/6] Fix formatting --- deepspeed/runtime/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepspeed/runtime/config.py b/deepspeed/runtime/config.py index 274316bacda0..9fb9eba44a38 100755 --- a/deepspeed/runtime/config.py +++ b/deepspeed/runtime/config.py @@ -786,8 +786,8 @@ def _initialize_params(self, param_dict): self.disable_allgather = get_disable_allgather(param_dict) self.communication_data_type = get_communication_data_type(param_dict) - self.seq_parallel_communication_data_type = get_communication_data_type(param_dict, - SEQ_PARALLEL_COMMUNICATION_DATA_TYPE,SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT) + self.seq_parallel_communication_data_type = get_communication_data_type( + param_dict, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE, SEQ_PARALLEL_COMMUNICATION_DATA_TYPE_DEFAULT) self.prescale_gradients = get_prescale_gradients(param_dict) self.gradient_predivide_factor = get_gradient_predivide_factor(param_dict) self.sparse_gradients_enabled = get_sparse_gradients_enabled(param_dict)