Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions src/transformers/generation/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,8 +379,7 @@ class GenerationConfig(PushToHubMixin):
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
gains.

disable_compile (`bool`, *optional*): Whether to disable the compilation of the forward pass when using 'statis' cache
implementation.
disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag.

> Wild card

Expand Down
5 changes: 3 additions & 2 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1586,7 +1586,6 @@ def _prepare_generation_config(
generation_config.decoder_start_token_id = self.generation_config.decoder_start_token_id
else:
model_kwargs = kwargs

return generation_config, model_kwargs

def _get_initial_cache_position(self, input_ids, model_kwargs):
Expand Down Expand Up @@ -3252,7 +3251,9 @@ def _sample(
model_forward = self.__call__
if isinstance(model_kwargs.get("past_key_values"), Cache):
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
is_compileable = is_compileable and not self.generation_config.disable_compile
if getattr(self, "hf_quantizer", None) is not None:
is_compileable &= self.hf_quantizer.is_compileable
is_compileable = is_compileable and not generation_config.disable_compile
if is_compileable and (
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/quantizers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,11 @@ def is_qat_trainable(self) -> bool:
"""Flag indicating whether the quantized model can carry out quantization aware training"""
return False

@property
def is_compileable(self) -> bool:
"""Flag indicating whether the quantized model can be compiled"""
return False

@abstractmethod
def _process_model_before_weight_loading(self, model, **kwargs): ...

Expand Down
4 changes: 4 additions & 0 deletions src/transformers/quantizers/quantizer_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,7 @@ def is_trainable(self):
"int8_dynamic_activation_int8_weight",
]
return self.quantization_config.quant_type in supported_quant_types_for_training

@property
def is_compileable(self) -> bool:
return True
33 changes: 33 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,3 +771,36 @@ def test_set_load_in_8_bit(self):
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
quantization_config.load_in_8bit = True


@require_bitsandbytes
@require_accelerate
@require_torch_gpu_if_bnb_not_multi_backend_enabled
@slow
@apply_skip_if_not_implemented
class Bnb4bitCompile(unittest.TestCase):
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
input_text = "Hello my name is"

def setUp(self):
# Models and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)

def test_generate_compile(self):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

# if nothing is set, compile will be disabled for bnb
self.model_4bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
max_new_tokens=10,
cache_implementation="static",
)
with self.assertRaises(Exception):
# overwrite property
object.__setattr__(self.model_4bit.hf_quantizer, "is_compileable", True)
self.model_4bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
max_new_tokens=10,
cache_implementation="static",
)
34 changes: 34 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -966,3 +966,37 @@ def test_int8_from_pretrained(self):
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)

self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)


@require_bitsandbytes
@require_accelerate
@require_torch
@require_torch_gpu_if_bnb_not_multi_backend_enabled
@slow
@apply_skip_if_not_implemented
class Bnb8bitCompile(unittest.TestCase):
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
input_text = "Hello my name is"

def setUp(self):
# Models and tokenizer
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)

def test_generate_compile(self):
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

# if nothing is set, compile will be disabled for bnb
self.model_8bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
max_new_tokens=10,
cache_implementation="static",
)

with self.assertRaises(Exception):
object.__setattr__(self.model_8bit.hf_quantizer, "is_compileable", True)
self.model_8bit.generate(
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
max_new_tokens=10,
cache_implementation="static",
)