Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PEFT] Fix save_pretrained to make sure adapters weights are also saved on TPU #29388

Merged
merged 5 commits into from Mar 14, 2024

Conversation

shub-kris
Copy link
Contributor

@shub-kris shub-kris commented Mar 1, 2024

Bug Fix for saving adapter weights when using PEFT

What does this PR do?

This PR fixes saving adapter weights when using PEFT on TPUs. Currently only the model weights are being saved and not the adapter weights.

I tested it locally with this change on this script and now it saves following whiles whenever checkpointing:

README.md            
adapter_model.safetensors  
rng_state.pth  
special_tokens_map.json  
tokenizer.model        
trainer_state.json
adapter_config.json  
optimizer.pt               
scheduler.pt   
tokenizer.json           
tokenizer_config.json  
training_args.bin

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Earlier discussed here

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

apter weights when using PEFT
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ArthurZucker ArthurZucker changed the title Fix for saving ad [PEFT] Fix save_pretrained to make sure adapters weights are also saved Mar 4, 2024
@ArthurZucker ArthurZucker changed the title [PEFT] Fix save_pretrained to make sure adapters weights are also saved [PEFT] Fix save_pretrained to make sure adapters weights are also saved on TPU Mar 4, 2024
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks 🤗

@@ -3035,9 +3035,10 @@ def _save_tpu(self, output_dir: Optional[str] = None):

# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all PushtoHubMixin can fall in the classes that support save_pretrained and from pretrained so we could also use that as both should inherit from the latter

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My fix was inspired from the code here:

supported_classes = (PreTrainedModel,) if not is_peft_available() else (PreTrainedModel, PeftModel)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea @ArthurZucker pushed it.

@moficodes
Copy link

Hello there, What is the state of the PR? Is there a timeline when it will be merged and a new release of transformer will be out?

@ArthurZucker
Copy link
Collaborator

Just waiting for @shub-kris to come back ( he is off ) and transformers release will be in around 2 weeks

@moficodes
Copy link

I ran some tests on a GKE Cluster with TPU V4 with 4 nodes.

https://gist.github.com/moficodes/1492228c80a3c08747a973b519cc7cda

This run fails with

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 13, in storage_ptr
    return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "//fsdp.py", line 112, in <module>
    model.save_pretrained(new_model_id)
  File "/usr/local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2448, in save_pretrained
    safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 281, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 470, in _flatten
    shared_pointers = _find_shared_tensors(tensors)
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 72, in _find_shared_tensors
    if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 17, in storage_ptr
    return tensor.storage().data_ptr()
  File "/usr/local/lib/python3.10/site-packages/torch/storage.py", line 956, in data_ptr
    return self._data_ptr()
  File "/usr/local/lib/python3.10/site-packages/torch/storage.py", line 960, in _data_ptr
    return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

That looks like the original error. So not certain if the cause of the error was resolved.

@shub-kris
Copy link
Contributor Author

shub-kris commented Mar 11, 2024

I ran some tests on a GKE Cluster with TPU V4 with 4 nodes.

https://gist.github.com/moficodes/1492228c80a3c08747a973b519cc7cda

This run fails with

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 13, in storage_ptr
    return tensor.untyped_storage().data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "//fsdp.py", line 112, in <module>
    model.save_pretrained(new_model_id)
  File "/usr/local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2448, in save_pretrained
    safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 281, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 470, in _flatten
    shared_pointers = _find_shared_tensors(tensors)
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 72, in _find_shared_tensors
    if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
  File "/usr/local/lib/python3.10/site-packages/safetensors/torch.py", line 17, in storage_ptr
    return tensor.storage().data_ptr()
  File "/usr/local/lib/python3.10/site-packages/torch/storage.py", line 956, in data_ptr
    return self._data_ptr()
  File "/usr/local/lib/python3.10/site-packages/torch/storage.py", line 960, in _data_ptr
    return self._untyped_storage.data_ptr()
RuntimeError: Attempted to access the data pointer on an invalid python storage.

That looks like the original error. So not certain if the cause of the error was resolved.

Hi @moficodes thanks for flagging this error, but on an initial glance, it doesn't look like the problem that this PR addresses. This PR aims to save the adapter weights, which were not being saved before this PR.

So, if you would have used trainer with this change it would save the adapter-weights too:

README.md            
adapter_model.safetensors  
rng_state.pth  
special_tokens_map.json  
tokenizer.model        
trainer_state.json
adapter_config.json  
optimizer.pt               
scheduler.pt   
tokenizer.json           
tokenizer_config.json  
training_args.bin

Earlier, adapter_model.safetensors and adapter_config.json were not being saved. A simple script to demonstrate is here which you can run by

export XLA_USE_BF16=1 PJRT_DEVICE=TPU XLA_USE_SPMD=1  HF_TOKEN=<your-HF-TOKEN>
python save-gemma.py

So, it might happen that the error you are encountering is unrelated to what this PR tries to fix.

@moficodes
Copy link

I see. Will open a separate issue for it then.

Thank you!

@moficodes
Copy link

The error happens on the same line though. model.save_pretrained(new_model_id)

@shub-kris
Copy link
Contributor Author

@LysandreJik can we merge this if it looks good to you, since @ArthurZucker is on holidays and I made the changes he asked and have tested it too locally.

@shub-kris
Copy link
Contributor Author

@moficodes answered it here: #29608 (comment)

@amyeroberts
Copy link
Collaborator

@shub-kris Based on reviews and code, we can merge. There's currently a failing test which needs to be resolved first. Could you try rebasing on main to make sure you have all the latest updates, and trigger a fresh CI run?

@shub-kris
Copy link
Contributor Author

shub-kris commented Mar 13, 2024

@amyeroberts Than you for looking into the PR. I have rebased but some checks are still failing because of this I guess: https://github.com/huggingface/transformers/runs/22627188153

@amyeroberts
Copy link
Collaborator

@shub-kris Yep - a fix has just been merged into main. Apologies for the disruption. Could you try rebasing again?

@amyeroberts amyeroberts merged commit c9e3c0b into huggingface:main Mar 14, 2024
21 checks passed
@shub-kris
Copy link
Contributor Author

Thanks a lot @amyeroberts

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants