Skip to content

Commit

Permalink
Reshape deepspeed checkpoint (#239)
Browse files Browse the repository at this point in the history
* Reshape deepspeed checkpoint

* add checkpoint tests

* Validate input folder

* Tests for tp/pp reshape

* remove debug folders

* fix test_checkpoint_reshaping_empty_dir

* Fix unit tests

* Remove deepspeed checkpoint utils

* Use DS 3D reshaping utils

* convert to bf16

* wip universal chkpt

* rename

* rename

* wip on fragments dealing

* cleanup

* Loading universal checkpoint with reshaping

* all gpu1<->2 reshapes work

* param attrs

* make the tests adaptable to the number of available gpus

* WIP

* WIP

* WIP

* WIP

* Debug functions

* args should be required, don't create another latest file

* Parallelize shard extraction

* close+join pool; add tqdm; comment out noise

* rename

* parameterize

* Parallel slice merging

* Cleanup

* allow inspection on a machine w/o gpus

* test against the right DS branch

* DS size was merged

Co-authored-by: Stas Bekman <stas@stason.org>
  • Loading branch information
tjruwase and stas00 committed Jul 20, 2022
1 parent 7b5f175 commit 0f23a72
Show file tree
Hide file tree
Showing 15 changed files with 1,349 additions and 293 deletions.
2 changes: 2 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ def _add_learning_rate_args(parser):
'(learning rate, warmup iterations, minimum learning '
'rate, maximum number of iterations, and decay style '
'from checkpoint and ignore input arguments.')
group.add_argument('--universal-checkpoint', action='store_true',
help='Loading a universal format checkpoint.')

return parser

Expand Down
17 changes: 15 additions & 2 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
mpu,
print_rank_0,
update_num_microbatches,
utils)
utils,
get_tokenizer)
from megatron.enums import PositionEmbeddingType

_CHECKPOINT_VERSION = None
Expand Down Expand Up @@ -131,6 +132,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['tokens'] = args.consumed_train_tokens
state_dict['checkpoint_info'] = _checkpoint_info()

# DeepSpeed saves the model/optimizer/scheduler
if not args.deepspeed:
Expand Down Expand Up @@ -361,7 +363,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
assert args.consumed_valid_samples == 0
if 'args' in state_dict:
checkpoint_args = state_dict['args']
check_checkpoint_args(checkpoint_args)
if not args.universal_checkpoint:
check_checkpoint_args(checkpoint_args)
args.consumed_train_samples = getattr(checkpoint_args,
'consumed_train_samples', 0)
update_num_microbatches(consumed_samples=args.consumed_train_samples)
Expand Down Expand Up @@ -468,3 +471,13 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print(' successfully loaded {}'.format(checkpoint_name))

return model


def _checkpoint_info():
args = get_args()
tokenizer = get_tokenizer()

return {
"padded_vocab_size": args.padded_vocab_size,
"original_vocab_size": tokenizer.vocab_size,
}
68 changes: 66 additions & 2 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,32 @@ def get_learning_rate_scheduler(optimizer):
return lr_scheduler


def sync_hp_to_lp(optimizer):

optimizer.update_lp_params()

# for n,p in model.named_parameters():
# print(n)

# if p._hp_mapping is not None:
# #print(f'rank {rank} fixing hp for input_layernorm')
# #p._hp_mapping.update_hp()

# hp = p._hp_mapping.hp_fragment



# torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())

# # 3. optim states
# for key in ['exp_avg', 'exp_avg_sq']:
# optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key)
# #print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}')
# torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
# #print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}')



def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
Expand All @@ -386,12 +412,21 @@ def setup_model_and_optimizer(model_provider_func):

if args.deepspeed:
print_rank_0("DeepSpeed is enabled.")
pp = mpu.get_pipeline_model_parallel_world_size()
#pp = mpu.get_pipeline_model_parallel_world_size()

import json
import io
with io.open(args.deepspeed_config, "r", encoding="utf-8") as f:
config = json.load(f)
if args.universal_checkpoint:
config["checkpoint"] = {"load_universal": True}

model, optimizer, _, lr_scheduler = deepspeed.initialize(
model=model[0],
optimizer=optimizer,
lr_scheduler=lr_scheduler,
config=config,
args=args,
lr_scheduler=lr_scheduler
)

assert model.fp16_enabled() == args.fp16, "megatron fp16 config does not match deepspeed"
Expand All @@ -416,8 +451,37 @@ def setup_model_and_optimizer(model_provider_func):
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])


# hp -> lp
if args.deepspeed and args.universal_checkpoint:
sync_hp_to_lp(optimizer)


else:
args.iteration = 0

from .utils import dump_weights
dump_weights(f'{args.universal_checkpoint=}', args.iteration, model, optimizer)

# tp_rank = mpu.get_tensor_model_parallel_rank()
# pp_rank = mpu.get_pipeline_model_parallel_rank()
# dp_rank = mpu.get_data_parallel_rank()
# for n,p in model[0].named_parameters():
# if 'word_embeddings.weight' not in n:
# continue
# if tp_rank == 0 and pp_rank == 0:
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
# if p._hp_mapping is not None:
# hp = p._hp_mapping.hp_fragment
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')

# if tp_rank == 0 and pp_rank == mpu.get_pipeline_model_parallel_world_size() - 1:
# print(f"{tp_rank=}{pp_rank=}{dp_rank=} bf16 {n=} {p[:10]=}")
# if p._hp_mapping is not None:
# hp = p._hp_mapping.hp_fragment
# print(f'{tp_rank=}{pp_rank=}{dp_rank=} fp32 {n=} {hp[:10]=}')


# We only support local DDP with multiple micro-batches.
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
Expand Down
76 changes: 76 additions & 0 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,3 +461,79 @@ def found_kill_switch():
return True
else:
return False

def get_fingerprint_header():
return f"{'min':^13} {'max':^13} {'mean':^13} {'l2 norm':^12} metadata"

def get_fingerprint(p):
return f"{p.min():13.6e} {p.max():13.6e} {p.mean():13.6e} {p.norm():12.6e}"


def dump_weights(preamble, iteration, model, optimizer, tensor=None):
tp_rank = mpu.get_tensor_model_parallel_rank()
pp_rank = mpu.get_pipeline_model_parallel_rank()
dp_rank = mpu.get_data_parallel_rank()
dp_size = mpu.get_data_parallel_world_size()
fn = f"debug-bf16-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"

# only care for first and last pp stages and dp0 tp0
#if not (mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage()):
# return

#if not (tp_rank == 0 and dp_rank == 0):
# return

if tensor is not None:
orig_tensor = tensor
if hasattr(tensor, "_hp_param"):
numel = tensor._hp_param.numel() # // dp_size
tensor = tensor.flatten().narrow(0, 0, numel)

#print(fn)
with open(fn, "w") as fh:
fh.write(f"{get_fingerprint_header()}\n")

if tensor is not None:
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
else:
for n, p in model[0].named_parameters():
fh.write(f"{get_fingerprint(p)} {n} {p.shape}\n")


return


# until we figure out how to dump the actual fp32 values don't do this
fn = f"debug-fp32-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-{preamble}.txt"
with open(fn, "w") as fh:
fh.write(f"{get_fingerprint_header()}\n")
if tensor is not None:
tensor = orig_tensor
if hasattr(tensor, "_hp_param"):
fh.write(f"{get_fingerprint(tensor._hp_param)} tensor {tensor._hp_param.shape}\n")
#fh.write(f"{get_fingerprint(tensor._hp_grad)} tensor grad\n")
else:
fh.write(f"{get_fingerprint(tensor)} tensor {tensor.shape}\n")
#fh.write(f"{get_fingerprint(tensor.grad)} tensor grad\n")

else:
if hasattr(model[0].module.tied_modules, "embed"):
p = model[0].module.tied_modules.embed.word_embeddings.weight._hp_param
fh.write(f"{get_fingerprint(p)} module.tied_modules.embed.word_embeddings.weight._hp_param {p.shape}\n")

# for i, param_group in enumerate(optimizer.param_groups):
# fh.write(f"{get_fingerprint(optimizer.fp32_groups_flat_partition[i])} group={i}\n")
#fh.write(f"{i}={optimizer.fp32_groups_flat_partition[i]}\n")
# if mpu.is_pipeline_first_stage():
# x = optimizer.fp32_groups_flat_partition[0]
# fh.write(f"fp32={x[:402432]}\n")
# if mpu.is_pipeline_last_stage()):
# x = optimizer.fp32_groups_flat_partition[1]
# fh.write(f"fp32={x[-402432:]}\n")

# import os
# import socket
# hostname = socket.gethostname()
# pid = os.getpid()
# global_rank = torch.distributed.get_rank()
#fn = f"debug-{iteration}-pp{pp_rank}-tp{tp_rank}-dp{dp_rank}-global{global_rank}-{preamble}-{pid}.txt"
65 changes: 40 additions & 25 deletions run_bf16.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,58 @@ DATETIME=`date +'date_%y-%m-%d_time_%H-%M-%S'`
#DATASET_3="<PATH TO THE THIRD DATASET>"
#DATASET="0.2 ${DATASET_1} 0.3 ${DATASET_2} 0.5 ${DATASET_3}"

BASE_DATA_PATH=/data/Megatron-LM/data
#BASE_DATA_PATH=tests/data/gpt2
#DATASET=${BASE_DATA_PATH}/meg-gpt2-openwebtext_text_document
#VOCAB_PATH=${BASE_DATA_PATH}/gpt2-tiny-vocab.json
#MERGE_PATH=${BASE_DATA_PATH}/gpt2-tiny-merges.txt

BASE_DATA_PATH=/vc_data/Megatron-LM/data
DATASET=${BASE_DATA_PATH}/indexed_datasets/megatron
VOCAB_PATH=${BASE_DATA_PATH}/gpt2-vocab.json
MERGE_PATH=${BASE_DATA_PATH}/gpt2-merges.txt


script_path=$(realpath $0)
script_dir=$(dirname $script_path)
#CONFIG_JSON="$script_dir/ds_config.json"
CONFIG_JSON="/tmp/ds_config.json"
CONFIG_JSON="$script_dir/ds_config.json"
#CONFIG_JSON="/tmp/ds_config.json"

USE_DEEPSPEED=1
ZERO_STAGE=0


# Debug
#TP=4
#PP=4
#LAYERS=8
#HIDDEN=512
#SEQ=1024
#GLOBAL_BATCH=128
#WORKER_STR="-i worker-0"


TP=1
PP=1
DP=2
# Debug
DEBUG_MODE=0
if [[ $DEBUG_MODE == 1 ]]; then
LAYERS=4
HIDDEN=512
SEQ=512
EXIT_INTERVAL=3
else
HIDDEN=1024
LAYERS=24
SEQ=1024
EXIT_INTERVAL=10
fi

TP=2
PP=2
DP=4
WORLD_SIZE=$((TP*PP*DP))
HIDDEN=1024
LAYERS=24
SEQ=1024
GLOBAL_BATCH=1
WORKER_STR=""
GLOBAL_BATCH=4

MICRO_BATCH=1
TRAIN_ITERS=100000
CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}
LOAD_CHECKPOINT_PATH=checkpoints/gpt2/tp${TP}_pp${PP}_dp${DP}

LR=6.0e-4
MIN_LR=6.0e-5
DTYPE="bf16"
EXP_DIR=${HOME}/experiments/results/bf16
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_fix3"
EXP_DIR=${HOME}/experiments/results/ckpt_reshape
LOG_DIR="${EXP_DIR}/tensorboard/tp${TP}_pp${PP}_dp${DP}_hd${HIDDEN}_nl${LAYERS}_gbsz${GLOBAL_BATCH}_mbsz${MICRO_BATCH}_z${ZERO_STAGE}_LR_${LR}_${MIN_LR}_${DTYPE}_cont"
mkdir -p $LOG_DIR

while [[ $# -gt 0 ]]
Expand Down Expand Up @@ -89,7 +99,7 @@ options=" \
--max-position-embeddings $SEQ \
--micro-batch-size $MICRO_BATCH \
--global-batch-size $GLOBAL_BATCH \
--train-iters 1000 \
--train-iters $TRAIN_ITERS \
--lr $LR \
--min-lr $MIN_LR \
--lr-decay-style cosine \
Expand All @@ -99,7 +109,7 @@ options=" \
--data-path ${DATASET} \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--save-interval 10000 \
--save-interval 1000 \
--split 98,2,0 \
--clip-grad 1.0 \
--weight-decay 0.1 \
Expand All @@ -108,7 +118,12 @@ options=" \
--init-method-std 0.006 \
--${DTYPE} \
--checkpoint-activations \
--exit-interval 10000 \
--exit-interval ${EXIT_INTERVAL} \
--save ${CHECKPOINT_PATH} \
--load ${LOAD_CHECKPOINT_PATH} \
--position-embedding-type alibi \
--override-lr-scheduler \
--embed-layernorm \
--tensorboard-dir $LOG_DIR
"

Expand Down Expand Up @@ -151,7 +166,7 @@ cat <<EOT > $CONFIG_JSON
}
EOT

WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
#WORKER_STR="--num_nodes 1 --num_gpus $WORLD_SIZE"
#WORKER_STR="-i worker-0:0,1,2,3"
#run_cmd="deepspeed -i worker-0:0,1,2,3 ${DIR}/pretrain_gpt.py $@ ${options}"
#run_cmd="deepspeed -i worker-0 ${DIR}/pretrain_gpt.py $@ ${options}"
Expand Down
Loading

0 comments on commit 0f23a72

Please sign in to comment.