I think an underlying issue with test_from_save_pretrained_dtype_inference is that the model.to(dtype) cast at
|
model.to(dtype).save_pretrained(tmp_path) |
is not dtype-aware (it will cast everything to dtype), but from_pretrained(..., torch_dtype=dtype) is dtype-aware (it respects _keep_in_fp32_modules, etc.). This causes the behavior of the two to diverge in several scenarios:
- There are
_keep_in_fp32_modules specified on the model (the more common case)
- A non-persistent buffer like
inv_freq is created with an explicit dtype (which is the case here)
which leads to a divergence between the outputs of model and model_loaded.
Originally posted by @dg845 in #13862 (comment)
Cc: @DN6
I think an underlying issue with
test_from_save_pretrained_dtype_inferenceis that themodel.to(dtype)cast atdiffusers/tests/models/testing_utils/common.py
Line 484 in 5d10b4d
is not dtype-aware (it will cast everything to
dtype), butfrom_pretrained(..., torch_dtype=dtype)is dtype-aware (it respects_keep_in_fp32_modules, etc.). This causes the behavior of the two to diverge in several scenarios:_keep_in_fp32_modulesspecified on themodel(the more common case)inv_freqis created with an explicit dtype (which is the case here)which leads to a divergence between the outputs of
modelandmodel_loaded.Originally posted by @dg845 in #13862 (comment)
Cc: @DN6