Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance 3.x torch WOQ load #1877

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open

Enhance 3.x torch WOQ load #1877

wants to merge 21 commits into from

Conversation

yuwenzho
Copy link
Collaborator

Type of Change

feature
API changed or not: no

Description

Use different WeightOnlyLinear module according to device.

  • Abstract WeightOnlyLinear class. Inherited class INCWeightOnlyLinear and HPUWeighOnlyLinear
  • Load woq linear weight module by module
  • save hpu format tensor to reuse it once load it again: huggingface format save to local 'hpu_model.safetensor' file; default format save to 'quantized_hpu_weight.pt' file

load huggingface WOQ model example:

from neural_compressor.torch.quantization import load

model_id = "TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ"
# first load: torch.nn.Linear -> INCWeightOnlyLinear -> HPUWeightOnlyLinear, 
# and then save hpu_model.safetensors to local cache dir
qmodel = load(model_name_or_path=model_id, format="huggingface", device="hpu")

# second load: torch.nn.Linear -> HPUWeightOnlyLinear using hpu_model.safetensors saved in local cache dir
qmodel = load(model_name_or_path=model_id, format="huggingface", device="hpu")

load INC WOQ model example:

from neural_compressor.torch.quantization import load

# first load: torch.nn.Linear -> INCWeightOnlyLinear -> HPUWeightOnlyLinear, 
# and then save quantized_hpu_weight.pt to 'saved_results' dir
qmodel = load("saved_results", original_model=fp32_model, device="hpu")

# second load: torch.nn.Linear -> HPUWeightOnlyLinear using quantized_hpu_weight.pt saved in 'saved_results' dir
qmodel = load("saved_results", original_model=fp32_model, device="hpu")

How has this PR been tested?

CI

Dependency Change?

No

Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
@yuwenzho yuwenzho added the PyTorch Related to PyTorch F/W label Jun 18, 2024
Copy link

github-actions bot commented Jun 18, 2024

⛈️ Required checks status: Has failure 🔴

Warning
If you do not have the access to re-run the Probot, please contact XuehaoSun for help. If you push a new commit, all of the workflow will be re-triggered.

Groups summary

🟢 Code Scan Tests workflow
Check ID Status Error details
Code-Scan success
Code-Scan (Bandit Code Scan Bandit) success
Code-Scan (DocStyle Code Scan DocStyle) success
Code-Scan (Pylint Code Scan Pylint) success

These checks are required after the changes to neural_compressor/torch/algorithms/weight_only/gptq.py, neural_compressor/torch/algorithms/weight_only/modules.py, neural_compressor/torch/algorithms/weight_only/rtn.py, neural_compressor/torch/algorithms/weight_only/save_load.py, neural_compressor/torch/quantization/load_entry.py, neural_compressor/torch/utils/environ.py, neural_compressor/torch/utils/utility.py.

🟢 Model Tests 3x workflow
Check ID Status Error details
Model-Test-3x success
Model-Test-3x (Generate Report GenerateReport) success
Model-Test-3x (Run PyTorch Model opt_125m_woq_gptq_int4) success
Model-Test-3x (Run PyTorch Model opt_125m_woq_gptq_int4_dq_bnb) success
Model-Test-3x (Run PyTorch Model opt_125m_woq_gptq_int4_dq_ggml) success

These checks are required after the changes to neural_compressor/torch/algorithms/weight_only/gptq.py, neural_compressor/torch/algorithms/weight_only/modules.py, neural_compressor/torch/algorithms/weight_only/rtn.py, neural_compressor/torch/algorithms/weight_only/save_load.py, neural_compressor/torch/quantization/load_entry.py, neural_compressor/torch/utils/environ.py, neural_compressor/torch/utils/utility.py.

🔴 Unit Tests 3x-PyTorch workflow
Check ID Status Error details
UT-3x-Torch failure
UT-3x-Torch (Coverage Compare CollectDatafiles) failure download
UT-3x-Torch (Unit Test 3x Torch Unit Test 3x Torch) success
UT-3x-Torch (Unit Test 3x Torch baseline Unit Test 3x Torch baseline) success

These checks are required after the changes to neural_compressor/torch/algorithms/weight_only/gptq.py, neural_compressor/torch/algorithms/weight_only/modules.py, neural_compressor/torch/algorithms/weight_only/rtn.py, neural_compressor/torch/algorithms/weight_only/save_load.py, neural_compressor/torch/quantization/load_entry.py, neural_compressor/torch/utils/environ.py, neural_compressor/torch/utils/utility.py, test/3x/torch/quantization/weight_only/test_autoround.py, test/3x/torch/quantization/weight_only/test_awq.py, test/3x/torch/quantization/weight_only/test_gptq.py, test/3x/torch/quantization/weight_only/test_load.py, test/3x/torch/quantization/weight_only/test_load_woq_hf_model.py, test/3x/torch/quantization/weight_only/test_rtn.py.


Thank you for your contribution! 💜

Note
This comment is automatically generated and will be updates every 180 seconds within the next 6 hours. If you have any other questions, contact chensuyue or XuehaoSun for help.

yuwenzho and others added 2 commits June 18, 2024 08:10
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Copy link
Collaborator

@xin3he xin3he left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adjustments after discussion:

  1. use skipif for hpu logics and avoid exposing WOQModelLoader.
  2. avoid using AutoRoundWeightOnlyLinear so that we can unpack and pack to HPUWeightOnlyLinear

@Kaihui-intel
Copy link
Collaborator

Abstract WeightOnlyLinear class. Inherited class INCWeightOnlyLinear and HPUWeighOnlyLinear
For cpu, how does the woq algorithm use abstract class WeightOnlyLinear ? Do we use INCweightonlinear instead of WeightOnlyLinear?

yuwenzho and others added 6 commits June 19, 2024 08:22
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <zyuwen@habana.ai>
Signed-off-by: yuwenzho <zyuwen@habana.ai>
os.path.abspath(os.path.expanduser(self.model_name_or_path)), WEIGHT_NAME
)
# if hpu format tensor can be used directly, then update qmodel_weight_file_path to the hpu format tensor file
if self._with_hpu_format_tensor():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest here to change the format to hpu and later on load the correct layer according to hpu format.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, we should use the format to indicate that the loaded model file is already in hpu format.
However, we think it is better to use the habana format because hpu is a device name.
Looking forward to your feedback.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I made a mistake. Thanks for correcting me, @yuwenzho.
The arg: format indicates the load API arguments format, not packing format. We are using arg: device to decide which packing format should be used. If the packing format is already hpu format, we should have some flag in config.json or else before uploading it to huggingface hub.

# load API format 1
load(model_name_or_path="saved_results", original_model=fp32_model, format="default")
#  load API format 2
load(model_name_or_path="TheBloke/TinyLlama-1.1B-python-v0.1-GPTQ", format="huggingface")

device_dict = {"cpu": INCWeightOnlyLinear, "hpu": HPUWeightOnlyLinear}

# if hpu format tensor can be used directly, then update mapping module to HPUWeightOnlyLinear
if self._with_hpu_format_tensor():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggest to add new format for HPU instead, it will be clearer. and then you can move format_dict and device_dictto global place ..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

format_dict and device_dict are moved to global place in 8001a0a

Copy link
Collaborator

@xin3he xin3he left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

neural_compressor/torch/utils/auto_accelerator.py Outdated Show resolved Hide resolved
yuwenzho and others added 2 commits June 20, 2024 08:53
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
@yuwenzho
Copy link
Collaborator Author

Abstract WeightOnlyLinear class. Inherited class INCWeightOnlyLinear and HPUWeighOnlyLinear For cpu, how does the woq algorithm use abstract class WeightOnlyLinear ? Do we use INCweightonlinear instead of WeightOnlyLinear?

Yes, algorithm should use INCweightonlinear. Fixed in 56c864f

yuwenzho and others added 2 commits June 20, 2024 09:10
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
@yuwenzho yuwenzho added the WIP label Jun 28, 2024
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
yuwenzho and others added 4 commits July 2, 2024 06:41
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Signed-off-by: yuwenzho <yuwen.zhou@intel.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
PyTorch Related to PyTorch F/W WIP
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants