Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for loading GPTQ models on CPU #26719

Merged
merged 4 commits into from
Oct 31, 2023

Conversation

vivekkhandelwal1
Copy link
Contributor

Right now, we can only load the GPTQ Quantized model on the CUDA device. The flag load_gptq_on_cpu adds the support to load the GPTQ models on the CPU. The larger variants of the model are hard to load/run/trace on the GPU and that's the rationale behind adding this flag.

Signed-Off By: Vivek Khandelwal vivek@nod-labs.com

@vivekkhandelwal1
Copy link
Contributor Author

@SunMarc @younesbelkada Please take a look at this PR.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great contribution, I did not managed to succesfully run a GPTQ model using your branch:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

model_id = "ybelkada/opt-125m-gptq-4bit"

quantization_config = GPTQConfig(
    bits=4,
    disable_exllama=True
)

model = AutoModelForCausalLM.from_pretrained(model_id, load_gptq_on_cpu=True, quantization_config=quantization_config, torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id)

text = "Hello my name is"
input_ids = tokenizer(text, return_tensors="pt").input_ids

out = model.generate(input_ids)
print(tokenizer.decode(out[0], skip_special_tokens=True))

Can you share a small snippet you used to test out your implementation?

@vivekkhandelwal1
Copy link
Contributor Author

Can you share a small snippet you used to test out your implementation?

Hi @younesbelkada, here's a code snippet for the Falcon-180B-Chat-GPTQ model. Right now, I'm working on this and I get it running through this.

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch

model_name_or_path = "TheBloke/Falcon-180B-Chat-GPTQ"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             revision="main",
                                             load_gptq_on_cpu=True)
model = model.to(torch.float32)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

prompt = "Tell me about AI"
prompt_template=f'''User: {prompt}
Assistant: '''

print("\n\n*** Generate:")

input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids
output = model.generate(inputs=input_ids, do_sample=True, temperature=0.7, max_new_tokens=512)
print(tokenizer.decode(output[0]))

print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.15
)

print(pipe(prompt_template)[0]['generated_text'])

@vivekkhandelwal1
Copy link
Contributor Author

Thanks for the great contribution, I did not managed to succesfully run a GPTQ model using your branch:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

model_id = "ybelkada/opt-125m-gptq-4bit"

quantization_config = GPTQConfig(
    bits=4,
    disable_exllama=True
)

model = AutoModelForCausalLM.from_pretrained(model_id, load_gptq_on_cpu=True, quantization_config=quantization_config, torch_dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id)

text = "Hello my name is"
input_ids = tokenizer(text, return_tensors="pt").input_ids

out = model.generate(input_ids)
print(tokenizer.decode(out[0], skip_special_tokens=True))

Can you share a small snippet you used to test out your implementation?

You need to add model = model.to(torch.float32) before doing the inference. Additionally, you would require these changes: AutoGPTQ/AutoGPTQ#367

@vivekkhandelwal1
Copy link
Contributor Author

@younesbelkada, here's a smaller falcon GPTQ variant loading and executing fine, with the changes suggested above:

from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, GPTQConfig
import torch

model_name_or_path = "TheBloke/falcon-7b-instruct-GPTQ"
quantization_config = GPTQConfig(
    bits=4,
    disable_exllama=True
)

model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map="auto",
                                             trust_remote_code=True,
                                             revision="main",
                                             quantization_config=quantization_config,
                                             load_gptq_on_cpu=True)
model = model.to(torch.float32)

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)

prompt = "Tell me about AI"
prompt_template=f'''User: {prompt}
Assistant: '''

print("\n\n*** Generate:")

input_ids = tokenizer(prompt_template, return_tensors='pt').input_ids
output = model.generate(inputs=input_ids, do_sample=True, temperature=0.7, max_new_tokens=512)
print(tokenizer.decode(output[0]))

print("*** Pipeline:")
pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=512,
    temperature=0.7,
    do_sample=True,
    top_p=0.95,
    repetition_penalty=1.15
)

print(pipe(prompt_template)[0]['generated_text'])

@younesbelkada
Copy link
Contributor

OK thanks, I'll run some tests and report back here

@vivekkhandelwal1
Copy link
Contributor Author

OK thanks, I'll run some tests and report back here

Hi @younesbelkada, did you get a chance to look at this?

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1
I did not managed to run an inference with your PR on CPU, I have used another model since I don't have access to an instance that can support falcon-180B

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.bfloat16).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

I get:

    out = torch.matmul(x.half(), weight)
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

Can you let me know if you manage to run the snippet above?

@vivekkhandelwal1
Copy link
Contributor Author

Hi @vivekkhandelwal1 I did not managed to run an inference with your PR on CPU, I have used another model since I don't have access to an instance that can support falcon-180B

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.bfloat16).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

I get:

    out = torch.matmul(x.half(), weight)
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

Can you let me know if you manage to run the snippet above?

Hi @younesbelkada, I think you missed my comments above. Actually, to run the model on the CPU, you also need to convert the model to float. After loading the model, you have to do model = model.to(torch.float32). After doing this, you should be able to run the model.

@younesbelkada
Copy link
Contributor

Thanks @vivekkhandelwal1 , indeed I forgot to cast it to fp32, however running this script:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config).to(device)

model = model.to(torch.float32)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

Led to the same error, what version of optimum and auto-gptq are you using?

@vivekkhandelwal1
Copy link
Contributor Author

Thanks @vivekkhandelwal1 , indeed I forgot to cast it to fp32, however running this script:
....
Led to the same error, what version of optimum and auto-gptq are you using?

Hi @younesbelkada, can you please add this flag load_gptq_on_cpu=True while loading the model? The below code runs fine for me:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, load_gptq_on_cpu=True).to(device)

model = model.to(torch.float32)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

The result:

Downloading pytorch_model.bin: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 125M/125M [00:03<00:00, 37.4MB/s]
CUDA extension not installed.
Downloading (…)neration_config.json: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 137/137 [00:00<00:00, 1.10MB/s]
</s>Hello how are you?????

@vivekkhandelwal1
Copy link
Contributor Author

@younesbelkada, now I got it. I think making these changes AutoGPTQ/AutoGPTQ#367 locally would fix the issue for you.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! Makes sense now, I left two comments, what do you think?
Can you also confirm that loading the model in float32 works for you instead of first loading your model in fp16 then cast it back?

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

We strongly discourage users to perform dtype casting to quantized models and this is likely to be not supported anymore for GPTQ models: #26761

It would be also great if you can add a test to make sure generation works on CPU + GPTQ? The test should go in this file: https://github.com/huggingface/transformers/blob/6df9179c1c5b19674ce0a4b82d311ef833c02447/tests/quantization/gptq/test_gptq.py let me know if you need any help !

Also let's wait AutoGPTQ/AutoGPTQ#367 to be merged before merging this PR

@@ -2472,6 +2472,7 @@ def from_pretrained(
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
load_gptq_on_cpu = kwargs.pop("load_gptq_on_cpu", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
load_gptq_on_cpu = kwargs.pop("load_gptq_on_cpu", False)

@@ -2700,7 +2701,7 @@ def from_pretrained(
quantization_method_from_args == QuantizationMethod.GPTQ
or quantization_method_from_config == QuantizationMethod.GPTQ
):
if not torch.cuda.is_available():
if not load_gptq_on_cpu and not torch.cuda.is_available():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if not load_gptq_on_cpu and not torch.cuda.is_available():
gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
if not gptq_supports_cpu and not torch.cuda.is_available():

I think we can support that in a more elegant way, can we create an attribute gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2") similary as :

is_8bit_serializable = version.parse(importlib.metadata.version("bitsandbytes")) > version.parse("0.37.2")

Then we can safely remove load_gptq_on_cpu arg safely from from_pretrained. We should also force-set disable-exllama to True in that case

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget also to change this: https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_utils.py#L2721 to print the warning only in case not gptq_supports_cpu

@vivekkhandelwal1
Copy link
Contributor Author

vivekkhandelwal1 commented Oct 13, 2023

Can you also confirm that loading the model in float32 works for you instead of first loading your model in fp16 then cast it back?

@younesbelkada, first I tried loading the model in float32 only but that didn't work. Even after passing the torch_dtype as torch.float32, the model loaded had the weight in fp16. The same was the issue with your code snippet, and that's why we have to do the casting explicitly. If you think there's an issue, and it could be fixed then I would be happy to drop this patch.

Edit: Sorry, we can't drop this patch, since it enables loading the model on the CPU, but yeah .to is required if the torch_dtype is not working fine.

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1

Hmm this is a bit strange, running:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

for n, p in model.named_parameters():
    print(n, p.dtype)

returns:

model.decoder.embed_tokens.weight torch.float32
model.decoder.embed_positions.weight torch.float32
model.decoder.final_layer_norm.weight torch.float32
model.decoder.final_layer_norm.bias torch.float32
model.decoder.layers.0.self_attn_layer_norm.weight torch.float32
model.decoder.layers.0.self_attn_layer_norm.bias torch.float32
model.decoder.layers.0.final_layer_norm.weight torch.float32
model.decoder.layers.0.final_layer_norm.bias torch.float32
model.decoder.layers.1.self_attn_layer_norm.weight torch.float32
model.decoder.layers.1.self_attn_layer_norm.bias torch.float32
model.decoder.layers.1.final_layer_norm.weight torch.float32
model.decoder.layers.1.final_layer_norm.bias torch.float32
model.decoder.layers.2.self_attn_layer_norm.weight torch.float32
model.decoder.layers.2.self_attn_layer_norm.bias torch.float32
model.decoder.layers.2.final_layer_norm.weight torch.float32
model.decoder.layers.2.final_layer_norm.bias torch.float32
model.decoder.layers.3.self_attn_layer_norm.weight torch.float32
model.decoder.layers.3.self_attn_layer_norm.bias torch.float32
model.decoder.layers.3.final_layer_norm.weight torch.float32
model.decoder.layers.3.final_layer_norm.bias torch.float32
model.decoder.layers.4.self_attn_layer_norm.weight torch.float32
model.decoder.layers.4.self_attn_layer_norm.bias torch.float32
model.decoder.layers.4.final_layer_norm.weight torch.float32
model.decoder.layers.4.final_layer_norm.bias torch.float32
model.decoder.layers.5.self_attn_layer_norm.weight torch.float32
model.decoder.layers.5.self_attn_layer_norm.bias torch.float32
model.decoder.layers.5.final_layer_norm.weight torch.float32
model.decoder.layers.5.final_layer_norm.bias torch.float32
model.decoder.layers.6.self_attn_layer_norm.weight torch.float32
model.decoder.layers.6.self_attn_layer_norm.bias torch.float32
model.decoder.layers.6.final_layer_norm.weight torch.float32
model.decoder.layers.6.final_layer_norm.bias torch.float32
model.decoder.layers.7.self_attn_layer_norm.weight torch.float32
model.decoder.layers.7.self_attn_layer_norm.bias torch.float32
model.decoder.layers.7.final_layer_norm.weight torch.float32
model.decoder.layers.7.final_layer_norm.bias torch.float32
model.decoder.layers.8.self_attn_layer_norm.weight torch.float32
model.decoder.layers.8.self_attn_layer_norm.bias torch.float32
model.decoder.layers.8.final_layer_norm.weight torch.float32
model.decoder.layers.8.final_layer_norm.bias torch.float32
model.decoder.layers.9.self_attn_layer_norm.weight torch.float32
model.decoder.layers.9.self_attn_layer_norm.bias torch.float32
model.decoder.layers.9.final_layer_norm.weight torch.float32
model.decoder.layers.9.final_layer_norm.bias torch.float32
model.decoder.layers.10.self_attn_layer_norm.weight torch.float32
model.decoder.layers.10.self_attn_layer_norm.bias torch.float32
model.decoder.layers.10.final_layer_norm.weight torch.float32
model.decoder.layers.10.final_layer_norm.bias torch.float32
model.decoder.layers.11.self_attn_layer_norm.weight torch.float32
model.decoder.layers.11.self_attn_layer_norm.bias torch.float32
model.decoder.layers.11.final_layer_norm.weight torch.float32
model.decoder.layers.11.final_layer_norm.bias torch.float32
  • inspecting some QuantLinear modules I can't see any attribute in fp16. Let me get back to you on this

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1 when building auto-gptq using AutoGPTQ/AutoGPTQ#367 + running your script above I am facing.

  File "../qlinear_exllamav2.py", line 137, in post_init
    assert self.qweight.device.type == "cuda"
AssertionError

Do you think there is anything I missed on my end?

@vivekkhandelwal1
Copy link
Contributor Author

Hi @vivekkhandelwal1 when building auto-gptq using PanQiWei/AutoGPTQ#367 + running your script above I am facing.

  File "../qlinear_exllamav2.py", line 137, in post_init
    assert self.qweight.device.type == "cuda"
AssertionError

Do you think there is anything I missed on my end?

Can you please try building the auto_gptq with:

BUILD_CUDA_EXT=0 pip install -v .

Right now, we can only load the GPTQ Quantized model on the CUDA
device. The attribute `gptq_supports_cpu` checks if the current
auto_gptq version is the one which has the cpu support for the
model or not.
The larger variants of the model are hard to load/run/trace on
the GPU and that's the rationale behind adding this attribute.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
@vivekkhandelwal1
Copy link
Contributor Author

Hi @younesbelkada, my Auto-GPTQ PR AutoGPTQ/AutoGPTQ#367 is merged. Can we go ahead with this?

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1 let me have another look and get back to you

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1 , thanks again!
I made AutoGPTQ/AutoGPTQ#376 in order for this script:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

to work. However, I found it quite slow to get generations on CPU, out of curiosity how long does it takes to generate text on falcon 180B on your end?

@vivekkhandelwal1
Copy link
Contributor Author

vivekkhandelwal1 commented Oct 23, 2023

However, I found it quite slow to get generations on CPU, out of curiosity how long does it takes to generate text on falcon 180B on your end?

Hi @younesbelkada, the performance of Falcon-180B-GPTQ on the CPU is quite slow. We lower this model through Torch-MLIR and then compile via IREE to run over the CUDA and the Rocm backend. The Torch-MLIR doesn't support cuda tensors that's why we need to have this model loaded over the CPU.

Also, were you getting some errors even after my patch was merged in Auto-GPTQ? I mean, what's the reason behind this https://github.com/PanQiWei/AutoGPTQ/pull/376/files.

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1
Yes I was getting some errors even with your patch, in my script I don't call to() with a torch dtype but directly load in fp32 with torch_dtype=torch.float32, for some reasons some weights still remained in fp32. Note we don't support explicit model casting for quantized weights since: #26761

@vivekkhandelwal1
Copy link
Contributor Author

Hi @vivekkhandelwal1 Yes I was getting some errors even with your patch, in my script I don't call to() with a torch dtype but directly load in fp32 with torch_dtype=torch.float32, for some reasons some weights still remained in fp32. Note we don't support explicit model casting for quantized weights since: #26761

I'm also not doing the explicit model casting and without that it worked for me.

@vivekkhandelwal1
Copy link
Contributor Author

vivekkhandelwal1 commented Oct 23, 2023

Author

I have made two more changes(https://github.com/PanQiWei/AutoGPTQ/pull/367/files#diff-c7731808a14c99106c4a0e48729c7435181a1085978ca567433d28a5f2473b1dL272) in the Auto-GPTQ if you try the complete changes of this https://github.com/PanQiWei/AutoGPTQ/pull/367/files, then your patch is not required. I have tried it just now for the script you shared.

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1
I cloned again auto-gptq and I can confirm I am on the main branch git log gives me your commit and I uninstalled and installed again auto-gptq with BUILD_CUDA_EXT=0 pip install -U -v . and ran into:

s/qlinear/qlinear_cuda_old.py", line 269, in forward
    out = torch.matmul(x.to(weight.dtype), weight)
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

Sharing again the snippet I use:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

Let me know if I missed anything!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint.

@vivekkhandelwal1
Copy link
Contributor Author

Hi @vivekkhandelwal1 I cloned again auto-gptq and I can confirm I am on the main branch git log gives me your commit and I uninstalled and installed again auto-gptq with BUILD_CUDA_EXT=0 pip install -U -v . and ran into:

s/qlinear/qlinear_cuda_old.py", line 269, in forward
    out = torch.matmul(x.to(weight.dtype), weight)
RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'

Sharing again the snippet I use:

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig

checkpoint = "marcsun13/opt-350m-gptq-4bit"
device = "cpu" # for GPU usage or "cpu" for CPU usage

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

quantization_config = GPTQConfig(bits=4, disable_exllama=True)

model = AutoModelForCausalLM.from_pretrained(checkpoint, low_cpu_mem_usage=True, quantization_config=quantization_config, torch_dtype=torch.float32).to(device)

inputs = tokenizer.encode("Hello how are you?", return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=4, do_sample=False)
print(tokenizer.decode(outputs[0]))

Let me know if I missed anything!

Hi @younesbelkada, I just tried your script on another machine with a fresh setup. I cloned the auto_gptq repo built it and made the changes in transformers same as this PR, and it worked for me. If you still have some issues locally which are fixed by your changes in Auto-GPTQ then we can do that as well. But, can we now go ahead with this patch? If you want then we can discuss this on a call, I can help you set this up and try.

@younesbelkada
Copy link
Contributor

Thanks a lot @vivekkhandelwal1
The changes look great to me, I tried it again :D and still facing the same issue :/ I suggest we wait AutoGPTQ/AutoGPTQ#376 to be merged before merging this PR
Can you also add one line in the documentation explaining that you can perform CPU inference with auto-gptq? the new paragraph can come right after: https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/quantization.md#exllama-kernels-for-faster-inference - what do you think?

@vivekkhandelwal1
Copy link
Contributor Author

Thanks a lot @vivekkhandelwal1 The changes look great to me, I tried it again :D and still facing the same issue :/ I suggest we wait PanQiWei/AutoGPTQ#376 to be merged before merging this PR Can you also add one line in the documentation explaining that you can perform CPU inference with auto-gptq? the new paragraph can come right after: https://github.com/huggingface/transformers/blob/main/docs/source/en/main_classes/quantization.md#exllama-kernels-for-faster-inference - what do you think?

@younesbelkada done! Also, I'm still not sure if your patch is needed or not, also the GPTQ guys don't seem convinced about getting your patch in. Can we get this in so that I'm unblocked on this, and it works for me as well?

@younesbelkada
Copy link
Contributor

Hi @vivekkhandelwal1
Thanks! I think @SunMarc managed to reproduce the same issue that I had on his end, so I prefer to wait before AutoGPTQ/AutoGPTQ#385 gets merged, let me ping @fxmarty internally and see if we can get this merged ASAP
Thanks again!

@vivekkhandelwal1
Copy link
Contributor Author

Hi @vivekkhandelwal1 Thanks! I think @SunMarc managed to reproduce the same issue that I had on his end, so I prefer to wait before PanQiWei/AutoGPTQ#385 gets merged, let me ping @fxmarty internally and see if we can get this merged ASAP Thanks again!

Sure @younesbelkada!

@vivekkhandelwal1
Copy link
Contributor Author

@younesbelkada, can we now merge this PR since the Auto-GPTQ PRs are merged now? AutoGPTQ/AutoGPTQ#385
huggingface/optimum#1496

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you very much for your great work! @vivekkhandelwal1

@vivekkhandelwal1
Copy link
Contributor Author

Thank you very much for your great work! @vivekkhandelwal1

Thanks, @younesbelkada for the support and feedback. I don't have merge access to the repo, can you please merge it for me?

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this!

@amyeroberts amyeroberts merged commit 2963e19 into huggingface:main Oct 31, 2023
21 checks passed
@younesbelkada
Copy link
Contributor

Thanks again for all your work on this @vivekkhandelwal1 !

EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
* Add support for loading GPTQ models on CPU

Right now, we can only load the GPTQ Quantized model on the CUDA
device. The attribute `gptq_supports_cpu` checks if the current
auto_gptq version is the one which has the cpu support for the
model or not.
The larger variants of the model are hard to load/run/trace on
the GPU and that's the rationale behind adding this attribute.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

* Update quantization.md

* Update quantization.md

* Update quantization.md
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants