Skip to content

[zero_to_fp32.py] support param groups#1017

Merged
tjruwase merged 2 commits intodeepspeedai:masterfrom
stas00:zero_to_fp32-param_groups
Apr 29, 2021
Merged

[zero_to_fp32.py] support param groups#1017
tjruwase merged 2 commits intodeepspeedai:masterfrom
stas00:zero_to_fp32-param_groups

Conversation

@stas00
Copy link
Copy Markdown
Collaborator

@stas00 stas00 commented Apr 29, 2021

In my original version I happened to use a model with 1 param group so I wasn't aware that there could be multiple flattened tensors - one per group, so my reconstruction script was breaking when it run into non-one-single-flat-tensor.

This PR tries to fix that.

There might be a more efficient way to do it, but for now just trying to make sure that it functions correctly.

I also left some disabled debug code for now while it's new and likely to need debug still. We can remove it later if we feel it's solid.

While I tested this on a few live models, it'd be great to have a functional test for zero2 and zero3 for this code. But I'm not quite familiar with how your test suite is done and unfortunately don't really have time right now to sort it out.

The simplest conceptual test would be:

# original_fp32_state_dict_path
model = (multi-param-group model).from_pretrained(original_fp32_state_dict_path)
engine = deepspeed.initialize(model, ...)
engine.save_checkpoint()
! ./zero_to_fp32.py global_step1 pytorch_model.bin
! diff pytorch_model.bin original_fp32_state_dict_path # should be identical

Fixes: #1009

@exelents, please check that this PR solves the problem for you.

Comment on lines +63 to +64
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
Copy link
Copy Markdown
Collaborator Author

@stas00 stas00 Apr 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only functional change in this PR. Instead of using just the first element, it now uses them all.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine for now. I agree we have to revisit, especially for very large models that could cause CPU OOM.

@exelents
Copy link
Copy Markdown

exelents commented Apr 29, 2021

Converting of "siamese" model based on t5-11b encoders is done successfully. But when I load it into CPU memory I gone a stange message:

Some weights of T5Siamese were not initialized from the model checkpoint at
 ./siamese_train_deepspeed/models/siamese-t5-11b-fp16/checkpoint-1625 and are newly initialized: 
['encoder_left.encoder.embed_tokens.weight', 'encoder_right.encoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

Comment on lines +63 to +64
torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key],
0) for i in range(len(state_dicts))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems fine for now. I agree we have to revisit, especially for very large models that could cause CPU OOM.

@tjruwase tjruwase merged commit a8cf887 into deepspeedai:master Apr 29, 2021
@stas00 stas00 deleted the zero_to_fp32-param_groups branch April 29, 2021 16:12
@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Apr 29, 2021

Thank you for running the checks, @exelents

./siamese_train_deepspeed/models/siamese-t5-11b-fp16/checkpoint-1625 and are newly initialized:
['encoder_left.encoder.embed_tokens.weight', 'encoder_right.encoder.embed_tokens.weight']

You have encoder_right.shared.weight - aren't those tied/aliased? (Same for left)

The weights are being restored based on this dict which gets saved when the checkpoint is created

    def _get_param_shapes(self):
        param_shapes = OrderedDict()
        for name, param in self.module.named_parameters():
            param_shapes[name] = param.ds_shape if hasattr(param,
                                                           "ds_shape") else param.shape
            # print(f"saving param {name} {param_shapes[name]}")
        return param_shapes

so if you do this on your model, you won't find encoder_left.encoder.embed_tokens.weight in the names.

give it a run. self is the ds engine - so instead of self.module it'd be just your transformers model

@exelents
Copy link
Copy Markdown

You have encoder_right.shared.weight - aren't those tied/aliased? (Same for left)

Yes, I have these weights returned from named_parameters() function of my model, so I suppose they should exist in the checkpoint.

dd = list(model_engine.module.named_parameters())
dd = list(filter(lambda x: 'shared' in x[0], dd))
dd
[('encoder_left.shared.weight',
  Parameter containing:
  tensor([1.], device='cuda:0', dtype=torch.float16, requires_grad=True)),
 ('encoder_right.shared.weight',
  Parameter containing:
  tensor([1.], device='cuda:0', dtype=torch.float16, requires_grad=True))]

@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Apr 29, 2021

But that's what I'm saying: the checkpoint does have encoder_left.shared.weight and encoder_right.shared.weight and those are restored.

The loader complains about encoder_left.encoder.embed_tokens.weight and encoder_right.encoder.embed_tokens.weight.

Your code above doesn't check for these 2.

BTW, the new version of zero_to_fp32.py has a debug flag - turn it on and when you run it you will see each weight as it gets loaded.

@exelents
Copy link
Copy Markdown

exelents commented Apr 29, 2021

I turned on debug flag and it show me all weights in checkpoint including:

└──>$ cat debug.txt | grep shared
encoder_left.shared.weight full shape: torch.Size([32128, 512]) partition0 numel=16449536 partitioned_padding_numel=0
encoder_right.shared.weight full shape: torch.Size([32128, 512]) partition0 numel=16449536 partitioned_padding_numel=0

@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Apr 29, 2021

We have already established that.

Please review #1017 (comment) the warning is for 2 other names.

@exelents
Copy link
Copy Markdown

exelents commented Apr 29, 2021

Ah, okay. I understood. In the code of T5EncoderModel I see that variable encoder_right.encoder.embed_tokens is initialized from external variable encoder_right.shared while T5Stack is created, and it isn't needed to be saved.
Thank you.

@stas00
Copy link
Copy Markdown
Collaborator Author

stas00 commented Apr 29, 2021

Yes, probably could clean that up so that it doesn't produce a misleading warning.

The key is to please check that the resumed checkpoint scores well for you. I did only a quick 100 or so steps and the loss looked correct. Also please re-check with zero2. I did test it as well, but a second pair of eyes is always better.

I was just concerned that perhaps somehow the saved weights weren't in the same order as the param_names dict, so it'd always reshape it correctly, since it's the same number of elements once it's all flattened into a single tensor, but I checked that the order appears to be correct.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Reconstruction of fp32 weights on stage3 doesn't work

3 participants