Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Load 8-bit quantized models for eval after fine-tuning #3606

Merged
merged 8 commits into from
Sep 15, 2023

Conversation

jeffkinnison
Copy link
Contributor

Errors

After training or fine-tuning, the best model checkpoint is loaded for evaluation. When loading an 8-bit quantized model that was fine-tuned on GPU, the following errors occur:

  1. With no handling, the call to load_state_dict here raises RuntimeError: Loading a quantized checkpoint into non-quantized Linear8bitLt is not supported. Please call module.cuda() before module.load_state_dict()"
  2. If 1. is handled, a number of unexpected keys are returned from load_state_dict and an AssertionError is raised

Causes

These issues can be reproduced by running tests/integration_tests/test_llm.py::test_llm_finetuning_strategies with 8-bit quantization. Both issues are the result of custom handling in bitsandbytes. They are caused by

  1. Moving an 8-bit parameter object to GPU creates a number of metadata matrices behind the scenes. These are added to model state on the fly during the move to GPU, and thus do not exist in a version of the model that has not been put on GPU. A check during load_state_dict in bitsandbytes raises the RuntimeError.
  2. When saving a state dict, bitsandbytes adds a number of weight_format entries to the state dict behind the scenes. These are metadata entries that are used in load_state_dict to reconstruct the quantized parameters. Since these weight_format entries are never registered in model state, on load they are returned in the unexpected_keys list. On load for eval, we assert that no unexpected keys were returned.

Workaround

This update puts in a workaround that addresses both issues. For 8-bit quantized models only, at the call to load_state_dict we first move the model to GPU and back to solve 1., then we ensure that the only unexpected keys are weight_format keys to handle 2. This should unblock 8-bit quantization for the time being, though we should double-check model quality.

@github-actions
Copy link

github-actions bot commented Sep 13, 2023

Unit Test Results

  6 files  ±0    6 suites  ±0   38m 58s ⏱️ - 6m 2s
31 tests ±0  26 ✔️ ±0    5 💤 ±0  0 ±0 
82 runs  ±0  66 ✔️ ±0  16 💤 ±0  0 ±0 

Results for commit f037e0f. ± Comparison against base commit d15a0c5.

♻️ This comment has been updated with latest results.

Comment on lines 891 to 893
if torch.cuda.is_available():
self.model.model.cuda()
self.model.model.cpu()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooof, maybe one callout here might be that the .cuda() call is unique and overriden for the Linear8Bit layers which internally does some stuff for 8BitParameters?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's not clear to me why we need to move to GPU then back to CPU like this. Comment would be great so I don't need to read the full PR description.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a comment and removed the move to CPU. It turns out the model was on GPU all along: self.model.device reports that it is on CPU, but self.model.model.device and deeper modules in the model all report that that they are on GPU.

only_weights_format_keys = ["weights_format" in k for k in unexpected_keys]
assert (
unexpected_keys == [] or only_weights_format_keys
), f"Unexpected keys found in state dict: {unexpected_keys}"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add something about the only_weights_format_keys to the error message.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added in b8b487d

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, am I missing something, I still don't see anything in the assert message about it?

# to a RuntimeError in `load_state_dict`. Explicitly call `model.cuda()` to make sure the
# matrices are part of model state. This workaround is necessary because the matrices are
# deleted during the model's forward pass.
if self.device == torch.device("cuda"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check self.device.type == "cuda" as the device might be cuda:0, etc.

@tgaddair tgaddair merged commit fe2f306 into master Sep 15, 2023
16 of 17 checks passed
@tgaddair tgaddair deleted the 8bit-quant-load-error branch September 15, 2023 04:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants