Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CPU offloading for quantizing bigger models on smaller GPUs #22

Merged
merged 6 commits into from
Jul 4, 2023

Conversation

abhinavkulkarni
Copy link

Hi,

This PR has the following changes:

  1. Further refinement to the branch dev/more_models to do quantization layer by layer on GPU
  2. Using device_map="auto" and max_memory kwargs to do LM Evaluation on smaller GPUs if needed

Thanks!

Abhinav Kulkarni added 3 commits July 1, 2023 12:23
awq/entry.py Outdated
# Init model on GPUs:
kwargs = {"device_map": "balanced", "torch_dtype": torch.float16}
# Init model on CPU:
kwargs = {"torch_dtype": torch.float16, "low_cpu_mem_usage": True}
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

low_cpu_mem_usage is needed here as loading HF models on CPU with fp16 dtype is slow, more details here: https://huggingface.co/mosaicml/mpt-7b-instruct/discussions/6#6470dec93df93fddece5fcde and

awq/entry.py Outdated
# Move the model to GPU (as much as possible) for LM evaluation
kwargs = {
"torch_dtype": torch.float16,
"device_map": "auto",
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me know if you would want me to change device_map to balanced instead of auto.

@Sakits
Copy link
Collaborator

Sakits commented Jul 3, 2023

Hi @abhinavkulkarni,

Thank you for your work on this PR!
However, when I was reviewing and testing it, I encountered the following two issues:

  1. I received the following error when running the original model directly: RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
    This seems to be due to the code that moves the model to the GPU being inside the if branch of quantization. This should probably be outside the quantization if statement.

  2. When running the evaluation with fake quantization, the results I obtained were identical to those of the original model. This raises a question about whether the state_dict=model.state_dict() on line 167 in entry.py is functioning as intended. Additionally, when applying AWQ to models using GeLU, we add an additional scaling node, which may not be well supported in this context.

Have you encountered these issues as well? If I've misunderstood or missed anything, I'd appreciate your correction!

@abhinavkulkarni
Copy link
Author

abhinavkulkarni commented Jul 4, 2023

Hey @Sakits,

Thanks for your reply. I have made the necessary changes. Here's the output I got on RTX 3060 (12GB VRAM):

# LM Eval: Original model
$ python -m awq.entry --model_path mosaicml/mpt-7b-instruct \
        --max_memory 0:9GiB cpu:99GiB \
	--tasks wikitext

Time: 05:51
Results:
|  Task  |Version|    Metric     | Value |   |Stderr|
|--------|------:|---------------|------:|---|------|
|wikitext|      1|word_perplexity|10.8864|   |      |
|        |       |byte_perplexity| 1.5628|   |      |
|        |       |bits_per_byte  | 0.6441|   |      |

# LM Eval: Fake quantization
$ python -m awq.entry --model_path mosaicml/mpt-7b-instruct \
        --max_memory 0:9GiB cpu:99GiB \
	--tasks wikitext \
	--w_bit 4 --q_group_size 128 \
	--load_awq awq_cache/mpt-7b-instruct-w4-g128.pt \
	--q_backend fake

Time: 05:53
Results:
|  Task  |Version|    Metric     | Value |   |Stderr|
|--------|------:|---------------|------:|---|------|
|wikitext|      1|word_perplexity|11.2684|   |      |
|        |       |byte_perplexity| 1.5729|   |      |
|        |       |bits_per_byte  | 0.6534|   |      |

# LM Eval: Real quantization
$ python -m awq.entry --model_path mosaicml/mpt-7b-instruct \
    --tasks wikitext \
    --w_bit 4 --q_group_size 128 \
    --load_quant quant_cache/mpt-7b-instruct-w4-g128-awq.pt

Time: 06:52
Results:
|  Task  |Version|    Metric     | Value |   |Stderr|
|--------|------:|---------------|------:|---|------|
|wikitext|      1|word_perplexity|11.2696|   |      |
|        |       |byte_perplexity| 1.5729|   |      |
|        |       |bits_per_byte  | 0.6535|   |      |

@abhinavkulkarni abhinavkulkarni force-pushed the dev/more_models branch 2 times, most recently from 308d689 to d93bfb8 Compare July 4, 2023 05:39
@Sakits Sakits merged commit ab536fb into mit-han-lab:dev/more_models Jul 4, 2023
@Sakits
Copy link
Collaborator

Sakits commented Jul 4, 2023

Hi @abhinavkulkarni,

I've reviewed your changes and everything looks great!
Thank you very much for your valuable contributions to awq project!

@abhinavkulkarni abhinavkulkarni deleted the dev/more_models branch July 5, 2023 06:49
@abhinavkulkarni abhinavkulkarni restored the dev/more_models branch July 9, 2023 14:43
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants