Skip to content

fix underlying issue with test_from_save_pretrained_dtype_inference is that the model.to(dtype) cast at #13869

@sayakpaul

Description

@sayakpaul

I think an underlying issue with test_from_save_pretrained_dtype_inference is that the model.to(dtype) cast at

model.to(dtype).save_pretrained(tmp_path)

is not dtype-aware (it will cast everything to dtype), but from_pretrained(..., torch_dtype=dtype) is dtype-aware (it respects _keep_in_fp32_modules, etc.). This causes the behavior of the two to diverge in several scenarios:

  1. There are _keep_in_fp32_modules specified on the model (the more common case)
  2. A non-persistent buffer like inv_freq is created with an explicit dtype (which is the case here)

which leads to a divergence between the outputs of model and model_loaded.

Originally posted by @dg845 in #13862 (comment)

Cc: @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions