Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load sharded checkpoints? #31

Closed
patrickvonplaten opened this issue May 4, 2022 · 9 comments
Closed

How to load sharded checkpoints? #31

patrickvonplaten opened this issue May 4, 2022 · 9 comments
Labels
question Further information is requested

Comments

@patrickvonplaten
Copy link
Contributor

patrickvonplaten commented May 4, 2022

❓ Questions and Help

After having set-up the libraries as described in: https://github.com/facebookresearch/metaseq/blob/main/docs/setup.md ,
it is possible to load the 350m checkpoint since it's not sharded as follows:

wget https://dl.fbaipublicfiles.com/opt/v1_20220502/350m/reshard.pt ./
  1. Next we need to comment out one line in the Megatron-LM library which is only relevant for training (initialize different random seeds accross pp ranks):
    Comment out this line: https://github.com/ngoyal2707/Megatron-LM/blob/ae0b844c1f6725c3433a95e42cac760b3885170b/megatron/initialize.py#L65 in your local clone of Megatron-LM

  2. Now we write the following Python script to a run_model.py file:

import os

from transformers import AutoTokenizer, GPT2Tokenizer
from megatron.initialize import initialize_megatron
from metaseq import checkpoint_utils
import torch

path = "./"

# arguments taken from: https://arxiv.org/pdf/2205.01068.pdf | table 1
initialize_megatron(args_defaults={
    "micro_batch_size": 1, 
    "num_layers": 24, 
    "hidden_size": 1024, 
    "num_attention_heads": 16,
    "max_position_embeddings": 2048, 
    "encoder_seq_length": 2048 
})

tokenizer = GPT2Tokenizer.from_pretrained("patrickvonplaten/opt_gpt2_tokenizer")
tokenizer.save_pretrained(path)

checkpoint = checkpoint_utils.load_model_ensemble_and_task(
    [os.path.join(path, "reshard.pt")],
    arg_overrides={
        "vocab_filename": os.path.join(path, "vocab.json"),
        "merges_filename": os.path.join(path, "merges.txt"),
    }
)

model = checkpoint[0][0].eval()
  1. We can load the checkpoint when running
torchrun run_model.py --pipeline-model-parallel-size 1 --tensor-model-parallel-size 1

Problem This only works for the 350m checkpoint!!! For the other checkpoints this doesn't work.
E.g. when replacing:
[os.path.join(path, "reshard.pt")]
by
[os.path.join(path, "reshard-model_part-0.pt"), os.path.join(path, "reshard-model_part-1.pt")] (part-0 and part-1 of the 125M model),
we're getting an error because the weigths are all flattened into 1D-arrays.

Using #29 sadly also doesn't help, since the checkpoints don't seem to be in the *shard* format as required here:

sorted(glob(f"{pth_prefix}*shard*.pt"), key=_get_shard_number)

The parameter flattening seems to come from Fairscale and we've found some functionality to unflatten it here: https://github.com/facebookresearch/fairscale/blob/51b53ddb6c3aa77426c7d5cc0b543b79628053c4/fairscale/nn/misc/flatten_params_wrapper.py#L358 , but we don't manage to wrap our head around how to make it work exactly.

@stephenroller @suchenzang @zhiqwang - any pointers on how we could load the 125M model (and the others) into a model instance of metaseq?

@patrickvonplaten patrickvonplaten added the question Further information is requested label May 4, 2022
@suchenzang
Copy link
Contributor

Could you try using our consolidation script for these smaller models, and try loading the consolidated checkpoint instead? Instructions here: https://github.com/facebookresearch/metaseq/blob/main/projects/OPT/download_opt175b.md#reshard-the-shards

@patrickvonplaten
Copy link
Contributor Author

Hey @suchenzang,

Thanks for your answer. We did try out these scripts, but they don't work for a couple of reasons:

If you look into:

sorted(glob(f"{pth_prefix}*shard*.pt"), key=_get_shard_number)

You can see that the function expects the filenames to have a syntax which is different from the filenames of 125m, which are

  • reshard-model_part-0.pt
  • reshard-model_part-1.pt

Accordingly, if you run the command, you'll soon get the following error:

AssertionError: reshard-model_part-1.pt did not match shard(\d+).pt  

Also more generally this function cannot work as it never loads the correct module and no metadata is stored anywhere in the 125m checkpoints.

Could you try to load the 125M checkpoints and share the script here? It'd be immensely helpful for the community to understand how to load the flat 1D-array weights I think :-)

Thanks a lot!

@patrickvonplaten
Copy link
Contributor Author

More generally, the checkpoints of the 125m have no information about:

state_dict["shard_metadata"]["param_metadata"]

IMO this is required to load the flat 1d tensor correctly into a fairseq model. Do you know where we could find the shard_metadata information for the 125m and other checkpoints?

@sshleifer
Copy link

sshleifer commented May 5, 2022

I only interacted with this code when it was in a branch off fairseq (pre-metaseq), but either the Meta people can send you param_metadata or you can call. param_metadata = fsdp_instance.local_metadata_dict() on the last rank after one step of FSDP training with the same world size that they used to get it. I suspect the first route is easier!

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented May 6, 2022

@suchenzang @stephenroller any way that you guys could send us or open-source the param_metadata dicts?

I've tried for quite some time now to reproduce the correct parameter mapping without much success.

It's not really stated on how many GPUs (world_size models other than 175B were trained), nor was I able to reproduce the parameter mapping.

Also, there is one thing I don't fully understand:
I can load a randomely initialized model according to the model config in state[cfg], but this random model then has significantly less parameters than the number of parameters in the sharded checkpoitns.
E.g. for the 125M model the sum of parameters of the two checkpoints has more than 126M parameters even though the randomely initialized model has (the correct amount of) 125M parameters.

It would be extremely useful if you guys could provide some kind of script that allows to load the sharded checkpoints on CPU :-)

@ElegantLin
Copy link

ElegantLin commented May 7, 2022

Hi, @patrickvonplaten . Thanks for your code. I met the error that ModuleNotFoundError: No module named 'transformers'. I think the error is from the installation of Megatron-LM. Do you know how to fix it?

Thanks!

@stephenroller
Copy link
Contributor

Okay I wrote #60 to help us out here. It outputs the full unflattened and non-sharded checkpoint and should be pretty easy to load into hugging face. See the docstring for usage.

In [4]: sd = torch.load("/shared/home/roller/foo/restored.pt")

In [5]: sd.keys()
Out[5]: dict_keys(['decoder.version', 'decoder.embed_tokens.weight', 'decoder.embed_positions.weight', 'decoder.layers.0.self_attn.qkv_proj.weight', 'decoder.layers.0.self_attn.qkv_proj.bias', 'decoder.layers.0.self_attn.ou
t_proj.weight', 'decoder.layers.0.self_attn.out_proj.bias', 'decoder.layers.0.self_attn_layer_norm.weight', 'decoder.layers.0.self_attn_layer_norm.bias', 'decoder.layers.0.fc1.weight', 'decoder.layers.0.fc1.bias', 'decoder.l
ayers.0.fc2.weight', ...

@stephenroller
Copy link
Contributor

Keep in mind that metaseq, just like fairseq, actually prompts the model on the EOS token.

@suchenzang
Copy link
Contributor

Closing this given #88, #78, and #77, which should cover this issue as well.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

5 participants