-
Notifications
You must be signed in to change notification settings - Fork 27.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Hqq serialization #32379
Hqq serialization #32379
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
@SunMarc thank you very much! import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
device = 'cuda:0'
dtype = torch.float16
model_id = 'meta-llama/Meta-Llama-3-8B'
cache_dir = '.'
quant_config = HqqConfig(nbits=4, group_size=64, quant_zero=False, quant_scale=False, axis=1)
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
cache_dir=cache_dir,
device_map="cuda:0",
quantization_config=quant_config
)
#Test
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device='cuda:0')
with torch.no_grad():
out_ref = model.forward(input_tensor)
# Save
model.save_pretrained("quant_model")
del model
torch.cuda.empty_cache(); gc.collect();
#Load
model_loaded = AutoModelForCausalLM.from_pretrained(
'quant_model',
torch_dtype=dtype,
cache_dir=cache_dir,
device_map=device)
with torch.no_grad():
out = model_loaded.forward(input_tensor)
assert (out.logits - out_ref.logits).abs().mean() == 0 I will try with more models, especially larger models to see if it's working properly. Other than that, what is missing for an official merge? |
I fixed an overflow problem while encoding to safetensors mobiusml/hqq@7cd36a7 , now it works fine with 70B as well: import torch, gc
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
device = 'cuda:0'
dtype = torch.float16
model_id = 'meta-llama/Meta-Llama-3-70B'
cache_dir = '.'
quant_config = HqqConfig(nbits=2, group_size=1024, quant_zero=False, quant_scale=False, axis=1)
# fit in a single 24GB gpu for testing only
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=dtype,
cache_dir=cache_dir,
device_map="cuda:0",
quantization_config=quant_config
)
#Test
input_tensor = torch.zeros((1, 8), dtype=torch.int32, device=device)
with torch.no_grad():
out_ref = model.forward(input_tensor)
# In [6]: out_ref.logits
# Out[6]:
# tensor([[[ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# ...,
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766]]],
# device='cuda:0')
# Save
model.save_pretrained("quant_model")
del model
torch.cuda.empty_cache(); gc.collect();
#Load
model_loaded = AutoModelForCausalLM.from_pretrained(
'quant_model',
torch_dtype=dtype,
cache_dir=cache_dir,
device_map=device)
with torch.no_grad():
out = model_loaded.forward(input_tensor)
assert (out.logits - out_ref.logits).abs().mean() == 0
# tensor([[[ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# ...,
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766],
# [ 2.9355, -1.3359, 0.5991, ..., -0.4763, -0.4766, -0.4766]]],
# device='cuda:0') |
superseded by #33141 |
What does this PR do?
Fixed version of #32056
The
dispatch_model
issue is solved. However we still have the issue with sharded checkpoints.