[zero_to_fp32.py] support param groups#1017
Conversation
| torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], | ||
| 0) for i in range(len(state_dicts)) |
There was a problem hiding this comment.
This is the only functional change in this PR. Instead of using just the first element, it now uses them all.
There was a problem hiding this comment.
This seems fine for now. I agree we have to revisit, especially for very large models that could cause CPU OOM.
|
Converting of "siamese" model based on t5-11b encoders is done successfully. But when I load it into CPU memory I gone a stange message: |
| torch.cat(state_dicts[i]['optimizer_state_dict'][fp32_groups_key], | ||
| 0) for i in range(len(state_dicts)) |
There was a problem hiding this comment.
This seems fine for now. I agree we have to revisit, especially for very large models that could cause CPU OOM.
|
Thank you for running the checks, @exelents
You have The weights are being restored based on this dict which gets saved when the checkpoint is created so if you do this on your model, you won't find give it a run. |
Yes, I have these weights returned from named_parameters() function of my model, so I suppose they should exist in the checkpoint. |
|
But that's what I'm saying: the checkpoint does have The loader complains about Your code above doesn't check for these 2. BTW, the new version of |
|
I turned on debug flag and it show me all weights in checkpoint including: |
|
We have already established that. Please review #1017 (comment) the warning is for 2 other names. |
|
Ah, okay. I understood. In the code of T5EncoderModel I see that variable encoder_right.encoder.embed_tokens is initialized from external variable encoder_right.shared while T5Stack is created, and it isn't needed to be saved. |
|
Yes, probably could clean that up so that it doesn't produce a misleading warning. The key is to please check that the resumed checkpoint scores well for you. I did only a quick 100 or so steps and the loss looked correct. Also please re-check with zero2. I did test it as well, but a second pair of eyes is always better. I was just concerned that perhaps somehow the saved weights weren't in the same order as the |
In my original version I happened to use a model with 1 param group so I wasn't aware that there could be multiple flattened tensors - one per group, so my reconstruction script was breaking when it run into non-one-single-flat-tensor.
This PR tries to fix that.
There might be a more efficient way to do it, but for now just trying to make sure that it functions correctly.
I also left some disabled debug code for now while it's new and likely to need debug still. We can remove it later if we feel it's solid.
While I tested this on a few live models, it'd be great to have a functional test for zero2 and zero3 for this code. But I'm not quite familiar with how your test suite is done and unfortunately don't really have time right now to sort it out.
The simplest conceptual test would be:
Fixes: #1009
@exelents, please check that this PR solves the problem for you.