Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute model param count once #204

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open

Conversation

jaketae
Copy link
Member

@jaketae jaketae commented Nov 24, 2021

This PR fixes #203 by using the args global variable holder to save and access model parameter counts during gigaflops counting. This is sensible given that the number of model parameters stays constant throughout the training iteration, rendering it unnecessary to call the model param count function at every iteration.

Additionally fixes: #123

@jaketae jaketae marked this pull request as ready for review November 24, 2021 06:55
@jaketae jaketae requested a review from stas00 November 24, 2021 06:58
@jaketae jaketae mentioned this pull request Nov 24, 2021
megatron/training.py Outdated Show resolved Hide resolved
megatron/training.py Outdated Show resolved Hide resolved
@stas00
Copy link
Member

stas00 commented Nov 24, 2021

@TevenLeScao , while we are at it why do we print:

  1. estimated model parameters:
  2. estimated model parameters without embeddings:

for the whole model?

What's the practical point of the first one? I'm trying to think what can be done with this information? I surely am missing something...

Is this info of any practical use for the pipe stages? when we print the same per process?

And I think the 2nd one is misleading, as it's not w/o embeddings. It's without repeat count of tied params.

May I suggest that it says instead:

estimated unique model parameters.

Thanks.


and if agreed for the latter then we can adjust the constant name to be args.uniq_parameters_in_billions

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
megatron/training.py Outdated Show resolved Hide resolved
Co-authored-by: Jake Tae <jaesungtae@gmail.com>
@TevenLeScao
Copy link
Collaborator

Hey ! The number without embeddings is actually just that: the number of non-embedding parameters, not the number of unique parameters. This is the relevant number to estimate the loss against, which is why we're using it in the tensorboards for arch-and-scaling.

I wanted to also print exact model parameters (including embeddings and withotu double-counts) but haven't managed to get unique counting to work.

@stas00
Copy link
Member

stas00 commented Nov 25, 2021

OK, then something is very wrong in the reporting. e.g. for the 104B model it prints:

estimated model parameters without embeddings: 103.368064
estimated model parameters: 125.22432

The formula we use comes from https://github.com/bigscience-workshop/bigscience/blob/master/experiments/gpt2-utils.md#calculate-model-size

For:

NLAYERS=64
NHIDDEN=11600
SEQ_LEN=2048
VOCAB_SIZE=50257
# w/ embeddings
perl -le 'print 64*(12*11600**2+13*11600) + 50257*11600+2048*11600 + 2*11600'
103958492400

# w/o embeddings
perl -le 'print 64*(12*11600**2+13*11600) + 2*11600'
103351754400

The difference is 60M params, so embeddings params are ignored when doing math by hand and the simplified formula is used instead 12*layers*hidden**2.

So 103.368064 from estimated model parameters without embeddings: 103.368064 matches quite closely.

But there is no 125B anywhere in sight, so the estimated model parameters: 125.22432 appears to be very wrong. i.e. the last number has to be ~104B and not 125B to match the math. We have 21B extraneous params.

@stas00
Copy link
Member

stas00 commented Nov 29, 2021

@TevenLeScao, I looked deeper and we have wrong counting for PP, e.g. see:

First let's do a manual approximate math for this config:


NLAYERS=8
NHIDDEN=512
SEQ_LEN=1024
VOCAB_SIZE=50257

EMB_PARAMS=$((VOCAB_SIZE * NHIDDEN + SEQ_LEN * NHIDDEN))
BLOCKS_PARAMS=$((NLAYERS * (12 * NHIDDEN**2 + 13 * NHIDDEN)))
echo yes-emb param count $EMB_PARAMS
echo non-emb param count $BLOCKS_PARAMS

gives:

yes-emb param count 26255872
non-emb param count 25219072
sum: 51474944

So we know we are dealing with a 52M model, with about half the params taken by word embeddings.

Now run the start of the training and look at the logs (filtered out all but the relevant parts). The upcase logs are from deepspeed's pipe, the last line is ours.

# 1 gpu: tp=1 pp=1

[2021-11-28 18:33:47,870] [INFO] [engine.py:151:__init__] RANK=0 STAGE=0 LAYERS=15 [0, 15) STAGE_PARAMS=51500032 (51.500M) 
TOTAL_PARAMS=51500032 (51.500M) UNIQUE_PARAMS=51500032 (51.500M)

estimated model parameters: 0.051500032

# 2 gpus: tp=2 pp=1

[2021-11-28 18:33:00,537] [INFO] [engine.py:151:__init__] RANK=0 STAGE=0 LAYERS=15 [0, 15) STAGE_PARAMS=26057728 (26.058M) 
TOTAL_PARAMS=52115456 (52.115M) UNIQUE_PARAMS=52115456 (52.115M)
[2021-11-28 18:33:00,537] [INFO] [engine.py:151:__init__] RANK=1 STAGE=0 LAYERS=15 [0, 15) STAGE_PARAMS=26057728 (26.058M) 
TOTAL_PARAMS=52115456 (52.115M) UNIQUE_PARAMS=52115456 (52.115M)

estimated model parameters: 0.052115456


# 2 gpus: tp=1 pp=2

[2021-11-28 18:31:10,687] [INFO] [engine.py:151:__init__] RANK=0 STAGE=0 LAYERS=7 [0, 7) STAGE_PARAMS=38889472 (38.889M) 
TOTAL_PARAMS=77779968 (77.780M) UNIQUE_PARAMS=51500032 (51.500M)
[2021-11-28 18:31:10,687] [INFO] [engine.py:151:__init__] RANK=1 STAGE=1 LAYERS=8 [7, 15) STAGE_PARAMS=38890496 (38.890M) 
TOTAL_PARAMS=77779968 (77.780M) UNIQUE_PARAMS=51500032 (51.500M)

estimated model parameters: 0.077778944

So for gpu=1 and gpus=2/tp=2 all is good, but once pp>1 is involved it's wrong. It reports 0.077778944 but should be 0.051500032

Here is the code that we need to borrow to do it correctly:

https://github.com/microsoft/DeepSpeed/blob/7a132a9f4b37959f951b7c04a05207aba6054965/deepspeed/runtime/pipe/engine.py#L134-L157

@jaketae
Copy link
Member Author

jaketae commented Nov 29, 2021

@stas00 @TevenLeScao I assume the PP > 1 problem is related to the warning in get_parameters_in_billions?

warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings")

And I'm also suspecting if this is the last untied knot in #40, which was followed up by #99. Just trying to put the pieces together to add clarity!

@stas00
Copy link
Member

stas00 commented Nov 29, 2021

yes, and I pointed to the code that leads to correct data!

The current 20 to 50% over-reported size leads to very different results over what is really happening.

Do you want to tackle that, @jaketae?

@jaketae
Copy link
Member Author

jaketae commented Nov 29, 2021

Yes, I'd love to take it. Will take a look at the DS reference code today. Thanks!

@stas00
Copy link
Member

stas00 commented Nov 29, 2021

And to debug while you're working at it, you may choose the same as I did here that is tweaking:

N_GPUS=2
TP_SIZE=2
PP_SIZE=1

to 3 different set ups and checking that the deepspeed (upcase) log matches ours.

@jaketae
Copy link
Member Author

jaketae commented Dec 1, 2021

@stas00 I'm trying to test the code, but haven't figured out what the best way to go about it is. Do you modify entries in run.sh and execute the script? Wondering how to best setup an environment where I can directly compare the output from DeepSpeed and Meg-DS side by side. Thank you!

@stas00
Copy link
Member

stas00 commented Dec 1, 2021

both outputs already show up in the same log file. Please see #204 (comment) - I just filtered out that information from the rest of the logged info.

You just change the setup in the script, yes.

I don't know what run.sh is, I use the following script:

CHECKPOINT_PATH=checkpoints/gpt2

VOCAB_FILE=data/gpt2-vocab.json
MERGE_FILE=data/gpt2-merges.txt
#DATA_PATH=data/meg-gpt2_text_document
DATA_PATH=data/meg-gpt2_oscar-combined_text_document
TENSORBOARD_PATH=output_dir/tensorboard

N_GPUS=2
MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=16
TP_SIZE=2
PP_SIZE=1

NLAYERS=2
NHIDDEN=64
NHEADS=2
SEQ_LEN=1024
VOCAB_SIZE=50257

SAVE_INTERVAL=50

GPT_ARGS=" \
    --num-layers $NLAYERS \
    --hidden-size $NHIDDEN \
    --num-attention-heads $NHEADS \
    --seq-length $SEQ_LEN \
    --max-position-embeddings $SEQ_LEN \
    --micro-batch-size $MICRO_BATCH_SIZE \
    --rampup-batch-size 2 2 1_000 \
    --global-batch-size $GLOBAL_BATCH_SIZE \
    --train-samples 100 \
    --optimizer adam \
    --adam-beta1 0.9 \
    --adam-beta2 0.95 \
    --adam-eps 1e-8 \
    --lr 1e-4 \
    --lr-warmup-samples 5 \
    --clip-grad 1.0 \
    --weight-decay 1e-1 \
    --fp16 \
    --partition-activations \
    --seed 42 \
    --vocab-file $VOCAB_FILE \
    --merge-file $MERGE_FILE \
    "

OUTPUT_ARGS=" \
    --exit-interval 200 \
    --log-interval 10 \
    --save-interval $SAVE_INTERVAL \
    --eval-interval 100 \
    --eval-iters 10 \
    --checkpoint-activations \
    "

DATA_ARGS=" \
    --save $CHECKPOINT_PATH \
    --load $CHECKPOINT_PATH \
    --data-path $DATA_PATH \
    --tensorboard-dir $TENSORBOARD_PATH \
    --tensorboard-queue-size 5 \
    --log-timers-to-tensorboard \
    --log-batch-size-to-tensorboard \
    --log-validation-ppl-to-tensorboard \
    "


ZERO_STAGE=1

config_json="./ds_config.json"

# Deepspeed figures out GAS dynamically from dynamic GBS via set_train_batch_size()
cat <<EOT > $config_json
{
  "train_micro_batch_size_per_gpu": $MICRO_BATCH_SIZE,
  "train_batch_size": $GLOBAL_BATCH_SIZE,
  "gradient_clipping": 1.0,
  "zero_optimization": {
    "stage": $ZERO_STAGE
  },
  "fp16": {
    "enabled": true,
    "loss_scale": 0,
    "loss_scale_window": 500,
    "hysteresis": 2,
    "min_loss_scale": 1,
    "initial_scale_power": 12
  },
  "steps_per_print": 2000,
  "wall_clock_breakdown": false
}
EOT


DEEPSPEED_ARGS=" \
    --deepspeed \
    --deepspeed_config ${config_json} \
    --zero-stage ${ZERO_STAGE} \
    --deepspeed-activation-checkpointing \
    "

ALL_ARGS="$GPT_ARGS $OUTPUT_ARGS $DATA_ARGS $DEEPSPEED_ARGS"

# if you can't stand pt-1.9 launcher noise
export LOGLEVEL=WARNING

PYTHONPATH=/hf/Megatron-DeepSpeed-master

MASTER_ADDR=localhost
MASTER_PORT=6777

export LAUNCHER="deepspeed --num_gpus $N_GPUS --master_port $MASTER_PORT"

# export LAUNCHER="python -u -m torch.distributed.run \
#     --nproc_per_node $N_GPUS \
#     --rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
#     --rdzv_backend c10d \
#     --max_restarts 0 \
#     "

# export LAUNCHER="python -u -m torch.distributed.launch \
#     --nproc_per_node $N_GPUS \
#     --master_addr $MASTER_ADDR \
#     --master_port $MASTER_PORT \
#     "


export CMD=" \
    env PYTHONPATH=$PYTHONPATH USE_TF=0 \
    $LAUNCHER pretrain_gpt.py \
    --tensor-model-parallel-size $TP_SIZE \
    --pipeline-model-parallel-size $PP_SIZE \
    --distributed-backend nccl \
    $ALL_ARGS \
    "

echo $CMD

#rm -rf $CHECKPOINT_PATH
$CMD

If you want to use it, you will need to adjusts paths as they are hardcoded to my setup.

But the main point I'm trying to convey is that I just tweaked this section of the above script 3 times as described in my comment I linked to above and re-run the script.

N_GPUS=2
MICRO_BATCH_SIZE=1
GLOBAL_BATCH_SIZE=16
TP_SIZE=2
PP_SIZE=1

adammoody pushed a commit to adammoody/Megatron-DeepSpeed that referenced this pull request Dec 18, 2023
Co-authored-by: Alexander Jipa <azzhipa@amazon.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Avoid re-computing model parameter count every iteration Need model size dumped at init
3 participants