Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Curriculum learning support #132

Merged
merged 20 commits into from
Oct 10, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions examples/curriculum_learning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
This is a short tutorial of how to use/tune the curriculum learning (CL) integration. Currently it is only integrated for GPT pre-training. For technical details please refer to our [paper](https://arxiv.org/abs/2108.06084).

# Disable batch size warmup (--rampup-batch-size)
In our [paper](https://arxiv.org/abs/2108.06084) section 5.4 we demonstrate that curriculum learning (seqlen-based) provides much better training stability than the batch size warmup technique. So when using CL you need to remove the `--rampup-batch-size` config in your training script. It's not recommended to use both CL and batch size warmup, because both of them will reduce the number of tokens in a batch. Another related change you might want is to increase your micro batch size, since without batch size warmup your batch size will be fixed now.

# Token-based training termination

Because CL changes length of each sequence/sample during training, it is very hard/impossible to use number of steps/samples to terminate the training exactly at the desired number of tokens. Thus we add a `--train-tokens` config as an alternative accurate token-based termination. We recommend increase your original `--train-samples` or `--train-iters` to a large enough number (e.g., 2X of what you used for baseline), and set `--train-tokens` at the exact desired number of training tokens (e.g., 300B for GPT-3 like training).

# Token-based LR decay

Again because CL changes the number of tokens per batch, in our [paper](https://arxiv.org/abs/2108.06084) Appendix A.2 we show that it is also necessary to change the LR decay to token-based (to avoid decaying LR too fast). Thus we add a `--lr-decay-tokens` which will be the number of LR decay tokens. If previously you were using `--lr-decay-samples`, you can calculate your `--lr-decay-tokens` simply by multiplying the former by full seqlen (e.g. 2K for GPT-3). Then you need to replace `--lr-decay-samples` with `--lr-decay-tokens` in your script. For LR warmup we don't change it to token-based, because doing so for CL means slowing down the LR warmup, which is both unnecessary and harmful.

conglongli marked this conversation as resolved.
Show resolved Hide resolved
# Token-based tensorboard

Because of the above changes, we also add token-based tensorboard scalars. We also add scalars that plot the seqlen at each step.

# Curriculum learning hyperparameters tuning strategy

The curriculum learning hyperparameters are all located in the deepspeed config json file (see the example `ds_config_cl.json` in this dir). There are three config entries that you need to change, and two of which require some tuning. In our [paper](https://arxiv.org/abs/2108.06084) Appendix A.1 we have a more detailed tuning strategy description.

First, the `max_difficulty` should be set as the full seqlen (i.e., your `--seq-length`). No need to tune this.

Second, the `min_difficulty` is the beginning seqlen used by CL. In general smaller `min_difficulty` could provide better stability/convergence speed benefit. However we observe that for a larger model or for different training data, starting from a very small seqlen could lead to significant validation PPL fluctuation (or even divergence) at the very beginning. We recommend to start with `min_difficulty` at 64, and then increase it if you observe problems at the very beginning. Note that to enable Tensor Core acceleration you should always use a multiple of 8.

Third, the `total_curriculum_step` is the total number of steps used by CL. In general larger `total_curriculum_step` could provide better stability/convergence speed benefit. However we observe that a too large `total_curriculum_step` could lead to overfitting and significant validation PPL fluctuation (or even divergence) at the first few multiple of LR warmup steps. In our paper we have a detailed tuning strategy based on binary search. However, if you want to reduce the tuning effort we recommend directly setting `total_curriculum_step` as half of baseline's total number of steps. This may not provide the highest convergence speed benefit, but should provide enough training stability gains.
37 changes: 37 additions & 0 deletions examples/curriculum_learning/ds_config_cl.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
{
"train_batch_size": 512,
"gradient_accumulation_steps": 1,
"steps_per_print": 1,
"zero_optimization": {
"stage": 0
},
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015,
"max_grad_norm": 1.0,
"betas": [0.9, 0.95]
}
},
"gradient_clipping": 1.0,
"fp16": {
"enabled": true,
"loss_scale": 0,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"wall_clock_breakdown": false,
"zero_allow_untested_optimizer": false,
"curriculum_learning": {
"enabled": true,
"curriculum_type": "seqlen",
"min_difficulty": 8,
"max_difficulty": 1024,
"schedule_type": "fixed_linear",
"schedule_config": {
"total_curriculum_step": 60000,
"difficulty_step": 8
}
}
}
103 changes: 103 additions & 0 deletions examples/curriculum_learning/pretrain_gpt_cl.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#!/bin/bash

# This is a dummy train script to show how to use curriculum
# learning, some parameters are not for actual GPT pretraining.

TARGET_GLOBAL_BATCH_SIZE=512
TRAIN_SAMPLES=146_484_375
LR=1.0e-4
MIN_LR=1.0e-5
LR_DECAY_SAMPLES=126_953_125
LR_WARMUP_SAMPLES=183_105
SEQLEN=1024

############################################################
# New configs for curriculum learning, see README.md
TRAIN_TOKENS=10_000_000_000
LR_DECAY_TOKENS=$(($LR_DECAY_SAMPLES*$SEQLEN))
############################################################

LOG_INTERVAL=100
EVAL_ITERS=10
EVAL_INTERVAL=100
SAVE_INTERVAL=1000

VOCAB_PATH=/data/Megatron-LM/data/gpt2-vocab.json
MERGE_PATH=/data/Megatron-LM/data/gpt2-merges.txt
DATA_PATH=/data/Megatron-LM/data/indexed_datasets/megatron

MICRO_BATCH_SIZE=1
MP_SIZE=1
PP_SIZE=1

NUM_GPUS=128
echo ${NUM_GPUS}
if [[ $PP_SIZE -gt 0 ]]; then
DP_SIZE=$(( ${NUM_GPUS} / (${PP_SIZE} * ${MP_SIZE}) ))
else
DP_SIZE=$(( ${NUM_GPUS} / ${MP_SIZE} ))
fi
GRAD_ACC_STEPS=$(( ${TARGET_GLOBAL_BATCH_SIZE} / (${MICRO_BATCH_SIZE} * ${DP_SIZE}) ))

NAME="gpt-117M-pp${PP_SIZE}-mp${MP_SIZE}-bsz${TARGET_GLOBAL_BATCH_SIZE}-mbsz${MICRO_BATCH_SIZE}-cl"
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
host="${HOSTNAME}"
TENSORBOARD_DIR="tensorboard/${NAME}_${host}_${current_time}"
mkdir -p ${TENSORBOARD_DIR}
CHECKPOINT_PATH="checkpoints/${NAME}"

megatron_options=" \
--data-path ${DATA_PATH} \
--vocab-file ${VOCAB_PATH} \
--merge-file ${MERGE_PATH} \
--data-impl mmap \
--override-lr-scheduler \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--tensor-model-parallel-size ${MP_SIZE} \
--init-method-std 0.014 \
--lr-decay-tokens ${LR_DECAY_TOKENS} \
--lr-warmup-samples ${LR_WARMUP_SAMPLES} \
--micro-batch-size ${MICRO_BATCH_SIZE} \
--global-batch-size ${TARGET_GLOBAL_BATCH_SIZE} \
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 16 \
--seq-length ${SEQLEN} \
--max-position-embeddings ${SEQLEN} \
--train-samples ${TRAIN_SAMPLES} \
--train-tokens ${TRAIN_TOKENS} \
--lr ${LR} \
--min-lr ${MIN_LR} \
--lr-decay-style cosine \
--split 98,2,0 \
--log-interval ${LOG_INTERVAL} \
--eval-interval ${EVAL_INTERVAL} \
--eval-iters ${EVAL_ITERS} \
--save-interval ${SAVE_INTERVAL} \
--weight-decay 0.1 \
--clip-grad 1.0 \
--hysteresis 2 \
--num-workers 0 \
--checkpoint-activations \
--fp16 \
--load ${CHECKPOINT_PATH} \
--save ${CHECKPOINT_PATH} \
--tensorboard-queue-size 1 \
--log-timers-to-tensorboard \
--log-batch-size-to-tensorboard \
--log-validation-ppl-to-tensorboard \
--tensorboard-dir ${TENSORBOARD_DIR}"

config_json="ds_config_cl.json"

deepspeed_options=" \
--deepspeed \
--deepspeed_config ${config_json} \
--pipeline-model-parallel-size ${PP_SIZE} \
--partition-activations"

run_cmd="deepspeed ../../pretrain_gpt.py ${megatron_options} ${deepspeed_options} &>> ${NAME}.log"
echo ${run_cmd}
eval ${run_cmd}
set +x
7 changes: 7 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def parse_args(extra_args_provider=None, defaults={},
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.consumed_train_tokens = 0
args.gigaflos_no_embeds = 0

# Iteration-based training.
Expand Down Expand Up @@ -428,6 +429,9 @@ def _add_training_args(parser):
help='Total number of samples to train over all '
'training runs. Note that either train-iters or '
'train-samples should be provided.')
group.add_argument('--train-tokens', type=int, default=None,
help='Total number of tokens to train over all '
'training runs.')
group.add_argument('--log-interval', type=int, default=100,
help='Report loss and timing interval.')
group.add_argument('--exit-interval', type=int, default=None,
Expand Down Expand Up @@ -495,6 +499,9 @@ def _add_learning_rate_args(parser):
group.add_argument('--lr-decay-samples', type=int, default=None,
help='number of samples to decay learning rate over,'
' If None defaults to `--train-samples`')
group.add_argument('--lr-decay-tokens', type=int, default=None,
help='number of tokens to decay learning rate over,'
' If not None will override iter/sample-based decay')
group.add_argument('--lr-warmup-fraction', type=float, default=None,
help='fraction of lr-warmup-(iters/samples) to use '
'for warmup (as a float)')
Expand Down
2 changes: 2 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def save_checkpoint(iteration, model, optimizer, lr_scheduler):
state_dict['args'] = args
state_dict['checkpoint_version'] = 3.0
state_dict['iteration'] = iteration
state_dict['tokens'] = args.consumed_train_tokens

# DeepSpeed saves the model/optimizer/scheduler
if not args.deepspeed:
Expand Down Expand Up @@ -339,6 +340,7 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
else:
try:
iteration = state_dict['iteration']
args.consumed_train_tokens = state_dict['tokens']
except KeyError:
try: # Backward compatible with older checkpoints
iteration = state_dict['total_iters']
Expand Down
46 changes: 34 additions & 12 deletions megatron/learning_rates.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math

from megatron import print_rank_0
from megatron import print_rank_0, get_args

class AnnealingLR(object):
"""Anneals the learning rate."""
Expand All @@ -26,7 +26,7 @@ def __init__(self, optimizer, max_lr, min_lr,
warmup_steps, decay_steps, decay_style,
use_checkpoint_lr_scheduler=True,
override_lr_scheduler=False):

args = get_args()
# Class values.
self.optimizer = optimizer

Expand All @@ -41,6 +41,10 @@ def __init__(self, optimizer, max_lr, min_lr,
assert self.decay_steps > 0
assert self.warmup_steps < self.decay_steps

self.decay_tokens = args.lr_decay_tokens
self.num_tokens = 0
self.warmup_tokens = 0

self.decay_style = decay_style

self.override_lr_scheduler = override_lr_scheduler
Expand All @@ -61,21 +65,31 @@ def get_lr(self):

# Use linear warmup for the initial part.
if self.warmup_steps > 0 and self.num_steps <= self.warmup_steps:
if self.num_steps == self.warmup_steps and \
self.decay_tokens is not None:
self.warmup_tokens = self.num_tokens
return self.max_lr * float(self.num_steps) / \
float(self.warmup_steps)

# If the learning rate is constant, just return the initial value.
if self.decay_style == 'constant':
return self.max_lr

# For any steps larger than `self.decay_steps`, use `self.min_lr`.
if self.num_steps > self.decay_steps:
return self.min_lr

# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
if self.decay_tokens is None:
# For any steps larger than `self.decay_steps`, use `self.min_lr`.
conglongli marked this conversation as resolved.
Show resolved Hide resolved
if self.num_steps > self.decay_steps:
return self.min_lr

# If we are done with the warmup period, use the decay style.
num_steps_ = self.num_steps - self.warmup_steps
decay_steps_ = self.decay_steps - self.warmup_steps
decay_ratio = float(num_steps_) / float(decay_steps_)
else:
if self.num_tokens > self.decay_tokens:
conglongli marked this conversation as resolved.
Show resolved Hide resolved
return self.min_lr
num_tokens_ = self.num_tokens - self.warmup_tokens
decay_tokens_ = self.decay_tokens - self.warmup_tokens
decay_ratio = float(num_tokens_) / float(decay_tokens_)
assert decay_ratio >= 0.0
assert decay_ratio <= 1.0
delta_lr = self.max_lr - self.min_lr
Expand All @@ -91,8 +105,12 @@ def get_lr(self):
return self.min_lr + coeff * delta_lr


def step(self, increment):
def step(self, increment, token_num=None):
"""Set lr for all parameters groups."""
if token_num is None:
args = get_args()
token_num = args.consumed_train_tokens
self.num_tokens = token_num
self.num_steps += increment
new_lr = self.get_lr()
for group in self.optimizer.param_groups:
Expand All @@ -104,6 +122,8 @@ def state_dict(self):
'max_lr': self.max_lr,
'warmup_steps': self.warmup_steps,
'num_steps': self.num_steps,
'warmup_tokens': self.warmup_tokens,
'num_tokens': self.num_tokens,
'decay_style': self.decay_style,
'decay_steps': self.decay_steps,
'min_lr': self.min_lr
Expand Down Expand Up @@ -161,4 +181,6 @@ def load_state_dict(self, sd):
num_steps = sd['num_iters']
else:
num_steps = sd['num_steps']
self.step(increment=num_steps)
self.warmup_tokens = sd['warmup_tokens']
conglongli marked this conversation as resolved.
Show resolved Hide resolved
self.num_tokens = sd['num_tokens']
self.step(num_steps, self.num_tokens)
16 changes: 14 additions & 2 deletions megatron/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,20 @@ def set_input_tensor(self, input_tensor):

def forward(self, input_ids, position_ids, attention_mask, labels=None,
tokentype_ids=None, layer_past=None, get_key_value=False,
forward_method_parallel_output=None):

forward_method_parallel_output=None, curriculum_seqlen=None):
if curriculum_seqlen is not None:
args = get_args()
args.curriculum_seqlen = curriculum_seqlen
if curriculum_seqlen < input_ids.size()[1]:
# seqlen-based curriculum learning
# input_ids, position_ids, labels have size [batch size, seqlen]
input_ids = input_ids[:, :curriculum_seqlen].contiguous()
position_ids = position_ids[:, :curriculum_seqlen].contiguous()
labels = labels[:, :curriculum_seqlen].contiguous()

# attention_mask has size [1, 1, seqlen, seqlen]
attention_mask = attention_mask[:, :, :curriculum_seqlen, :curriculum_seqlen].contiguous()

lm_output = self.language_model(
input_ids,
position_ids,
Expand Down
Loading