You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In neox_arguments/arguments.py, lines 1032-1035, there is a contradictory setting based on pipe_parallel_size; pipe_parallel_size = 1 actually goes down the neox_args.is_pipe_parallel forward/backward/step codepath in megatron/training.py/train_step
In neox_arguments/arguments.py:
# Update 'is pipe parallel' flag# if we set pipe_parallel_size to 0 or 1, GPT2ModelPipe.to_sequential() is called, and we run training with# the sequential model without the PipelineModule wrapper to avoid the overhead it incursself.update_value("is_pipe_parallel", self.pipe_parallel_size>=1)
In megatron/training.py:
deftrain_step(neox_args, timers, data_iterator, model, optimizer, lr_scheduler):
"""Single training step."""# Pipeline parallelism schedules forward/backward/stepifneox_args.is_pipe_parallel:
reduced_loss=train_step_pipe(
neox_args=neox_args, timers=timers, model=model, data_iterator=data_iterator
)
if (
neox_args.memory_profilingandneox_args.iteration>=neox_args.profile_step_startandneox_args.iteration<=neox_args.profile_step_stopandtorch.distributed.get_rank() ==0
):
save_snapshot(neox_args)
else:
losses= []
for_inrange(neox_args.gradient_accumulation_steps):
# Forward model for one step.timers("forward").start()
loss=forward_step(
This was also verified by running a training job with pipe_parallel_size = 1 and checking that the model class is class 'deepspeed.runtime.pipe.engine.PipelineEngine'
Potentially relevant to the non-pipeline parallel MoE work that @yang has been doing because pipe_parallel_size = 1 will currently still cause the training run to use the DeepSpeed PipelineEngine.
Expected behavior
pipe_parallel_size = 1 should probably not use the DeepSpeed PipelineEngine for the reasons the comment describes (higher overhead).
Proposed solution
2 potential solutions:
Clarifying in the comment that pipe_parallel_size = 0/omission of pipe_parallel_size arg is the only way to get the non PipelineEngine runtime. It looks like most of the provided configs in /configs are already setting pipe_parallel_size = 1 by default, so this might not be the preferable solution.
Changing the conditional to self.update_value("is_pipe_parallel", self.pipe_parallel_size > 1) instead of >= 1.
Environment (please complete the following information):
Configs:
# GPT-2 pretraining setup{# parallelism settings ( you will want to change these based on your cluster setup, ideally scheduling pipeline stages# across the node boundaries )"pipe_parallel_size": 1,"model_parallel_size": 2,# model settings"num_layers": 12,"hidden_size": 1024,"num_attention_heads": 16,"seq_length": 2048,"max_position_embeddings": 2048,"norm": "layernorm","pos_emb": "rotary","no_weight_tying": true,# moe settings"moe_num_experts": 8,# these should provide some speedup but takes a while to build, set to true if desired"scaled_upper_triang_masked_softmax_fusion": false,"bias_gelu_fusion": false,"rope_fusion": false,"layernorm_fusion": false,# optimizer settings"optimizer": {"type": "Adam","params": {"lr": 0.0006,"betas": [0.9, 0.999],"eps": 1.0e-8,}},# for all zero_optimization options, see https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training"zero_optimization": {"stage": 0,"allgather_partitions": True,"allgather_bucket_size": 500000000,"overlap_comm": True,"reduce_scatter": True,"reduce_bucket_size": 500000000,"contiguous_gradients": True,},# batch / data settings"train_micro_batch_size_per_gpu": 4,"data_impl": "mmap","split": "949,50,1",# activation checkpointing"checkpoint_activations": true,"checkpoint_num_layers": 1,"partition_activations": true,"synchronize_each_layer": true,# regularization"gradient_clipping": 1.0,"weight_decay": 0.0,"hidden_dropout": 0.0,"attention_dropout": 0.0,"precision": "bfloat16","fp32_allreduce": True, # without a patch to torch, bf16 models have to do the allreduce in fp32# misc. training settings"train_iters": 5,"lr_decay_iters": 320000,"distributed_backend": "nccl","min_lr": 0.0006,# "lr_decay_style": "cosine","warmup": 0.0,"checkpoint_factor": 10000,"eval_interval": 1000,"eval_iters": 10,# logging"log_interval": 1,"steps_per_print": 1,"keep_last_n_checkpoints": 4,"wall_clock_breakdown": true,}
The text was updated successfully, but these errors were encountered:
This is not a big. Using pp = 1 deliberately creates a pipeline model with one stage. To avoid using the pipe classes, you need to specify pp = 0. This was done because we found that even with one pipeline stage using the pipe classes was slightly faster than using the regular ones.
Describe the bug
In
neox_arguments/arguments.py
, lines 1032-1035, there is a contradictory setting based onpipe_parallel_size
;pipe_parallel_size = 1
actually goes down theneox_args.is_pipe_parallel
forward/backward/step codepath inmegatron/training.py/train_step
In
neox_arguments/arguments.py
:In
megatron/training.py
:This was also verified by running a training job with
pipe_parallel_size = 1
and checking that the model class isclass 'deepspeed.runtime.pipe.engine.PipelineEngine'
Potentially relevant to the non-pipeline parallel MoE work that @yang has been doing because
pipe_parallel_size = 1
will currently still cause the training run to use the DeepSpeedPipelineEngine
.Expected behavior
pipe_parallel_size = 1
should probably not use the DeepSpeedPipelineEngine
for the reasons the comment describes (higher overhead).Proposed solution
2 potential solutions:
pipe_parallel_size = 0
/omission ofpipe_parallel_size
arg is the only way to get the nonPipelineEngine
runtime. It looks like most of the provided configs in/configs
are already settingpipe_parallel_size = 1
by default, so this might not be the preferable solution.self.update_value("is_pipe_parallel", self.pipe_parallel_size > 1)
instead of>= 1
.Environment (please complete the following information):
The text was updated successfully, but these errors were encountered: