Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floating-point ops counting and reloading #40

Merged
merged 24 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
5196f8a
initial flo count/logging setup (need to fix model parameter count)
TevenLeScao Jul 22, 2021
a3a4ba4
initial flo count/logging setup (need to fix model parameter count)
TevenLeScao Jul 23, 2021
bdd75f1
1B3 parameter setup + flos counting
TevenLeScao Jul 28, 2021
c0fc29a
1B3 parameter setup + flos counting
TevenLeScao Jul 28, 2021
aefbe3b
1B3 parameter setup + flos counting
TevenLeScao Jul 28, 2021
17e0184
1B3 parameter setup
TevenLeScao Jul 28, 2021
97dd06d
1B3 parameter setup
TevenLeScao Jul 28, 2021
64892e2
synched with latest 13B script
TevenLeScao Jul 28, 2021
3c79aac
synched with latest 13B script
TevenLeScao Jul 29, 2021
b7b3167
pipe transformer docstring
Jul 26, 2021
8382141
improve DS integration evaluation + logging
Jul 26, 2021
06cb18f
use pp engine even for pp=1 (#6)
jeffra Jul 30, 2021
d581894
removed slurm_examples
Aug 4, 2021
60794bf
flos re-loading
TevenLeScao Aug 4, 2021
c79db1c
Merge branch 'main' into training_flos
TevenLeScao Aug 24, 2021
fb33f13
Update megatron/training.py
TevenLeScao Aug 24, 2021
dff1479
Update megatron/data/gpt_dataset.py
TevenLeScao Aug 24, 2021
2fa3b5b
Update megatron/utils.py
TevenLeScao Aug 25, 2021
ff7af10
Update megatron/utils.py
TevenLeScao Aug 25, 2021
b9ac381
formatting fix, reserving bug for somewhere else, adding flo-logging …
TevenLeScao Aug 25, 2021
f25e25f
indentation bug
TevenLeScao Aug 25, 2021
e63503d
fixing possible double counts
TevenLeScao Aug 25, 2021
5bdcf81
tweaks
TevenLeScao Aug 25, 2021
72ad711
warning for double counts
TevenLeScao Sep 15, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def parse_args(extra_args_provider=None, defaults={},
# Consumed tokens.
args.consumed_train_samples = 0
args.consumed_valid_samples = 0
args.gigaflos_no_embeds = 0

# Iteration-based training.
if args.train_iters:
Expand Down
2 changes: 2 additions & 0 deletions megatron/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,8 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
update_num_microbatches(consumed_samples=args.consumed_train_samples)
args.consumed_valid_samples = getattr(checkpoint_args,
'consumed_valid_samples', 0)
args.gigaflos_no_embeds = getattr(checkpoint_args,
'gigaflos_no_embeds', 0)
else:
print_rank_0('could not find arguments in the checkpoint ...')

Expand Down
6 changes: 3 additions & 3 deletions megatron/data/gpt_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,10 +237,10 @@ def _build_index_mappings(name, data_prefix, documents, sizes,
last_epoch_num_samples = num_samples - \
num_samples_from_epochs_minus_one
assert last_epoch_num_samples >= 0, \
'last epoch number of samples should be non-negative.'
f'last epoch number of samples {last_epoch_num_samples} should be non-negative.'
num_samples_per_epoch = (tokens_per_epoch - 1) // seq_length
assert last_epoch_num_samples < (num_samples_per_epoch + 1), \
'last epoch number of samples exceeded max value.'
assert last_epoch_num_samples <= num_samples_per_epoch, \
f'last epoch number of samples {last_epoch_num_samples} exceeded max value {num_samples_per_epoch}.'
# If we have less than 80% of the samples for the last epoch,
# seperate out the epoch and treat it differently.
# Note: the 80% number is just based on common sense and can
Expand Down
19 changes: 16 additions & 3 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from megatron.initialize import write_args_to_tensorboard
from megatron.learning_rates import AnnealingLR
from megatron.model import DistributedDataParallel as LocalDDP
from megatron.utils import check_adlr_autoresume_termination
from megatron.utils import check_adlr_autoresume_termination, get_parameters_in_billions
from megatron.utils import unwrap_model
from megatron.data.data_samplers import build_pretraining_data_loader
from megatron.utils import calc_params_l2_norm
Expand Down Expand Up @@ -113,6 +113,8 @@ def pretrain(train_valid_test_dataset_provider,
# Model, optimizer, and learning rate.
timers('model-and-optimizer-setup').start()
model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider)
print(f'estimated model parameters: {get_parameters_in_billions(model)}')
print(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}')
timers('model-and-optimizer-setup').stop()
print_datetime('after model, optimizer, and learning rate '
'scheduler are built')
Expand Down Expand Up @@ -545,7 +547,7 @@ def add_to_logging(name):
total_loss_dict[skipped_iters_key]

# Tensorboard values.
if writer and (iteration % args.tensorboard_log_interval == 0 ) and \
if writer and (iteration % args.tensorboard_log_interval == 0) and \
is_last_rank():
writer.add_scalar('steps-vs-samples/y=steps,x=samples', iteration, args.consumed_train_samples)
writer.add_scalar('steps-vs-samples/y=samples,x=steps', args.consumed_train_samples, iteration)
Expand All @@ -561,6 +563,8 @@ def add_to_logging(name):
writer.add_scalar(f"lm-loss-training/{key}", loss_dict[key], iteration)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs samples', loss_dict[key],
args.consumed_train_samples)
writer.add_scalar(f"lm-loss-training/{key}" + ' vs gigaflos (without embeddings)', loss_dict[key],
args.gigaflos_no_embeds)
if args.log_loss_scale_to_tensorboard:
writer.add_scalar('loss-scale/loss-scale', loss_scale, iteration)
writer.add_scalar('loss-scale/loss-scale vs samples', loss_scale,
Expand Down Expand Up @@ -647,6 +651,8 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
print(f"Number of parameters: {get_parameters_in_billions(model)} billion")
print(f"Number of parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)} billion")
args = get_args()
timers = get_timers()

Expand Down Expand Up @@ -683,9 +689,11 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
optimizer,
lr_scheduler)
iteration += 1
args.consumed_train_samples += mpu.get_data_parallel_world_size() * \
new_samples = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \
get_num_microbatches()
args.consumed_train_samples += new_samples
args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True))

# Logging.
if args.deepspeed:
Expand Down Expand Up @@ -827,11 +835,16 @@ def evaluate_and_print_results(prefix, forward_step_func,
writer.add_scalar(f'lm-loss-validation/{key} validation vs samples',
total_loss_dict[key].item(),
args.consumed_train_samples)
writer.add_scalar(f'lm-loss-validation/{key} validation vs gigaflos (without embeddings)',
total_loss_dict[key].item(),
args.gigaflos_no_embeds)
TevenLeScao marked this conversation as resolved.
Show resolved Hide resolved
if args.log_validation_ppl_to_tensorboard:
writer.add_scalar(f'lm-loss-validation/{key} validation ppl', ppl,
iteration)
writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs samples',
ppl, args.consumed_train_samples)
writer.add_scalar(f'lm-loss-validation/{key} validation ppl vs gigaflos (without embeddings)',
ppl, args.gigaflos_no_embeds)

length = len(string) + 1
print_rank_last('-' * length)
Expand Down
30 changes: 26 additions & 4 deletions megatron/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import sys

import torch
from torch import nn
from torch.nn.parallel import DistributedDataParallel as torchDDP

from apex.multi_tensor_apply import multi_tensor_applier
Expand All @@ -28,7 +29,7 @@
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate, VocabParallelEmbedding
from megatron import get_num_microbatches

def unwrap_model(model, module_instances=(torchDDP)):
Expand Down Expand Up @@ -204,11 +205,32 @@ def get_ltor_masks_and_position_ids(data,
return attention_mask, loss_mask, position_ids


def get_parameters_in_billions(model):
def param_size(parameter):
return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()


def param_count_without_doubles(param_list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this can be triplicates, and more, perhaps a more telling name would be just: unique_param_count?

return sum(dict((p.data_ptr(), param_size(p)) for p in param_list).values())
# sum(dict((p.data_ptr(), param_size(p)) for submodel in model for p in submodel.parameters()).values())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed then



def non_embedding_params(module):
embedding_param_names = [
f"{name}.weight" for name, module_type in module.named_modules() if isinstance(module_type, nn.Embedding) or isinstance(module_type, VocabParallelEmbedding)
]
non_embedding_parameters = [
parameter for name, parameter in module.named_parameters() if name not in embedding_param_names
]
return param_count_without_doubles(non_embedding_parameters)


def get_parameters_in_billions(model, exclude_embeddings=False):
gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group())

approx_parameters_in_billions = sum([sum([p.ds_numel if hasattr(p,'ds_id') else p.nelement() for p in model_module.parameters()])
for model_module in model])
if exclude_embeddings:
approx_parameters_in_billions = sum([non_embedding_params(model_module) for model_module in model])
else:
approx_parameters_in_billions = param_count_without_doubles([p for model_module in model for p in model_module.parameters()])

return approx_parameters_in_billions*gpus_per_model/(1e9)

Expand Down