-
Notifications
You must be signed in to change notification settings - Fork 471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle shared layers in save_torch_state_dict
+ add save_torch_model
#2373
base: main
Are you sure you want to change the base?
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`]. | ||
The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported. | ||
|
||
If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the point of mentioning this but also I think for the Torch community, it's fairly standard practice to ship the model classes and their state dictionaries (i.e., the parameters) separately unlike TensorFlow/Keras, for example.
) | ||
|
||
|
||
def get_tensor_size(tensor: "tf.Tensor") -> int: | ||
def get_tf_storage_size(tensor: "tf.Tensor") -> int: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we have all the equivalent torch methods for TensorFlow? Or is that not necessary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not yet no. Let's build for torch first and then expand to TF after if needed. For now for TF we have the logic to split a state dict into shards but nothing to save to disk.
"metadata": {**state_dict_split.metadata, **metadata}, | ||
"weight_map": state_dict_split.tensor_to_filename, | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should there be any sanity check on the additional metadata
if not already done?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metadata
is at the discretion of the frameworks that will use it (transformers/diffusers/accelerate). In practice, I don't think it'll be much used. In any case, we can't really do sanity check since we are supposed to accept anything that is jsonable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understanding the full scope of the PR is still a little farfetched for me but I left some clarification questions.
Partially resolve #2065.
Follow-up PR after #2314.
In #2314, we introduce
save_torch_state_dict
. This new PR:save_torch_model
to directly save a torchnn.Module
get_tf_storage_size
/get_torch_storage_size
and make them public + documentedA last follow-up PR should had
load_torch_state_dict
/load_torch_model
helpers as well to correctly reload those files, including the shared layers.I'm pinging transformers/accelerate/diffusers cores maintainers for visibility as well. Feel free to comment if someone should be done differently.