You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This is a split off from one of the discussions at #13209:
It all started with trying to load torch models under either the desired dtype or the the dtype of the pretrained model - and thus avoid 2x memory usage needs e.g. if the model needs to be just fp16. So we added torch_dtype to from_pretrained and from_config.
Then we started storing torch_dtype in the config file for future possibly automatic loading model in the optimal "regime".
This resulted in a discrepancy where the same symbol sometimes means torch.dtype at other times a string like "float32" as we can't store torch.dtype in json.
then in fix AutoModel.from_pretrained(..., torch_dtype=...) #13209 (comment) we started discussing how dtype is really the same across pt/tf/flux and perhaps we should just use dtype in the config and variables and have it consistently to be a string ("float32") and convert it to the right dtype object of the desired framework at the point of use, e.g. getattr(torch, "float32")
A possible solution is to deprecate torch_dtype and replace it with dtype string both in config and in the function argument.
Possible conflicts with the naming:
we already have the dtype attribute in modeling_utils, which returns torch.dtype based on the first param's dtype.
#13098 - the idea of the PR is exactly to disentangle parameter dtype from matmul/computation dtype. In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.
So for Flax, I don't really think it would make sense to use a config.dtype to define weights dtype as it would be quite confusing with Flax's computation dtype parameter.
@Rocketknight1 Sorry, but can you please elaborate on how to load the model in Tensorflow or point me in the right direction? I am new to hugging face and I have been looking all over for instructions on how to do it. Thank you.
This is a split off from one of the discussions at #13209:
torch_dtype
tofrom_pretrained
andfrom_config
.torch_dtype
in the config file for future possibly automatic loading model in the optimal "regime".torch.dtype
at other times a string like "float32" as we can't storetorch.dtype
in json.AutoModel.from_pretrained(..., torch_dtype=...)
#13209 (comment) we started discussing howdtype
is really the same across pt/tf/flux and perhaps we should just usedtype
in the config and variables and have it consistently to be a string ("float32") and convert it to the right dtype object of the desired framework at the point of use, e.g.getattr(torch, "float32")
A possible solution is to deprecate
torch_dtype
and replace it withdtype
string both in config and in the function argument.Possible conflicts with the naming:
we already have the
dtype
attribute in modeling_utils, which returnstorch.dtype
based on the first param's dtype.https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L205
The context is different, but still this is something to consider to avoid ambiguity.
I may have missed some other areas. So please share if something else needs to be added.
Additional notes:
AutoModel.from_pretrained(..., torch_dtype=...)
#13209 (comment)@LysandreJik, @sgugger, @patrickvonplaten
The text was updated successfully, but these errors were encountered: