Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AQLM quantizer support #28928

Merged
merged 24 commits into from Feb 14, 2024
Merged

AQLM quantizer support #28928

merged 24 commits into from Feb 14, 2024

Conversation

BlackSamorez
Copy link
Contributor

@BlackSamorez BlackSamorez commented Feb 8, 2024

What does this PR do?

Fixes Vahe1994/AQLM#11

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@younesbelkada

@BlackSamorez
Copy link
Contributor Author

@BlackSamorez
Copy link
Contributor Author

A Google Colab demo: Mixtral in 2 bits.


if isinstance(module, nn.Linear):
# Check if the current key is not in the `linear_weights_not_to_quantize`
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I saw in the config of the model you pushed on the Hub that you also included layer norm weights inside linear_weights_not_to_quantize , I think these can be excluded from the config as they are not an insitance of nn.Linear right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They certainly can be excluded. It's just that converting from a freshly quantized AQLM format it would be troublesome to check if an unquantized .weight parameter is of nn.Linear or not. So I simply included all of them just in case. That Mixtral config can, indeed, be made somewhat shorter.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty clean already ! Thanks so much for working on this and converting Mixtral using AQLM format - thanks also for sharing the Google Colab ! Amazing work !
I assume the method works also on a T4 since you shared a colab demo, would you mind adding simple tests?
You can simply copy paste the tests from AWQ: https://github.com/huggingface/transformers/blob/main/tests/quantization/autoawq/test_awq.py and simply have Config tests and very simple model tests that tests that the model has been successfully converted to Aqlm format + a generation test. I would also add a simple test to make sure the model loads well on CPU
Can you also share some insights on generation speed for CPU & GPU ? 🙏

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last thing; could you add installation instructions on our testing dockerfile: https://github.com/huggingface/transformers/blob/main/docker/transformers-all-latest-gpu/Dockerfile

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@younesbelkada
Copy link
Contributor

cc @oobabooga this might be of your interest !

@BlackSamorez
Copy link
Contributor Author

I updated the docked recipe and added tests, but they are skipped because aqlm is not installed in the testing environment.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot ! I left two minor comments ! Can you also run make fixup ? This will redirect you to run make fix-copies which should fix the tests !

return True


def _replace_with_aqlm_linear(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can move this method and make it public under integrations/aqlm.py and import locally the method inside _process_model_before_weight_loading

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



@require_torch_gpu
class AwqConfigTest(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class AwqConfigTest(unittest.TestCase):
class AqlmConfigTest(unittest.TestCase):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@BlackSamorez
Copy link
Contributor Author

I'm pretty sure tests failing has nothing to do with my PR

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks very clean ! Thanks so much for the integration !
As discussed with @ArthurZucker offline, we could leverage # Copied from on the tests but this is clearly not a blocker for me and we can merge as is! Looking forward to the integration !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks for the clean PR 🤗

src/transformers/utils/quantization_config.py Outdated Show resolved Hide resolved
Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also left one nit ! I tested it and seems to work only for python >=3.10

docs/source/en/quantization.md Outdated Show resolved Hide resolved
@younesbelkada
Copy link
Contributor

@BlackSamorez on a google colab env the inference script works great, however on my VM, on a python 3.10 env with latest torch + cuda11.8 I constantly get:

Traceback (most recent call last):
  File "/transformers/scratch.py", line 11, in <module>
    output = quantized_model.generate(tokenizer("", return_tensors="pt")["input_ids"].cuda(), max_new_tokens=10)
  File "/miniconda3/envs/aqlm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/transformers/src/transformers/generation/utils.py", line 1495, in generate
    return self.greedy_search(
  File "/transformers/src/transformers/generation/utils.py", line 2366, in greedy_search
    next_tokens = torch.argmax(next_tokens_scores, dim=-1)
RuntimeError: CUDA error: device kernel image is invalid
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Do you have an idea what might be wrong here?

@younesbelkada
Copy link
Contributor

The only difference I see between the colab instance and mine is the CUDA version, I'll update it to 12.1 and loop back here

BlackSamorez and others added 2 commits February 13, 2024 11:48
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
@BlackSamorez
Copy link
Contributor Author

@younesbelkada
The kernel is compiled in runtime with

import os
from typing import Optional

import torch
from torch.utils.cpp_extension import load

CUDA_FOLDER = os.path.dirname(os.path.abspath(__file__))
CUDA_KERNEL = load(
    name="codebook_cuda",
    sources=[os.path.join(CUDA_FOLDER, "cuda_kernel.cpp"), os.path.join(CUDA_FOLDER, "cuda_kernel.cu")],
)

Maybe your nvcc is sourced from an incorrect cuda installment. I'm not really sure how to test it. Maybe you could try specifying an nvcc path somehow with an environmental variable.
I'll try to reproduce it as well.

@BlackSamorez
Copy link
Contributor Author

CUDA 11.8 seems to work fine on my machine on an a100 GPU.

@BlackSamorez
Copy link
Contributor Author

BlackSamorez commented Feb 13, 2024

FYI: I've released aqlm version 1.0.1 where I added device guards to fix CUDA errors when running in the multi-gpu setup. I've added the corresponding tests similar to autoawq ones

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the integration !

src/transformers/integrations/aqlm.py Outdated Show resolved Hide resolved
tests/quantization/aqlm_integration/test_aqlm.py Outdated Show resolved Hide resolved
BlackSamorez and others added 2 commits February 13, 2024 19:17
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@BlackSamorez
Copy link
Contributor Author

Looks like some network error occured

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot @BlackSamorez !

@ArthurZucker ArthurZucker merged commit 1ecf5f7 into huggingface:main Feb 14, 2024
21 of 22 checks passed
@ArthurZucker
Copy link
Collaborator

🤗 🚀

sbucaille pushed a commit to sbucaille/transformers that referenced this pull request Feb 14, 2024
* aqlm init

* calibration and dtypes

* docs

* Readme update

* is_aqlm_available

* Simpler link in docs

* Test TODO real reference

* init _import_structure fix

* AqlmConfig autodoc

* integration aqlm

* integrations in tests

* docstring fix

* legacy typing

* Less typings

* More kernels information

* Performance -> Accuracy

* correct tests

* remoced multi-gpu test

* Update docs/source/en/quantization.md

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Brought back multi-gpu tests

* Update src/transformers/integrations/aqlm.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/aqlm_integration/test_aqlm.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Andrei Panferov <blacksamorez@yandex-team.ru>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
@younesbelkada
Copy link
Contributor

Hi @BlackSamorez !
Thanks again for your great work ! I was wondering if you could update the installation cell on the shared notebook to install transformers from source instead of your fork - that way we could catch potential bugs in the future before the release 🙏

@BlackSamorez
Copy link
Contributor Author

@younesbelkada
Looks like this commit outside of the PR broke something.

AttributeError                            Traceback (most recent call last)

[<ipython-input-2-68b1b199d504>](https://localhost:8080/#) in <cell line: 3>()
      1 from transformers import AutoTokenizer, AutoModelForCausalLM
      2 
----> 3 quantized_model = AutoModelForCausalLM.from_pretrained(
      4     "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch",
      5     torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True,

4 frames

[/usr/local/lib/python3.10/dist-packages/transformers/models/auto/auto_factory.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs)
    565         elif type(config) in cls._model_mapping.keys():
    566             model_class = _get_model_class(config, cls._model_mapping)
--> 567             return model_class.from_pretrained(
    568                 pretrained_model_name_or_path, *model_args, config=config, **hub_kwargs, **kwargs
    569             )

[/usr/local/lib/python3.10/dist-packages/transformers/modeling_utils.py](https://localhost:8080/#) in from_pretrained(cls, pretrained_model_name_or_path, config, cache_dir, ignore_mismatched_sizes, force_download, local_files_only, token, revision, use_safetensors, *model_args, **kwargs)
   3561 
   3562         if hf_quantizer is not None:
-> 3563             hf_quantizer.postprocess_model(model)
   3564             model.hf_quantizer = hf_quantizer
   3565 

[/usr/local/lib/python3.10/dist-packages/transformers/quantizers/base.py](https://localhost:8080/#) in postprocess_model(self, model, **kwargs)
    177                 The keyword arguments that are passed along `_process_model_after_weight_loading`.
    178         """
--> 179         return self._process_model_after_weight_loading(model, **kwargs)
    180 
    181     @abstractmethod

[/usr/local/lib/python3.10/dist-packages/transformers/quantizers/quantizer_aqlm.py](https://localhost:8080/#) in _process_model_after_weight_loading(self, model, **kwargs)
     78 
     79     def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
---> 80         model._is_quantized_training_enabled = False
     81         return model
     82 

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in __setattr__(self, name, value)
   1745                     buffers[name] = value
   1746                 else:
-> 1747                     super().__setattr__(name, value)
   1748 
   1749     def __delattr__(self, name):

AttributeError: can't set attribute '_is_quantized_training_enabled'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Integration with HF transformers
5 participants