Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[model loading] framework-agnostic dtype parameter #13246

Open
stas00 opened this issue Aug 24, 2021 · 3 comments
Open

[model loading] framework-agnostic dtype parameter #13246

stas00 opened this issue Aug 24, 2021 · 3 comments
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress

Comments

@stas00
Copy link
Contributor

stas00 commented Aug 24, 2021

This is a split off from one of the discussions at #13209:

  1. It all started with trying to load torch models under either the desired dtype or the the dtype of the pretrained model - and thus avoid 2x memory usage needs e.g. if the model needs to be just fp16. So we added torch_dtype to from_pretrained and from_config.
  2. Then we started storing torch_dtype in the config file for future possibly automatic loading model in the optimal "regime".
  3. This resulted in a discrepancy where the same symbol sometimes means torch.dtype at other times a string like "float32" as we can't store torch.dtype in json.
  4. then in fix AutoModel.from_pretrained(..., torch_dtype=...) #13209 (comment) we started discussing how dtype is really the same across pt/tf/flux and perhaps we should just use dtype in the config and variables and have it consistently to be a string ("float32") and convert it to the right dtype object of the desired framework at the point of use, e.g. getattr(torch, "float32")

A possible solution is to deprecate torch_dtype and replace it with dtype string both in config and in the function argument.

Possible conflicts with the naming:

  1. we already have the dtype attribute in modeling_utils, which returns torch.dtype based on the first param's dtype.

    https://github.com/huggingface/transformers/blob/master/src/transformers/modeling_utils.py#L205

    The context is different, but still this is something to consider to avoid ambiguity.

I may have missed some other areas. So please share if something else needs to be added.

Additional notes:

#13098 - the idea of the PR is exactly to disentangle parameter dtype from matmul/computation dtype. In Flax, it's common practice that the dtype parameter defines the matmul/computation dtype, see: https://flax.readthedocs.io/en/latest/_autosummary/flax.linen.Dense.html#flax.linen.Dense.dtype instead of the parameter dtype and not the parameter dtype.
So for Flax, I don't really think it would make sense to use a config.dtype to define weights dtype as it would be quite confusing with Flax's computation dtype parameter.

@LysandreJik, @sgugger, @patrickvonplaten

@LysandreJik
Copy link
Member

Would like to ping @Rocketknight1 regarding the TensorFlow management of types, and @patil-suraj for flax

@Rocketknight1
Copy link
Member

This should work in Tensorflow too - you can use tf.dtypes.as_dtype(dtype_string) to turn strings into TF dtype objects.

@Joy-Lunkad
Copy link

Joy-Lunkad commented Sep 2, 2021

@Rocketknight1 Sorry, but can you please elaborate on how to load the model in Tensorflow or point me in the right direction? I am new to hugging face and I have been looking all over for instructions on how to do it. Thank you.

@huggingface huggingface deleted a comment from github-actions bot Sep 27, 2021
@stas00 stas00 added the WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress label Sep 27, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
WIP Label your PR/Issue with WIP for some long outstanding Issues/PRs that are work in progress
Projects
None yet
Development

No branches or pull requests

4 participants