Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Floating-point ops counting and reloading #40

Merged
merged 24 commits into from
Sep 15, 2021
Merged

Floating-point ops counting and reloading #40

merged 24 commits into from
Sep 15, 2021

Conversation

TevenLeScao
Copy link
Collaborator

This PR adds flo-based logging of the validation loss. We count gigaflos to avoid overflowing, and because parameter numbers are also counted in billions. We don't count operations linked to the embeddings (cf. Scaling Laws for Neural Language Models, section 2.1.) Those are stored in the args as args.gigaflos_no_embeds, and exported with them in the state_dict at checkpoint saving time. Currently, reloading a model trained before this PR assumes its flos are 0, which is similar behaviour to the consumed_train_samples.

megatron/initialize.py Outdated Show resolved Hide resolved
@TevenLeScao
Copy link
Collaborator Author

Note this integrates cherry-picked fixes from the main Megatron codebase, most notably changes to integrate deepspeed.

megatron/training.py Outdated Show resolved Hide resolved
megatron/training.py Outdated Show resolved Hide resolved
Copy link
Member

@stas00 stas00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a few suggestions, looking good. Thank you, @TevenLeScao

megatron/checkpointing.py Outdated Show resolved Hide resolved
megatron/initialize.py Outdated Show resolved Hide resolved
megatron/training.py Show resolved Hide resolved
megatron/utils.py Outdated Show resolved Hide resolved
megatron/utils.py Outdated Show resolved Hide resolved
Copy link
Member

@thomasw21 thomasw21 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not yet familiar with the codebase, so I'm not the best person to review so I'll let the others approve.

megatron/data/gpt_dataset.py Outdated Show resolved Hide resolved
megatron/training.py Outdated Show resolved Hide resolved
@stas00
Copy link
Member

stas00 commented Aug 5, 2021

btw, this:

sum(p.numel() for submodel in model for p in submodel.parameters() if p.requires_grad)

will count tied variables multiple times, you probably want:

sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
  • after you adjust for p.requires_grad.

@stas00
Copy link
Member

stas00 commented Aug 5, 2021

Also please make sure to rebase to include this fix if you're testing:
42fe3b3

TevenLeScao and others added 3 commits August 24, 2021 18:20
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
@@ -649,7 +649,7 @@ def save_checkpoint_and_time(iteration, model, optimizer, lr_scheduler):
def train(forward_step_func, model, optimizer, lr_scheduler,
train_data_iterator, valid_data_iterator):
"""Train the model function."""
print(sum(p.numel() for submodel in model for p in submodel.parameters() if p.requires_grad))
print(f"Number of trainable parameters: {sum(p.numel() for submodel in model for p in submodel.parameters() if p.requires_grad)}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and this most likely needs to be fixed to remove double/triple-counting of tied vars:
#40 (comment)

TevenLeScao and others added 5 commits August 25, 2021 15:11
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

def param_count_without_doubles(param_list):
return sum(dict((p.data_ptr(), param_size(p)) for p in param_list).values())
# sum(dict((p.data_ptr(), param_size(p)) for submodel in model for p in submodel.parameters()).values())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed then

return parameter.ds_numel if hasattr(parameter, 'ds_id') else parameter.nelement()


def param_count_without_doubles(param_list):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since this can be triplicates, and more, perhaps a more telling name would be just: unique_param_count?

@TevenLeScao
Copy link
Collaborator Author

TevenLeScao commented Aug 30, 2021

btw, this:

sum(p.numel() for submodel in model for p in submodel.parameters() if p.requires_grad)

will count tied variables multiple times, you probably want:

sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
* after you adjust for `p.requires_grad`.

So originally I wasn't sure I was testing it right (sorry, took a bit more time than I would have liked) but now I am quite certain that they both return the same number, with duplicates. Do we have other leads on this?

I am testing by switching comments out:

def unique_param_count(param_list):
    # print("old school count")
    # return sum(p.numel() for p in param_list)
    print("new school count")
    return sum(dict((p.data_ptr(), param_size(p)) for p in param_list).values())

and looking at the model parameter print-out at the start of training.

@stas00
Copy link
Member

stas00 commented Aug 31, 2021

I think you're correct, Teven, I experimented some more with it and it looks like in general case this is actually not needed.

I got the original at https://stackoverflow.com/a/62764464/9201239 and I'm pretty sure I used it because I was getting duplicate counts. But perhaps it was in a different scenario.

I think these have to be 2 tensors with shared .strorage() to need checking a unique data_ptr, however in most (all?) of the model layers we deal with the shared vars are just pointers to the same python variable. So since model.parameters creates a set based on tensor variables, it actually never returns any tied variables.

set creation:
https://github.com/pytorch/pytorch/blob/1f16c22dc8251f01627ee73ad1ef69bd18e51447/torch/nn/modules/module.py#L1488

For example here lm_head which is tied to the first embedding doesn't even show up as if it doesn't exist:

from transformers import T5ForConditionalGeneration;
m = T5ForConditionalGeneration.from_pretrained("t5-small"); 
for k,v in m.named_parameters():
    print(k)

Thank you for questioning my suggestion, @TevenLeScao!

@TevenLeScao
Copy link
Collaborator Author

Sorry for the misunderstanding, I currently find that both return the number with duplicates! For the same reasons you've laid, this does not correspond to what I thought the code would do, so I'm still investigating.

@stas00
Copy link
Member

stas00 commented Aug 31, 2021

Oh, you're saying you are getting wrong numbers with both approaches. I see.

Let's look at the code/data.

How many shared embedings do you have in the model? What is the code that does the sharing (or tieing)?

Let's perhaps see the dump of the model structure as well?

The things to consider

In [5]: import torch

In [6]: a = b = torch.ones(1)

In [7]: id(a)
Out[7]: 140596158463936

In [8]: id(b)
Out[8]: 140596158463936

In [11]: a.data_ptr()
Out[11]: 94903924655104

In [12]: b.data_ptr()
Out[12]: 94903924655104

So now you can look at the params that you think are shared and do the above and see which storage the python variables point to and which tensor data ptrs point to.

@TevenLeScao
Copy link
Collaborator Author

Hey, sorry, I looked at the code a bit and I don't think I currently have the bandwidth to adapt it this way. I'd really like this to be functional before we launch the next round of training; here are ways I think we can achieve that:

  • Count parameters somewhere else, at the start of training (either with a formula or instantiating something like an HF model in CPU RAM and counting there)
  • Ask MLM/DS to do that; they seemed to have some leads from the Slack discussions ("The same bookkeeping is used to compute norms and other global values")
  • Merge as-is and warn that this can only be used for non-embedding counts
    What do you think?

@stas00
Copy link
Member

stas00 commented Sep 8, 2021

As I mentioned in my last comment, I will need context to work with to understand the problem and to be able to reproduce it.

Once I finished sorting out the checkpoints I can work on this if you can explain to me how to reproduce the issue.

  • Count parameters somewhere else, at the start of training (either with a formula or instantiating something like an HF model in CPU RAM and counting there)

I don't think we have enough memory to load the whole model on cpu.

  • Ask MLM/DS to do that; they seemed to have some leads from the Slack discussions ("The same bookkeeping is used to compute norms and other global values")

Doesn't hurt to ask

  • Merge as-is and warn that this can only be used for non-embedding counts

warnings don't usually work IMHO, but if an incomplete solution works for your needs, and you need it now as-is, merge it and open another issue to track to bring this to completion.

@TevenLeScao
Copy link
Collaborator Author

I really haven't had time to come back to this; I'll merge now with a warning and open an issue.

@TevenLeScao TevenLeScao merged commit af8229e into main Sep 15, 2021
@thomasw21
Copy link
Member

Maybe unrelated, but I found this in DS repo:

https://github.com/microsoft/DeepSpeed/blob/big-science/deepspeed/runtime/pipe/engine.py#L117-L126

ofirpress pushed a commit to ofirpress/Megatron-DeepSpeed that referenced this pull request Sep 23, 2021
* initial flo count/logging setup (need to fix model parameter count)

* initial flo count/logging setup (need to fix model parameter count)

* 1B3 parameter setup + flos counting

* 1B3 parameter setup + flos counting

* 1B3 parameter setup + flos counting

* 1B3 parameter setup

* 1B3 parameter setup

* synched with latest 13B script

* synched with latest 13B script

* pipe transformer docstring

* improve DS integration evaluation + logging

* use pp engine even for pp=1 (bigscience-workshop#6)

* removed slurm_examples

* flos re-loading

* Update megatron/training.py

Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* Update megatron/data/gpt_dataset.py

Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* Update megatron/utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update megatron/utils.py

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* formatting fix, reserving bug for somewhere else, adding flo-logging to TB groups

* indentation bug

* fixing possible double counts

* tweaks

* warning for double counts

Co-authored-by: Shaden Smith <shaden.smith@microsoft.com>
Co-authored-by: Jeff Rasley <jerasley@microsoft.com>
Co-authored-by: TevenLeScao <uhk85as@jean-zay1.idris.fr>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
SaulLu added a commit to SaulLu/Megatron-DeepSpeed that referenced this pull request Sep 24, 2021
@stas00
Copy link
Member

stas00 commented Nov 8, 2021

Hi Teven,

This warning isn't great:

        warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings")

as it logs on each gpu and resulting in hundreds of these flooding the log file. could we please change this to run it only global rank 0 please?

Thank you!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants