Skip to content

Commit

Permalink
Floating-point ops counting and reloading (bigscience-workshop#40)
Browse files Browse the repository at this point in the history
* initial flo count/logging setup (need to fix model parameter count)

* initial flo count/logging setup (need to fix model parameter count)

* 1B3 parameter setup + flos counting

* 1B3 parameter setup + flos counting

* 1B3 parameter setup + flos counting

* 1B3 parameter setup

* 1B3 parameter setup

* synched with latest 13B script

* synched with latest 13B script

* pipe transformer docstring

* improve DS integration evaluation + logging

* use pp engine even for pp=1 (bigscience-workshop#6)

* removed slurm_examples

* flos re-loading

* Update megatron/training.py

Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* Update megatron/data/gpt_dataset.py

Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* Update megatron/utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update megatron/utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* formatting fix, reserving bug for somewhere else, adding flo-logging to TB groups

* indentation bug

* fixing possible double counts

* tweaks

* warning for double counts

Co-authored-by: Shaden Smith <shaden.smith@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: TevenLeScao <uhk85as@jean-zay1.idris.fr>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
  • Loading branch information
6 people authored and Ofir Press committed Sep 23, 2021
1 parent 3a48503 commit 3b2207f
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
1 change: 1 addition & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def parse_args(extra_args_provider=None, defaults={},
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.gigaflos_no_embeds = 0

# Iteration-based training.
if args.train_iters:
Expand Down
2 changes: 2 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
args.gigaflos_no_embeds = getattr(checkpoint_args,
'gigaflos_no_embeds', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')

Expand Down
6 changes: 3 additions & 3 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
last_epoch_num_samples = num_samples - \
num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, \
'last epoch number of samples should be non-negative.'
f'last epoch number of samples {last_epoch_num_samples} should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
'last epoch number of samples exceeded max value.'
assert last_epoch_num_samples <= num_samples_per_epoch, \
f'last epoch number of samples {last_epoch_num_samples} exceeded max value {num_samples_per_epoch}.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
# Note: the 80% number is just based on common sense and can
Expand Down
19 changes: 16 additions & 3 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import check_adlr_autoresume_termination, get_parameters_in_billions
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
Expand Down Expand Up @@ -116,6 +116,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
print(f'estimated model parameters: {get_parameters_in_billions(model)}')
print(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}')
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
Expand Down Expand Up @@ -551,7 +553,7 @@ def add_to_logging(name):
total_loss_dict[skipped_iters_key]

# Tensorboard values.
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
if writer and (iteration % args.tensorboard_log_interval == 0) and \
is_last_rank():
writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples)
writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration)
Expand All @@ -567,6 +569,8 @@ def add_to_logging(name):
writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs samples', loss_dict[key],
args.consumed_train_samples)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs gigaflos (without embeddings)', loss_dict[key],
args.gigaflos_no_embeds)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale,
Expand Down Expand Up @@ -653,6 +657,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
print(f"Number of parameters: {get_parameters_in_billions(model)} billion")
print(f"Number of parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)} billion")
args = get_args()
timers = get_timers()

Expand Down Expand Up @@ -689,9 +695,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer,
lr_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
new_samples = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += new_samples
args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True))

# Logging.
if args.deepspeed:
Expand Down Expand Up @@ -833,11 +841,16 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar(f'lm-loss-validation/{key} validation vs samples',
total_loss_dict[key].item(),
args.consumed_train_samples)
writer.add_scalar(f'lm-loss-validation/{key} validation vs gigaflos (without embeddings)',
total_loss_dict[key].item(),
args.gigaflos_no_embeds)
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar(f'lm-loss-validation/{key} validation ppl', ppl,
iteration)
writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs samples',
ppl, args.consumed_train_samples)
writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs gigaflos (without embeddings)',
ppl, args.gigaflos_no_embeds)

length = len(string) + 1
print_rank_last('-' * length)
Expand Down
31 changes: 27 additions & 4 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
"""General utilities."""

import sys
import warnings

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as torchDDP

from apex.multi_tensor_apply import multi_tensor_applier
Expand All @@ -28,7 +30,7 @@
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate, VocabParallelEmbedding
from megatron import get_num_microbatches

def unwrap_model(model, module_instances=(torchDDP)):
Expand Down Expand Up @@ -204,11 +206,32 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids


def get_parameters_in_billions(model):
def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()


def unique_param_count(param_list):
return sum(dict((p.data_ptr(), param_size(p)) for p in param_list).values())


def non_embedding_params(module):
embedding_param_names = [
f"{name}.weight" for name, module_type in module.named_modules() if isinstance(module_type, nn.Embedding) or isinstance(module_type, VocabParallelEmbedding)
]
non_embedding_parameters = [
parameter for name, parameter in module.named_parameters() if name not in embedding_param_names
]
return unique_param_count(non_embedding_parameters)


def get_parameters_in_billions(model, exclude_embeddings=False):
gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group())

approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
for model_module in model])
if exclude_embeddings:
approx_parameters_in_billions = sum([non_embedding_params(model_module) for model_module in model])
else:
warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings")
approx_parameters_in_billions = unique_param_count([p for model_module in model for p in model_module.parameters()])

return approx_parameters_in_billions*gpus_per_model/(1e9)

Expand Down

0 comments on commit 3b2207f

Please sign in to comment.