Skip to content

Commit

Permalink
Merge branch 'master' into lyt/os
Browse files Browse the repository at this point in the history
  • Loading branch information
yintong-lu committed Sep 19, 2023
2 parents f8ad748 + 6c26635 commit 143e583
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 35 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Intel® Neural Compressor aims to provide popular model compression techniques s
as well as Intel extensions such as [Intel Extension for TensorFlow](https://github.com/intel/intel-extension-for-tensorflow) and [Intel Extension for PyTorch](https://github.com/intel/intel-extension-for-pytorch).
In particular, the tool provides the key features, typical examples, and open collaborations as below:

* Support a wide range of Intel hardware such as [Intel Xeon Scalable processor](https://www.intel.com/content/www/us/en/products/details/processors/xeon/scalable.html), [Intel Xeon CPU Max Series](https://www.intel.com/content/www/us/en/products/details/processors/xeon/max-series.html), [Intel Data Center GPU Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html), and [Intel Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) with extensive testing; support AMD CPU, ARM CPU, and NVidia GPU through ONNX Runtime with limited testing
* Support a wide range of Intel hardware such as [Intel Xeon Scalable Processors](https://www.intel.com/content/www/us/en/products/details/processors/xeon/scalable.html), [Intel Xeon CPU Max Series](https://www.intel.com/content/www/us/en/products/details/processors/xeon/max-series.html), [Intel Data Center GPU Flex Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/flex-series.html), and [Intel Data Center GPU Max Series](https://www.intel.com/content/www/us/en/products/details/discrete-gpus/data-center-gpu/max-series.html) with extensive testing; support AMD CPU, ARM CPU, and NVidia GPU through ONNX Runtime with limited testing

* Validate popular LLMs such as LLama2, [LLama](examples/onnxrt/nlp/huggingface_model/text_generation/llama/quantization/ptq_static), [MPT](https://github.com/intel/intel-extension-for-transformers/blob/main/examples/huggingface/pytorch/text-generation/quantization/README.md), [Falcon](https://github.com/intel/intel-extension-for-transformers/blob/main/examples/huggingface/pytorch/language-modeling/quantization/README.md), [GPT-J](/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/fx), [Bloom](/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/ipex/smooth_quant), [OPT](/examples/pytorch/nlp/huggingface_models/language-modeling/quantization/ptq_static/ipex/smooth_quant), and more than 10,000 broad models such as [Stable Diffusion](/examples/pytorch/nlp/huggingface_models/text-to-image/quantization), [BERT-Large](/examples/pytorch/nlp/huggingface_models/text-classification/quantization/ptq_static/fx), and [ResNet50](/examples/pytorch/image_recognition/torchvision_models/quantization/ptq/cpu/fx) from popular model hubs such as [Hugging Face](https://huggingface.co/), [Torch Vision](https://pytorch.org/vision/stable/index.html), and [ONNX Model Zoo](https://github.com/onnx/models#models), by leveraging zero-code optimization solution [Neural Coder](/neural_coder#what-do-we-offer) and automatic [accuracy-driven](/docs/source/design.md#workflow) quantization strategies

Expand Down Expand Up @@ -139,6 +139,7 @@ q_model = fit(
> More documentations can be found at [User Guide](./docs/source/user_guide.md).
## Selected Publications/Events
* EMNLP'2023 (Under Review): [TEQ: Trainable Equivalent Transformation for Quantization of LLMs](https://openreview.net/forum?id=iaI8xEINAf&referrer=%5BAuthor%20Console%5D) (Sep 2023)
* arXiv: [Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs](https://arxiv.org/abs/2309.05516) (Sep 2023)
* Post on Social Media: [ONNXCommunityMeetup2023: INT8 Quantization for Large Language Models with Intel Neural Compressor](https://www.youtube.com/watch?v=luYBWA1Q5pQ) (July 2023)
* Blog by Intel: [Accelerate Llama 2 with Intel AI Hardware and Software Optimizations](https://www.intel.com/content/www/us/en/developer/articles/news/llama2.html) (July 2023)
Expand Down
5 changes: 3 additions & 2 deletions docs/source/publication_list.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
Full Publications/Events (75)
Full Publications/Events (76)
==========
## 2023 (21)
## 2023 (22)
* EMNLP'2023 (Under Review): [TEQ: Trainable Equivalent Transformation for Quantization of LLMs](https://openreview.net/forum?id=iaI8xEINAf&referrer=%5BAuthor%20Console%5D) (Sep 2023)
* arXiv: [Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs](https://arxiv.org/abs/2309.05516) (Sep 2023)
* Blog on Medium: [Quantization Accuracy Loss Diagnosis with Neural Insights](https://medium.com/@NeuralCompressor/quantization-accuracy-loss-diagnosis-with-neural-insights-5d73f4ca2601) (Aug 2023)
* Blog on Medium: [Faster Stable Diffusion Inference with Intel Extension for Transformers](https://medium.com/intel-analytics-software/faster-stable-diffusion-inference-with-intel-extension-for-transformers-on-intel-platforms-7e0f563186b0) (July 2023)
Expand Down
33 changes: 32 additions & 1 deletion neural_compressor/compression/pruner/pruners/progressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def _init(self):
self.progressive_steps = self.progressive_configs["progressive_steps"]
self.progressive_type = self.progressive_configs["progressive_type"]
self.use_global = self.progressive_configs["use_global"]
self.progressive_logger = False
self.progressive_logger = True
self.align_masks_flag = False
self._init_for_progressive()

def _init_for_progressive(self):
Expand All @@ -77,6 +78,11 @@ def _init_for_progressive(self):
self.use_progressive = False
return

if self.pruning_frequency == 1:
logger.info("Current progressive setting will degrading to non-progressive pruning.")
self.use_progressive = False
return

# step 3: log hyper-parameters. and check validity.
if self.use_progressive:
logger.info("Progressive pruning is enabled!")
Expand Down Expand Up @@ -225,6 +231,10 @@ def on_step_begin(self, local_step):
Implement at the start of each step.
"""
if self.global_step > self.end_step and self.align_masks_flag is False:
self.align_masks_after_pruning()
self.align_masks_flag = True

if self.handled_global_step == self.global_step:
return

Expand Down Expand Up @@ -270,3 +280,24 @@ def print_progressive_sparsity(self):
"""Output the progressive sparsity."""
cur_sp = self.pattern.get_sparsity_ratio_progressive(self.progressive_masks)
logger.info("Step: {} -> Current progressive sparsity: {}".format(self.global_step, cur_sp))

def obtain_weight_sparsity(self, modules):
total_numels = 0
sparse_numels = 0
for key in modules.keys():
total_numels += modules[key].weight.data.numel()
sparse_numels += torch.sum(torch.where(modules[key].weight.data == 0, 1, 0)).item()
return sparse_numels / total_numels

def align_masks_after_pruning(self):
if not self.use_progressive:
return
"""Implement at the end of training phase."""
# If training ends while a progressive masks is applying, we have to use self.masks to align
# step 1 calculate sparsity under progressive masks
sparsity1 = self.obtain_weight_sparsity(self.modules)
# step 2 use block-wise masks to remask weights
self.mask_weights_general(self.masks)
# step 3 calculate sparsity under progressive masks
sparsity2 = self.obtain_weight_sparsity(self.modules)
logger.info(f"Replace progressive mask with complete masks: Sparsity Update: {sparsity1} => {sparsity2}")
38 changes: 9 additions & 29 deletions test/itex/test_keras_in_keras_out.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ def test_keras_in_keras_out(self):
eval_func=eval_func,
)
q_model.save("itex_qdq_keras_model")
self.assertEqual(q_model.framework(), "keras")

framework_config = {"framework": "keras", "approach": "post_training_static_quant"}
q_model.q_config = framework_config
self.assertEqual(q_model.q_config["framework"], "keras")
self.assertEqual(q_model.graph_info, None)
self.assertEqual(q_model.framework(), "keras")
self.assertEqual(isinstance(q_model.model, tf.keras.Model), True)

model = keras.models.load_model("./itex_qdq_keras_model")
model.summary()
found_quantize = False
Expand All @@ -167,35 +176,6 @@ def test_keras_in_keras_out(self):
test_mode = "performance"
fit(model, conf, b_func=eval_func)

def test_keras_model_interface(self):
logger.info("Run test_keras_model_interface case...")
global test_mode
test_mode = "accuracy"
build_model()

from neural_compressor import set_random_seed
from neural_compressor.config import PostTrainingQuantConfig
from neural_compressor.data.dataloaders.dataloader import DataLoader
from neural_compressor.quantization import fit

set_random_seed(9527)
config = PostTrainingQuantConfig(backend="itex")
q_model = fit(
keras.models.load_model("./baseline_model"),
conf=config,
calib_dataloader=DataLoader(framework="tensorflow", dataset=Dataset()),
eval_func=eval_func,
)
q_model.save("itex_qdq_keras_model")
self.assertEqual(q_model.framework(), "keras")

framework_config = {"framework": "keras", "approach": "post_training_static_quant"}
q_model.q_config = framework_config
self.assertEqual(q_model.q_config["framework"], "keras")
self.assertEqual(q_model.graph_info, None)
self.assertEqual(q_model.framework(), "keras")
self.assertEqual(isinstance(q_model.model, tf.keras.Model), True)


if __name__ == "__main__":
unittest.main()
6 changes: 4 additions & 2 deletions test/itex/test_tensorflow_qdq_convert_to_onnx_qdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def test_convert_tf_qdq_to_onnx_qdq(self):
it = iter(dataloader)
input = next(it)
input_dict = {"input:0": input[0]}
ort_session.run(None, input_dict)
outputs = ort_session.run(None, input_dict)
self.assertNotEqual(outputs, None)

@disable_random()
@unittest.skipIf(version1_lt_version2(tf.version.VERSION, "2.8.0"), "Only supports tf greater 2.7.0")
Expand Down Expand Up @@ -166,7 +167,8 @@ def test_convert_tf_fp32_to_onnx_fp32(self):
it = iter(dataloader)
input = next(it)
input_dict = {"input:0": input[0]}
ort_session.run(None, input_dict)
outputs = ort_session.run(None, input_dict)
self.assertNotEqual(outputs, None)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -618,6 +618,7 @@ def check_not_implement(self, batch_size, distributed):
"tensorflow", {"bert": {"root": "test.record", "label_file": "./dev.json"}}, None, None
)
dataloader = DATALOADERS["tensorflow"](dataset=eval_dataset, batch_size=batch_size, distributed=distributed)
self.assertNotEqual(dataloader, None)

def test_tf_bert_dataloader_1(self):
self.assertRaises(NotImplementedError, self.check_not_implement, 32, True)
Expand Down

0 comments on commit 143e583

Please sign in to comment.