Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HQQ OOMs on large models #29

Closed
rationalism opened this issue Mar 27, 2024 · 12 comments
Closed

HQQ OOMs on large models #29

rationalism opened this issue Mar 27, 2024 · 12 comments
Labels
enhancement New feature or request

Comments

@rationalism
Copy link

Hey, I have a machine with two 4090 GPUs (24 GB VRAM each). When I try to run HQQ quantization of Llama-2-70B:

from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer

#Model and setttings
model_id      = 'meta-llama/Llama-2-70b-chat-hf'
compute_dtype = torch.float16
device        = 'cuda:0'

#Load model on the CPU
######################
model     = HQQModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype)
tokenizer = AutoTokenizer.from_pretrained(model_id) 

#Quantize the model
######################
from hqq.core.quantize import *
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device) 

the first half of the layers seem to work fine, but then it OOMs, presumably because it tries to put the entire quantized model on a single GPU device. For Llama-2-70B, I could try renting an A100 machine and that should work, but for even larger models (eg. Grok-1) it would be impossible to fit the entire thing on a single GPU. Is splitting quantization across multiple GPUs supported, or planned to be supported in the future? Thanks :)

@mobicham
Copy link
Collaborator

mobicham commented Mar 27, 2024

Hi @rationalism, yeah unfortunately loading automatically to multiple GPUs is not supported. Maybe you can try:

Otherwise, I can take a stab at it and see how to do it on 2 GPUs, or more generally how to do it automatically.

@mobicham mobicham added the enhancement New feature or request label Mar 27, 2024
@Minami-su
Copy link

Hey, I have a machine with two 4090 GPUs (24 GB VRAM each). When I try to run HQQ quantization of Llama-2-70B:

from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer

#Model and setttings
model_id      = 'meta-llama/Llama-2-70b-chat-hf'
compute_dtype = torch.float16
device        = 'cuda:0'

#Load model on the CPU
######################
model     = HQQModelForCausalLM.from_pretrained(model_id, torch_dtype=compute_dtype)
tokenizer = AutoTokenizer.from_pretrained(model_id) 

#Quantize the model
######################
from hqq.core.quantize import *
quant_config = BaseQuantizeConfig(nbits=4, group_size=64)
model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=device) 

the first half of the layers seem to work fine, but then it OOMs, presumably because it tries to put the entire quantized model on a single GPU device. For Llama-2-70B, I could try renting an A100 machine and that should work, but for even larger models (eg. Grok-1) it would be impossible to fit the entire thing on a single GPU. Is splitting quantization across multiple GPUs supported, or planned to be supported in the future? Thanks :)

Same problem.

@rationalism
Copy link
Author

@mobicham Thanks. With larger models like DBRX coming out this year, I think being able to split across multiple GPUs will be an important feature to manage that

https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm

@Sneakr
Copy link

Sneakr commented Apr 5, 2024

@mobicham I have rtx 4090 and 128gb ram, is it possible to load the original mixtral instruct and quantize it using HQQ ? Currently my script gets killed while loading as in the example file, I suppose I need to use your already quantized mixtral you linked , the 2-bit 4-bit , right?

@Sneakr
Copy link

Sneakr commented Apr 5, 2024

I managed to solve it by increasing WSL memory allocation and page swap file, nice! :)

@mobicham
Copy link
Collaborator

mobicham commented Apr 6, 2024

Yeah increasing the swap should do it, but it's gonna be slow.
Otherwise, you can use this branch of transformers that supports on-the-fly loading and HQQ quantization, so you don't need a lot of ram: huggingface/transformers#29637
Soon it will be integrated into transformers and you wouldn't face this memory issue, I just need to fix a couple of things for the pull request.

@catid
Copy link

catid commented Apr 20, 2024

Would love to be able to actually use this model lol: https://huggingface.co/catid/cat-llama-3-70b-hqq

Need support for device_map="auto"

model_id = 'catid/cat-llama-3-70b-hqq'

from hqq.engine.hf import HQQModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
model     = HQQModelForCausalLM.from_quantized(model_id)
(hqq) ➜  openai-hqq-server git:(main) ✗ python test.py
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Failed to load the weights
Traceback (most recent call last):
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/hqq/models/base.py", line 328, in from_quantized
    weights = cls.load_weights(save_dir)
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/hqq/models/base.py", line 195, in load_weights
    return torch.load(cls.get_weight_file(save_dir), map_location=map_location)
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/serialization.py", line 1026, in load
    return _load(opened_zipfile,
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/serialization.py", line 1438, in _load
    result = unpickler.load()
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/serialization.py", line 1408, in persistent_load
    typed_storage = load_tensor(dtype, nbytes, key, _maybe_decode_ascii(location))
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/serialization.py", line 1382, in load_tensor
    wrap_storage=restore_location(storage, location),
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/serialization.py", line 391, in default_restore_location
    result = fn(storage, location)
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/serialization.py", line 271, in _cuda_deserialize
    return obj.cuda(device)
  File "/home/catid/mambaforge/envs/hqq/lib/python3.10/site-packages/torch/_utils.py", line 115, in _cuda
    untyped_storage = torch.UntypedStorage(
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 112.00 MiB. GPU 0 has a total capacity of 23.64 GiB of which 64.81 MiB is free. Including non-PyTorch memory, this process has 23.57 GiB memory in use. Of the allocated memory 23.19 GiB is allocated by PyTorch, and 9.97 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

@mobicham
Copy link
Collaborator

Yeah I am aware of this, there should be a simple fix but I've been very busy with other things. I hope I will have the time to take a look at it in the next days. Sorry for the delay!

@mobicham
Copy link
Collaborator

You can now shard quantized models on multiple gpus. Just pass the list of devices as a list like this:

model.quantize_model(quant_config=quant_config, compute_dtype=compute_dtype, device=['cuda:0', 'cuda:1'])

You still need to have the main model on CPU before quantizing. Will see how to dynamically dispatch directly to the gpu.

@catid
Copy link

catid commented Apr 23, 2024

That worked thanks, just in time! https://huggingface.co/catid/cat-llama-3-70b-san66-hqq

@mobicham
Copy link
Collaborator

@catid making it work with "from_quantized" would require some additional work. But if you quantize directly it should work fine, as long as it's an official HF model that follows the same layer naming logic.

@mobicham
Copy link
Collaborator

mobicham commented May 3, 2024

Closing this since HQQ now is integrated with transformers: huggingface/transformers#29637

@mobicham mobicham closed this as completed May 3, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants