Skip to content

fix(tests): dtype-aware reference cast in test_from_save_pretrained_dtype_inference#13882

Open
Anai-Guo wants to merge 1 commit into
huggingface:mainfrom
Anai-Guo:fix/save-pretrained-dtype-inference-aware-cast
Open

fix(tests): dtype-aware reference cast in test_from_save_pretrained_dtype_inference#13882
Anai-Guo wants to merge 1 commit into
huggingface:mainfrom
Anai-Guo:fix/save-pretrained-dtype-inference-aware-cast

Conversation

@Anai-Guo

@Anai-Guo Anai-Guo commented Jun 8, 2026

Copy link
Copy Markdown

What does this PR do?

Fixes #13869.

test_from_save_pretrained_dtype_inference builds its reference output with a blanket model.to(dtype), then compares it against from_pretrained(tmp_path, torch_dtype=dtype). These two casts are not equivalent, so the comparison can diverge spuriously:

  • _keep_in_fp32_modules: from_pretrained keeps these modules in fp32 (see load_model_dict_into_meta in model_loading_utils.py), but model.to(dtype) casts them to dtype.
  • Non-persistent buffers (e.g. RoPE inv_freq created with an explicit dtype): these are not saved to the checkpoint, so from_pretrained never casts them and they keep the dtype assigned in __init__. model.to(dtype) casts them unconditionally.

Either case makes the in-memory reference differ from the loaded model and the output assertion fails even though save/load is correct.

This PR replaces the blanket cast with one that mirrors the loader's dtype-aware semantics: parameters and persistent buffers are cast to dtype (except _keep_in_fp32_modules, which stay fp32), while non-persistent buffers are left at their __init__ dtype.

Before submitting

  • Did you read the contributor guideline?
  • Did you make sure to update the documentation with your changes? (test-only change)
  • Did you write any new necessary tests? (this fixes an existing test)

Who can review?

@dg845 @DN6 (diagnosis from #13869 / #13862 review)

🤖 Generated with Claude Code

The reference model in test_from_save_pretrained_dtype_inference was cast with a blanket model.to(dtype), which diverges from from_pretrained(torch_dtype=dtype): the latter keeps _keep_in_fp32_modules in fp32 and leaves non-persistent buffers (e.g. RoPE inv_freq, not stored in the checkpoint) at their __init__ dtype. The blanket cast produced spurious output mismatches for models with either. Mirror the loader dtype-aware casting on the reference instead.
@github-actions github-actions Bot added fixes-issue tests size/S PR with diff < 50 LOC and removed fixes-issue labels Jun 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

size/S PR with diff < 50 LOC tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

fix underlying issue with test_from_save_pretrained_dtype_inference is that the model.to(dtype) cast at

1 participant