Skip to content

Commit

Permalink
Add support for loading GPTQ models on CPU
Browse files Browse the repository at this point in the history
Right now, we can only load the GPTQ Quantized model on the CUDA
device. The flag `load_gptq_on_cpu` adds the support to load the
GPTQ models on the CPU. The larger variants of the model are hard
to load/run/trace on the GPU and that's the rationale behind adding
this flag.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>
  • Loading branch information
vivekkhandelwal1 committed Oct 10, 2023
1 parent 592f2ea commit fbb7fbc
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2472,6 +2472,7 @@ def from_pretrained(
adapter_kwargs = kwargs.pop("adapter_kwargs", {})
adapter_name = kwargs.pop("adapter_name", "default")
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
load_gptq_on_cpu = kwargs.pop("load_gptq_on_cpu", False)

if is_fsdp_enabled():
low_cpu_mem_usage = True
Expand Down Expand Up @@ -2700,7 +2701,7 @@ def from_pretrained(
quantization_method_from_args == QuantizationMethod.GPTQ
or quantization_method_from_config == QuantizationMethod.GPTQ
):
if not torch.cuda.is_available():
if not load_gptq_on_cpu and not torch.cuda.is_available():
raise RuntimeError("GPU is required to quantize or run quantize model.")
elif not (is_optimum_available() and is_auto_gptq_available()):
raise ImportError(
Expand Down

0 comments on commit fbb7fbc

Please sign in to comment.