Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle shared layers in save_torch_state_dict + add save_torch_model #2373

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

Wauplin
Copy link
Contributor

@Wauplin Wauplin commented Jul 4, 2024

Partially resolve #2065.
Follow-up PR after #2314.

In #2314, we introduce save_torch_state_dict. This new PR:

  • adds logic to deduplicate shared layers in safetensors. This is mostly taken from safetensors's torch helpers (see here). See slack thread (private) for discussions around this. See also https://huggingface.co/docs/safetensors/torch_shared_tensors for more details.
  • adds save_torch_model to directly save a torch nn.Module
  • renames internal methods get_tf_storage_size / get_torch_storage_size and make them public + documented
  • tests and documentation have also been updated.

A last follow-up PR should had load_torch_state_dict / load_torch_model helpers as well to correctly reload those files, including the shared layers.

I'm pinging transformers/accelerate/diffusers cores maintainers for visibility as well. Feel free to comment if someone should be done differently.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

The main helper of the `serialization` module takes a state dictionary as input (e.g. a mapping between layer names and related tensors), splits it into several shards while creating a proper index in the process and save everything to disk. At the moment, only `torch` tensors are supported. Under the hood, it delegates the logic to split the state dictionary to [`split_torch_state_dict_into_shards`].
The main helper of the `serialization` module takes a torch `nn.Module` as input and saves it to disk. It handles the logic to save shared tensors (see [safetensors explanation](https://huggingface.co/docs/safetensors/torch_shared_tensors)) as well as logic to split the state dictionary into shards, using [`split_torch_state_dict_into_shards`] under the hood. At the moment, only `torch` framework is supported.

If you want to save a state dictionary (e.g. a mapping between layer names and related tensors) instead of a `nn.Module`, you can use [`save_torch_state_dict`] which provides the same features. This is useful for example if you want to apply custom logic to the state dict before saving it.
Copy link
Member

@sayakpaul sayakpaul Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the point of mentioning this but also I think for the Torch community, it's fairly standard practice to ship the model classes and their state dictionaries (i.e., the parameters) separately unlike TensorFlow/Keras, for example.

)


def get_tensor_size(tensor: "tf.Tensor") -> int:
def get_tf_storage_size(tensor: "tf.Tensor") -> int:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have all the equivalent torch methods for TensorFlow? Or is that not necessary?

Copy link
Contributor Author

@Wauplin Wauplin Jul 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet no. Let's build for torch first and then expand to TF after if needed. For now for TF we have the logic to split a state dict into shards but nothing to save to disk.

Comment on lines +249 to +251
"metadata": {**state_dict_split.metadata, **metadata},
"weight_map": state_dict_split.tensor_to_filename,
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there be any sanity check on the additional metadata if not already done?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

metadata is at the discretion of the frameworks that will use it (transformers/diffusers/accelerate). In practice, I don't think it'll be much used. In any case, we can't really do sanity check since we are supposed to accept anything that is jsonable.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understanding the full scope of the PR is still a little farfetched for me but I left some clarification questions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement save_state_dict and load_state_dict in serialization module
3 participants