Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Good First Issue][NNCF]: Add INT8 weight compression conformance test for Tinyllama-1.1b PyTorch model #2527

Closed
alexsu52 opened this issue Feb 28, 2024 · 19 comments · Fixed by #2636
Assignees
Labels
good first issue Good for newcomers

Comments

@alexsu52
Copy link
Contributor

Context

This issue proposes adding a test to the post-training compression conformance suite to verify that the weights of Tinyllama-1.1b PyTorch model can be compressed to INT8 in a given time while preserving an acceptable level of model accuracy on whowhatbench

INT8 weight compression is popular approach to reduce the LLM model size by quantizing the weights from original floating point precision to INT8, leading to smaller model footprints and potentially faster inference on the target devices without significant accuracy drop.

This is code snippet for better understanding how to compress weights of Tinyllama-1.1b PyTorch model using NNCF:

import nncf
import transformers

MODEL_ID = "tinyllama/tinyllama-1.1b-step-50k-105b"

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="cpu")

text = 'The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens.'
token = tokenizer(text, max_length=500, return_tensors="pt", truncation=True)
inputs = {"input_ids": token["input_ids"], "attention_mask": token["attention_mask"]}

compressed_model = nncf.compress_weights(model, dataset=nncf.Dataset([inputs]))

What needs to be done?

Add INT8 weight compression test for for Tinyllama-1.1b PyTorch model to the post-training compression conformance suite so that the test can be run with the following command:

pytest tests/post_training/test_quantize_conformance.py::test_weight_compression -s --data=<path to data folder> -k [tinyllama_int8_data_free_backend_PT]

The task steps:

    {
        "reported_name": "tinyllama_int8_data_free",
        "model_id": "tinyllama/tinyllama-1.1b-step-50k-105b",
        "pipeline_cls": LMWeightCompression,
        "compression_params": {
            "mode": CompressWeightsMode.INT8_ASYM,
        },
        "backends": [BackendType.TORCH],
    },
  • Add PyTorch backend support to the LMWeightCompression class.
  • Collect golds and add to the reference file

Example Pull Requests

#2425

Resources

Contact points

@AlexanderDokuchaev, @alexsu52

Ticket

ref: 130788

@alexsu52 alexsu52 added the good first issue Good for newcomers label Feb 28, 2024
@github-project-automation github-project-automation bot moved this to Contributors Needed in Good first issues Feb 28, 2024
@RedShift51
Copy link

Hi, is it possible to take this one?

@alexsu52
Copy link
Contributor Author

Hello @RedShift51, the task is assigned to you.

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

@alexsu52 alexsu52 moved this from Contributors Needed to Assigned in Good first issues Feb 29, 2024
@RedShift51
Copy link

RedShift51 commented Feb 29, 2024

Hey, what metric value is okay for tinyllama/tinyllama-1.1b-step-50k-105b ?

@alexsu52
Copy link
Contributor Author

alexsu52 commented Mar 1, 2024

Hey,

Similarity metric between float16 and int8 weight compressed tinyllama-1.1b-step-50k-105b model on whowhatbench:
similarity : 0.9628345480671635

Code to reproduce:

import torch
import whowhatbench
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

import nncf

MODEL_ID = "tinyllama/tinyllama-1.1b-step-50k-105b"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="auto")

evaluator = whowhatbench.Evaluator(base_model=model, tokenizer=tokenizer)

text = "The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens."
token = tokenizer(text, max_length=500, return_tensors="pt", truncation=True)
inputs = {"input_ids": token["input_ids"].cuda(), "attention_mask": token["attention_mask"].cuda()}
compressed_model = nncf.compress_weights(model, dataset=nncf.Dataset([inputs]))

metrics_per_prompt, metrics = evaluator.score(compressed_model)

metric_of_interest = "similarity"
print(metric_of_interest, ": ", metrics["similarity"][0])

@RedShift51
Copy link

RedShift51 commented Mar 7, 2024

Hi, sorry for the delay, I have reproduced on a cpu
screenshot

import torch
import nncf
import transformers
import whowhatbench

MODEL_ID = "tinyllama/tinyllama-1.1b-step-50k-105b"

tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)
model = transformers.AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float16, device_map="cpu")

text = 'The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens.'
token = tokenizer(text, max_length=500, return_tensors="pt", truncation=True)
inputs = {"input_ids": token["input_ids"], "attention_mask": token["attention_mask"]}

compressed_model = nncf.compress_weights(model, dataset=nncf.Dataset([inputs]))


evaluator = whowhatbench.Evaluator(base_model=compressed_model, tokenizer=tokenizer)
metrics_per_prompt, metrics = evaluator.score(compressed_model)
print(metrics)
metric_of_interest = "similarity"
print(metric_of_interest, ": ", metrics["similarity"][0])

worst_examples = evaluator.worst_examples(top_k=5, metric=metric_of_interest)
print("Metric: ", metric_of_interest)

@alexsu52
Copy link
Contributor Author

alexsu52 commented Mar 8, 2024

The main idea of whowhatbench to compare original_model and compressed_model. But you have compared compressed_model with compressed_model in your code and as expected you get similarity metric = 1.

# collect outputs of original_model
evaluator = whowhatbench.Evaluator(base_model=model, tokenizer=tokenizer)
# inplace weight model compression
compressed_model = nncf.compress_weights(model, dataset=nncf.Dataset([inputs]))
# collect outputs of compressed model and calculate the similarity metric.
metrics_per_prompt, metrics = evaluator.score(compressed_model)

@alexsu52
Copy link
Contributor Author

@RedShift51, are you going to continue work on this issue? do you have any updates?

@alexsu52 alexsu52 moved this from Assigned to Contributors Needed in Good first issues Mar 28, 2024
@alexsu52
Copy link
Contributor Author

Removed assignment due to inactivity.

@ksj20
Copy link

ksj20 commented Mar 28, 2024

.take

Copy link

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

@alexsu52 alexsu52 moved this from Contributors Needed to Assigned in Good first issues Mar 28, 2024
@AdiKsOnDev
Copy link
Contributor

@alexsu52 @ksj20 Any updates on this issue? If the assignee isn't going to work on this, I'd be down to take it.

@ksj20 ksj20 removed their assignment Apr 8, 2024
@AdiKsOnDev
Copy link
Contributor

.take

Copy link

github-actions bot commented Apr 8, 2024

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

@AdiKsOnDev
Copy link
Contributor

AdiKsOnDev commented Apr 9, 2024

@alexsu52 @AlexanderDokuchaev If I add the following code to the LMWeightCompression.compress() and then run a benchmark right after using whowhatbench how should I store the metrics?
Also please tell me if I am going in the right direction, this approach feels a bit odd so far

class LMWeightCompression(BaseTestPipeline):
...

    def compress(self) -> None:
        if self.backend == BackendType.FP32:
            return
        elif self.backend == BackendType.TORCH:
            start_time = time.perf_counter()
            MODEL_ID = "tinyllama/tinyllama-1.1b-step-50k-105b"

            tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_ID)
            self.model = transformers.AutoModelForCausalLM.from_pretrained(
                MODEL_ID, torch_dtype=torch.float16, device_map="cpu"
            )

            text = "The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens."
            token = tokenizer(text, max_length=500, return_tensors="pt", truncation=True)
            inputs = {"input_ids": token["input_ids"], "attention_mask": token["attention_mask"]}

            self.run_info.compression_memory_usage = memory_usage(self._compress_torch(inputs), max_usage=True)
            self.run_info.time_compression = time.perf_counter() - start_time

            return

        print("Weight compression...")
        start_time = time.perf_counter()
        self.run_info.compression_memory_usage = memory_usage(self._compress, max_usage=True)
        self.run_info.time_compression = time.perf_counter() - start_time

    def _compress_torch(self, inputs):
        self.compressed_model = nncf.compress_weights(self.model, dataset=nncf.Dataset([inputs]))

...

@AdiKsOnDev
Copy link
Contributor

@alexsu52 @AlexanderDokuchaev If I add the following code to the LMWeightCompression.compress() and then run a >benchmark right after using whowhatbench how should I store the metrics?
Also please tell me if I am going in the right direction, this approach feels a bit odd so far

@alexsu52 @AlexanderDokuchaev following up on the above^

@AlexanderDokuchaev
Copy link
Collaborator

Hi @AdiKsOnDev

Add _validate function to LMWeightCompression, that will contain call of evaluator from whowhatbench.

Example of _validate function https://github.com/openvinotoolkit/nncf/blob/develop/tests/post_training/pipelines/image_classification_timm.py#L127

Metrics should be stored in self.run_info

self.run_info.metric_name = "Acc@1"
self.run_info.metric_value = acc_top1

@AdiKsOnDev
Copy link
Contributor

Hi @AdiKsOnDev

Add _validate function to LMWeightCompression, that will contain call of evaluator from whowhatbench.

Example of _validate function https://github.com/openvinotoolkit/nncf/blob/develop/tests/post_training/pipelines/image_classification_timm.py#L127

Metrics should be stored in self.run_info

self.run_info.metric_name = "Acc@1"
self.run_info.metric_value = acc_top1

OK, thanks for the directions

@AdiKsOnDev
Copy link
Contributor

AdiKsOnDev commented Apr 12, 2024

@AlexanderDokuchaev _validate(self) already exists in LMWeightCompression
image

Git Blame

image

@AdiKsOnDev
Copy link
Contributor

AdiKsOnDev commented Apr 12, 2024

@AlexanderDokuchaev I added following code for INT_8 support, do you want me to send a PR?

def compress(self) -> None:
    if self.backend == BackendType.FP32:
        return
    elif self.backend == BackendType.TORCH:
        start_time = time.perf_counter()
                                                                                                            
        tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_id)
        self.model = transformers.AutoModelForCausalLM.from_pretrained(
            self.model_id, torch_dtype=torch.float16, device_map="cpu"
        )
                                                                                                            
        text = "The TinyLlama project aims to pretrain a 1.1B Llama model on 3 trillion tokens."
        token = tokenizer(text, max_length=500, return_tensors="pt", truncation=True)
        inputs = {"input_ids": token["input_ids"], "attention_mask": token["attention_mask"]}
                                                                                                            
        self.run_info.compression_memory_usage = memory_usage(self._compress_torch(inputs), max_usage=True)
        self.run_info.time_compression = time.perf_counter() - start_time
                                                                                                            
        return

      print("Weight compression...")
      start_time = time.perf_counter()
      self.run_info.compression_memory_usage = memory_usage(self._compress, max_usage=True)
      self.run_info.time_compression = time.perf_counter() - start_time
    def _compress_torch(self, inputs):
        self.compressed_model = nncf.compress_weights(self.model, dataset=nncf.Dataset([inputs]))

@alexsu52 alexsu52 moved this from Assigned to In Review in Good first issues Apr 18, 2024
alexsu52 added a commit that referenced this issue May 2, 2024
…1b PyTorch model (#2636)

### Changes

- Added the `INT8` compression  **test suite** to the `model_scope`
- Added `TORCH` backend support in `LMWeightCompression` class 
- For `INT8` compression, _dataset,_ as well as some other parameters
(see
[model_scope](https://github.com/openvinotoolkit/nncf/blob/f0081037f28af2a829043d4ddaf4902d91864724/tests/post_training/model_scope.py#L329C1-L340C7))
are set to `None`
-
[metric_value](https://github.com/openvinotoolkit/nncf/blob/f0081037f28af2a829043d4ddaf4902d91864724/tests/post_training/data/wc_reference_data.yaml#L17C1-L20C15)
has been set to **0.95944**
- Mainly use `save_pretrained()` for `TORCH` models
- Omitted a few method calls that are not supported for `TORCH` models
(Check the commits for details)

 

### Reason for changes

Requested to Benchmark changes via `whowhatbench` in issue #2527 

### Related tickets

ref: 130788
Closes #2527

### Tests
- Added `INT8` _weight compression_ **conformance** test for
`Tinyllama-1.1b` **PyTorch** model

---------

Co-authored-by: Aleksander <aleksu52@noreply.github.com>
Co-authored-by: Alexander Suslov <alexander.suslov@intel.com>
@github-project-automation github-project-automation bot moved this from In Review to Closed in Good first issues May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
Archived in project
5 participants