Skip to content

Commit

Permalink
Merge pull request #10 from jeffra/megatron-deepspeed-pipedev
Browse files Browse the repository at this point in the history
3D parallelism development
  • Loading branch information
jeffra committed Feb 4, 2021
2 parents ccd6b2a + ebd60b3 commit 89e08d2
Show file tree
Hide file tree
Showing 14 changed files with 548 additions and 29 deletions.
16 changes: 16 additions & 0 deletions examples/ds_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"train_batch_size": 256,
"train_micro_batch_size_per_gpu": 4,
"steps_per_print": 10,
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,

"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"wall_clock_breakdown": true,
"zero_allow_untested_optimizer": false
}
140 changes: 140 additions & 0 deletions examples/ds_pretrain_gpt2_pipe.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#! /bin/bash

GPUS_PER_NODE=16
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

export DLWS_NUM_WORKER=${NNODES}
export DLWS_NUM_GPU_PER_WORKER=${GPUS_PER_NODE}

DATA_PATH=data/webtext/webtext_text_document
VOCAB_PATH=data/gpt2-vocab.json
MERGE_PATH=data/gpt2-merges.txt
CHECKPOINT_PATH=checkpoints/gpt2_345m_ds

script_path=$(realpath $0)
script_dir=$(dirname $script_path)
#config_json="$script_dir/ds_zero_stage_2_config.json"
config_json="$script_dir/ds_config.json"

# Megatron Model Parallelism
mp_size=2
# DeepSpeed Pipeline parallelism
pp_size=2

NLAYERS=24
NHIDDEN=1024
BATCHSIZE=4
LOGDIR="tensorboard_data/${NLAYERS}l_${NHIDDEN}h_${NNODES}n_${GPUS_PER_NODE}g_${pp_size}pp_${mp_size}mp_${BATCHSIZE}b_ds4"

GAS=16

#ZeRO Configs
stage=0
reduce_scatter=true
contigious_gradients=true
rbs=50000000
agbs=5000000000

#Actication Checkpointing and Contigious Memory
chkp_layers=1
PA=true
PA_CPU=false
CC=true
SYNCHRONIZE=true
PROFILE=false


gpt_options=" \
--model-parallel-size ${mp_size} \
--pipe-parallel-size ${pp_size} \
--num-layers $NLAYERS \
--hidden-size $NHIDDEN \
--num-attention-heads 16 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--batch-size $BATCHSIZE \
--gas $GAS \
--train-iters 320000 \
--lr-decay-iters 320000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file $VOCAB_PATH \
--merge-file $MERGE_PATH \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 1.5e-4 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup 0.01 \
--checkpoint-activations \
--log-interval 1 \
--save-interval 500 \
--eval-interval 100 \
--eval-iters 10 \
--fp16 \
--tensorboard-dir ${LOGDIR}
"

deepspeed_options=" \
--deepspeed \
--deepspeed_config ${config_json} \
--zero-stage ${stage} \
--zero-reduce-bucket-size ${rbs} \
--zero-allgather-bucket-size ${agbs}
"

if [ "${contigious_gradients}" = "true" ]; then
deepspeed_options="${deepspeed_options} \
--zero-contigious-gradients"
fi

if [ "${reduce_scatter}" = "true" ]; then
deepspeed_options="${deepspeed_options} \
--zero-reduce-scatter"
fi

chkp_opt=" \
--checkpoint-activations \
--checkpoint-num-layers ${chkp_layers}"

if [ "${PA}" = "true" ]; then
chkp_opt="${chkp_opt} \
--partition-activations"
fi

if [ "${PA_CPU}" = "true" ]; then
chkp_opt="${chkp_opt} \
--checkpoint-in-cpu"
fi

if [ "${SYNCHRONIZE}" = "true" ]; then
chkp_opt="${chkp_opt} \
--synchronize-each-layer"
fi

if [ "${CC}" = "true" ]; then
chkp_opt="${chkp_opt} \
--contigious-checkpointing"
fi

if [ "${PROFILE}" = "true" ]; then
chkp_opt="${chkp_opt} \
--profile-backward"
fi

full_options="${gpt_options} ${deepspeed_options} ${chkp_opt}"

run_cmd="deepspeed --num_nodes ${DLWS_NUM_WORKER} --num_gpus ${DLWS_NUM_GPU_PER_WORKER} pretrain_gpt2.py $@ ${full_options}"
echo ${run_cmd}
eval ${run_cmd}

set +x
6 changes: 6 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ def _add_training_args(parser):
help='Batch size per model instance (local batch size). '
'Global batch size is local batch size times data '
'parallel size.')
group.add_argument('--gas', type=int, default=1,
help='Gradient accumulation steps (pipeline parallelism only). '
'Global batch size is local batch size times data '
'parallel size times gas.')
group.add_argument('--checkpoint-activations', action='store_true',
help='Checkpoint activation to allow for training '
'with larger models, sequences, and batch sizes.')
Expand Down Expand Up @@ -373,6 +377,8 @@ def _add_distributed_args(parser):

group.add_argument('--model-parallel-size', type=int, default=1,
help='Size of the model parallel.')
group.add_argument('--pipe-parallel-size', type=int, default=0,
help='Size of the pipeline parallel. Disable with 0.')
group.add_argument('--distributed-backend', default='nccl',
choices=['nccl', 'gloo'],
help='Which backend to use for distributed training.')
Expand Down
11 changes: 8 additions & 3 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,14 @@ def save_ds_checkpoint(iteration, model, args):
sd['cuda_rng_state'] = torch.cuda.get_rng_state()
sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

#megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
#state_dict is used by deepspeed for module saving so it needs to point to the right function
model.module.state_dict = model.module.state_dict_for_save_checkpoint
if args.pipe_parallel_size == 0:
#megatron model uses state_dict_for_save_checkpointing instead of the standard state_dict
#state_dict is used by deepspeed for module saving so it needs to point to the right function
model.module.state_dict = model.module.state_dict_for_save_checkpoint
else:
# Pipeline parallelism manages its own state_dict.
pass

model.save_checkpoint(args.save, client_state=sd)


Expand Down
4 changes: 2 additions & 2 deletions megatron/data/gpt2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
# device_index=rank which is not the case for model
# parallel case
counts = torch.cuda.LongTensor([1])
torch.distributed.all_reduce(counts, group=mpu.get_data_parallel_group())
torch.distributed.all_reduce(counts, group=mpu.get_io_parallel_group())
assert counts[0].item() == torch.distributed.get_world_size(
group=mpu.get_data_parallel_group())
group=mpu.get_io_parallel_group())

# Load mappings.
start_time = time.time()
Expand Down
22 changes: 21 additions & 1 deletion megatron/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,32 @@ def _initialize_distributed():
world_size=args.world_size, rank=args.rank,
init_method=init_method)

# Setup 3D topology.
if args.pipe_parallel_size > 0:
pp = args.pipe_parallel_size
mp = args.model_parallel_size
assert args.world_size % (pp * mp) == 0
dp = args.world_size // (pp * mp)

from deepspeed.runtime.pipe.topology import PipeModelDataParallelTopology
topo = PipeModelDataParallelTopology(num_pp=pp, num_mp=mp, num_dp=dp)

# Offset base seeds for the interior pipeline stages.
# TODO: adjust last stage too once IO is improved.
stage_id = topo.get_coord(rank=torch.distributed.get_rank()).pipe
if 0 < stage_id < topo.get_dim('pipe') - 1:
offset = args.seed + 1138
args.seed = offset + (stage_id * mp)
else:
topo = None


# Set the model-parallel / data-parallel communicators.
if device_count > 0:
if mpu.model_parallel_is_initialized():
print('model parallel is already initialized')
else:
mpu.initialize_model_parallel(args.model_parallel_size)
mpu.initialize_model_parallel(args.model_parallel_size, topology=topo)

# Optional DeepSpeed Activation Checkpointing Features
#
Expand Down
2 changes: 1 addition & 1 deletion megatron/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@
from .distributed import *
from .bert_model import BertModel
from .realm_model import ICTBertModel
from .gpt2_model import GPT2Model
from .gpt2_model import GPT2Model, GPT2ModelPipe
from .utils import get_params_for_weight_decay_optimization
from .language_model import get_language_model
123 changes: 123 additions & 0 deletions megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,34 @@
from .utils import init_method_normal
from .utils import scaled_init_method_normal

# Pipeline parallelism
from megatron import mpu
import torch.nn.functional as F
import torch.nn.functional as F
from apex.normalization.fused_layer_norm import FusedLayerNorm as LayerNorm
import megatron.fp16 as fp16
from megatron.model.transformer import ParallelTransformerLayerPipe
from .language_model import EmbeddingPipe

from deepspeed.pipe import PipelineModule, LayerSpec, TiedLayerSpec


def gpt2_attention_mask_func(attention_scores, ltor_mask):
attention_scores.masked_fill_(ltor_mask, -10000.0)
return attention_scores


def CrossEntropy(output, labels):
""" From pretrain_gpt2:forward_step() """
labels, loss_mask = labels[0], labels[1]

losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), labels)
loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
return loss



class GPT2Model(MegatronModule):
"""GPT-2 Language model."""

Expand Down Expand Up @@ -103,3 +125,104 @@ def load_state_dict(self, state_dict, strict=True):
if self._language_model_key in state_dict:
state_dict = state_dict[self._language_model_key]
self.language_model.load_state_dict(state_dict, strict=strict)


class GPT2ModelPipe(PipelineModule,MegatronModule):
"""GPT2Model adapted for pipeline parallelism.
The largest change is flattening the GPTModel class so we can express it as a
sequence of layers including embedding, transformer layers, and output.
"""

def __init__(self, num_tokentypes=0, parallel_output=True, add_pooler=False, topology=None):
args = get_args()

self.parallel_output = parallel_output
self.hidden_size = args.hidden_size
self.num_tokentypes = num_tokentypes
self.init_method = init_method_normal(args.init_method_std)
self.output_layer_init_method = scaled_init_method_normal(args.init_method_std, args.num_layers)
self.add_pooler = add_pooler
if self.add_pooler:
raise NotImplementedError('Pipeline pooler not yet implemented. Forward needs pooling_sequence_index')

# Use torch gelu unless otherwise forced.
gelu = F.gelu
if args.openai_gelu:
gelu = openai_gelu

#
# forward() prototype
#
self.specs = []

# Embedding layer
self.specs.append(TiedLayerSpec('embed',
EmbeddingPipe,
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
tied_weight_attr='word_embeddings_weight'))

# outputs are now (hidden_states, attention_mask)

# data format change to avoid explicit tranposes : [b s h] --> [s b h]
self.specs.append(lambda x: (x[0].transpose(0,1).contiguous(), x[1]))

# Transformer layers
for x in range(args.num_layers):
self.specs.append(
LayerSpec(ParallelTransformerLayerPipe,
attention_mask_func=gpt2_attention_mask_func,
init_method=self.init_method,
output_layer_init_method=self.output_layer_init_method,
layer_number=x))
# Undo data format change and drop mask
self.specs.append(lambda x: x[0].transpose(0,1).contiguous())


# Final layernorm after transformer layers
self.specs.append(
LayerSpec(LayerNorm,
args.hidden_size,
eps=args.layernorm_epsilon))

# XXX forward_method_parallel_output is assumed to be None, but we're not in a
# fwd method to assert

def _logits_helper(embedding, lm_output):
"""Just a wrapper to massage inputs/outputs from pipeline. """
return parallel_lm_logits(
lm_output,
embedding.word_embeddings_weight,
self.parallel_output)

self.specs.append(
TiedLayerSpec('embed',
EmbeddingPipe,
self.hidden_size,
args.padded_vocab_size,
args.max_position_embeddings,
args.hidden_dropout,
self.init_method,
self.num_tokentypes,
forward_fn=_logits_helper,
tied_weight_attr='word_embeddings_weight')
)

# Should maybe be done in loss_fn() instead?
if args.fp16:
self.specs.append(fp16.fp16_to_fp32)

if args.checkpoint_activations:
interval = args.checkpoint_num_layers
else:
interval = 0
super().__init__(layers=self.specs,
loss_fn=CrossEntropy,
topology=topology,
activation_checkpoint_interval=interval,
partition_method='type:transformer')

0 comments on commit 89e08d2

Please sign in to comment.