fix(models): Preserve custom token IDs through DiaConfig save and load#43928
Conversation
| self.pad_token_id = pad_token_id | ||
| self.eos_token_id = eos_token_id | ||
| self.bos_token_id = bos_token_id |
There was a problem hiding this comment.
imo we need to remote pad_token_id: int = 1025, eos_token_id: int = 1024, bos_token_id: int = 1026 in longer term. They have to be in the config where it's actually set as attr
For BC we can keep it and set default to None, that should work since save_pretrained wasn't saving it anyway in the main config. So smth like
python
def __init__(self,
pad_token_id: int = None,
eos_token_id: int = None,
bos_token_id: int = None,
# We could raise a deprecation warning here, but first we need to update the official ckpt in `nari-labs` org
if pad_token_id is not None:
logger.maybe_warn("Please pass you pad token to the config where it belongs!")
self.decoder_config.pad_token_id = pad_token_id
There was a problem hiding this comment.
Resolved; I've updated the tests and docs, and added a TODO to avoid missing the nari-labs org config update (happy to remove if unnecessary). Tests pass without the change to the testing file as well, but maybe this should make the pattern clearer for readers (? :))
zucchini-nlp
left a comment
There was a problem hiding this comment.
Thanks for catching this, I've also been thinking to get rid of this pattern across the repo haha. I'm not 100% sure but there were a few more similar models, t5gemma and maybe some more. Would be great to check if you have bandwidth. or I'll just let you know later in the comments :)
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
@zucchini-nlp That makes sense! I've left a comment here with the findings within the time I had for your perusal. → I've listed the hyperlinks with the correct line numbers for each reference; feel free to cross-check my findings :) What was I searching for → Accept pad_token_id/eos_token_id/bos_token_id in What I found → Wasn't able to find other models other than Dia with the actual save/load bug; other composite configs set token IDs on a sub-config, but make sure it persists in the main config, speaking of which, the patterns are:
→ Additional: Would love to hear your suggestions for which ones you're looking to fix asap and probably we could move to writing issues for them, happy to help :) |
|
[For maintainers] Suggested jobs to run (before merge) run-slow: dia |
|
@harshaljanjani thanks a lot for your research. The T5gemma seems to work because it actually assigns the attr in subconfig and also in the main config, thus it will save it in both places. Let's fix Dia for now and I will think of unifying these later, I am not 100% sure if it'll disrupt downstream usage for other models if we just delete the attr |
| # TODO: Remove token ID forwarding once the `nari-labs/Dia-1.6B` | ||
| # checkpoint is updated |
There was a problem hiding this comment.
Can you open a PR in nari-labs repo and link to this issue?
|
Opened a PR at https://github.com/nari-labs/dia with the change; links the issue and the PR. |
|
I meant more like the hub one 😅 Opened https://huggingface.co/nari-labs/Dia-1.6B-0626/discussions/7 as well :) |
Oops, my bad! Should I close the other PR then? 😅 |
|
No, you can leave it. They would need to update their GH repo if hub config changes are merged! |
Perfect, sounds good! |
What does this PR do?
The following failing Dia use case was identified and fixed in this PR:
→ Tests that created
DiaConfigwith custom token IDs (eos_token_id=97for avocab_size=100) failed because saving then reloading the config would reset these values to defaults (eos_token_id=1024). The reason being that DiaConfig only set the attrs on the subdecoder_config, not on the main parent config, also leading to an IndexError during generation.→ For more details on reproducing the bug and the output screenshots, please visit the linked issue!
Fixes #43927.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.