Skip to content

Commit

Permalink
feat: add support for loading quantisation from config.json (#1363)
Browse files Browse the repository at this point in the history
* feat: add support for loading quantisation from config.json

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: address comments

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* chore: remove unused type hints

* chore: capture exception and raise correspondingly

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

* fix(style): oops forgot to run style

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>

---------

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
  • Loading branch information
aarnphm committed Sep 8, 2023
1 parent 058c72d commit c631387
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,8 +635,17 @@ def load_quantized_model(
device_map = {"": torch.cuda.current_device()}
logger.info("The device_map was not initialized." "Setting device_map to `{'':torch.cuda.current_device()}`.")

with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f:
quantize_config_dict = json.load(f)
# this branch will check if model is from huggingface
try:
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
quantize_config_dict = model.config.quantization_config.to_dict()
else:
with open(os.path.join(save_folder, quant_config_name), "r", encoding="utf-8") as f:
quantize_config_dict = json.load(f)
except Exception as err:
raise ValueError(
f"Failed to load quantization config from {save_folder} (lookup for traceback): {err}\nTip: If the save directory is saved from a transformers.PreTrainedModel, make sure that `config.json` contains a 'quantization_config' key."
) from err
quantizer = GPTQQuantizer.from_dict(quantize_config_dict)
quantizer.disable_exllama = disable_exllama
quantizer.max_input_length = max_input_length
Expand Down

0 comments on commit c631387

Please sign in to comment.