diff --git a/tests/models/cohere2/test_modeling_cohere2.py b/tests/models/cohere2/test_modeling_cohere2.py index 1bb3849d71d6..1b126e0ccd8d 100644 --- a/tests/models/cohere2/test_modeling_cohere2.py +++ b/tests/models/cohere2/test_modeling_cohere2.py @@ -34,8 +34,7 @@ torch_device, ) -from ...models.cohere.test_modeling_cohere import CohereModelTest, CohereModelTester -from ...test_configuration_common import ConfigTester +from ...models.cohere.test_modeling_cohere import CohereModelTester if is_torch_available(): @@ -46,6 +45,11 @@ Cohere2Model, ) +from ...generation.test_utils import GenerationTesterMixin +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin +from ...test_pipeline_mixin import PipelineTesterMixin + class Cohere2ModelTester(CohereModelTester): config_class = Cohere2Config @@ -55,7 +59,7 @@ class Cohere2ModelTester(CohereModelTester): @require_torch -class Cohere2ModelTest(CohereModelTest, unittest.TestCase): +class Cohere2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Cohere2Model, Cohere2ForCausalLM) if is_torch_available() else () pipeline_model_mapping = ( { @@ -67,10 +71,21 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase): ) _is_stateful = True + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + def setUp(self): self.model_tester = Cohere2ModelTester(self) self.config_tester = ConfigTester(self, config_class=Cohere2Config, hidden_size=37) + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + @slow @require_read_token @@ -269,6 +284,3 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str): output_text = tokenizer.batch_decode(out) self.assertEqual(output_text, EXPECTED_COMPLETIONS) - - -del CohereModelTest, CohereModelTester # So the parent tests don't run in this file too diff --git a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py index 69e1a67047d3..47ede51be516 100644 --- a/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py +++ b/tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py @@ -14,16 +14,21 @@ # limitations under the License. """Testing suite for the PyTorch GraniteMoeHybrid model.""" +import inspect +import tempfile import unittest import pytest +from pytest import mark from transformers import ( AutoTokenizer, + DataCollatorWithFlattening, GraniteMoeHybridConfig, is_torch_available, ) from transformers.testing_utils import ( + require_flash_attn, require_torch, require_torch_gpu, slow, @@ -31,7 +36,10 @@ ) from ...generation.test_utils import GenerationTesterMixin -from ...models.bamba.test_modeling_bamba import BambaModelTest, BambaModelTester +from ...models.bamba.test_modeling_bamba import BambaModelTester +from ...test_configuration_common import ConfigTester +from ...test_modeling_common import ModelTesterMixin +from ...test_pipeline_mixin import PipelineTesterMixin if is_torch_available(): @@ -77,7 +85,7 @@ def get_config(self): @require_torch -class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.TestCase): +class GraniteMoeHybridModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase): model_tester_class = GraniteMoeHybridModelTester all_model_classes = ( ( @@ -96,6 +104,225 @@ class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest. else {} ) + # Need to use `0.8` instead of `0.9` for `test_cpu_offload` + # This is because we are hitting edge cases with the causal_mask buffer + model_split_percents = [0.5, 0.7, 0.8] + + def _check_caches_are_equal( + self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache + ): + if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance( + cache2, HybridMambaAttentionDynamicCache + ): + raise ValueError("The wrong cache is being used!") + + if not len(cache1) == len(cache2): + raise ValueError("Both caches do not have the same number of layers.") + + num_layers = len(cache1) + for idx in range(num_layers): + torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx]) + torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx]) + torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx]) + torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx]) + + def setUp(self): + self.model_tester = self.model_tester_class(self) + self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64) + + def test_config(self): + self.config_tester.run_common_tests() + + def test_model(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_model(*config_and_inputs) + + def test_for_causal_lm(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_for_causal_lm(*config_and_inputs) + + def test_decoder_model_past_with_large_inputs(self): + config_and_inputs = self.model_tester.prepare_config_and_inputs() + self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs) + + def test_attention_outputs(self): + r""" + Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers + """ + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + config.return_dict = True + + seq_len = getattr(self.model_tester, "seq_length", None) + encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len) + encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length) + + expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices) + + for model_class in self.all_model_classes: + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = False + config.return_dict = True + model = model_class._from_config(config, attn_implementation="eager") + config = model.config + model.to(torch_device) + model.eval() + + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + # check that output_attentions also work using config + del inputs_dict["output_attentions"] + config.output_attentions = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + attentions = outputs.attentions + self.assertEqual(len(attentions), expected_num_attentions) + + self.assertListEqual( + list(attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + out_len = len(outputs) + + # Check attention is always last and order is fine + inputs_dict["output_attentions"] = True + inputs_dict["output_hidden_states"] = True + model = model_class(config) + model.to(torch_device) + model.eval() + with torch.no_grad(): + outputs = model(**self._prepare_for_class(inputs_dict, model_class)) + + added_hidden_states = 1 + self.assertEqual(out_len + added_hidden_states, len(outputs)) + + self_attentions = outputs.attentions + + self.assertEqual(len(self_attentions), expected_num_attentions) + self.assertListEqual( + list(self_attentions[0].shape[-3:]), + [self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], + ) + + def test_batching_equivalence(self): + # need to disable the tril input mask + orig = self.model_tester.use_input_mask + self.model_tester.use_input_mask = False + super().test_batching_equivalence() + self.model_tester.use_input_mask = orig + + @pytest.mark.generate + def test_left_padding_compatibility(self): + # TODO: document why a random attention mask causes this test to fail, but a full mask doesn't + unpadded_custom_inputs = {"attention_mask": None} + super().test_left_padding_compatibility(unpadded_custom_inputs=unpadded_custom_inputs) + + @unittest.skip( + "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training." + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self): + pass + + @unittest.skip( + "Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training." + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self): + pass + + @require_flash_attn + @require_torch_gpu + @mark.flash_attn_test + @slow + @unittest.skip( + "NotImplementedError: seq_idx support requires fast path support. Please install mamba_ssm and causal_conv1d" + ) + def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self): + if not self.has_attentions: + self.skipTest(reason="Model architecture does not support attentions") + + max_new_tokens = 30 + + for model_class in self.all_generative_model_classes: + if not model_class._supports_flash_attn: + self.skipTest(f"{model_class.__name__} does not support Flash Attention 2") + + config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() + if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict: + self.skipTest("Model dummy inputs should contain padding in their attention mask") + + dummy_input = inputs_dict[model_class.main_input_name] + if dummy_input.dtype in [torch.float32, torch.bfloat16]: + dummy_input = dummy_input.to(torch.float16) + + # make sure that all models have enough positions for generation + if hasattr(config, "max_position_embeddings"): + config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1 + + model = model_class(config) + if "position_ids" not in inspect.signature(model.forward).parameters: + self.skipTest("Model does not support position_ids") + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + + # ensure left padding, to adapt for some models + if 0 in inputs_dict["attention_mask"][:, -1]: + inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1) + dummy_attention_mask = inputs_dict["attention_mask"] + inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id + # Ensure inputs_dict also has labels in it, as their presence/absence can induce + # dtype conversions. This also lets us compare losses. + labels = inputs_dict["input_ids"].clone() + # Mask padding tokens + labels[~dummy_attention_mask.bool()] = -100 + # Also need to mask the first non-trivial token to match the padding-free batch. + first_nonneg_idx = (labels >= 0).int().argmax(dim=1) + labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100 + inputs_dict["labels"] = labels + + model = ( + model_class.from_pretrained( + tmpdirname, + dtype=torch.float16, + attn_implementation="flash_attention_2", + ) + .to(torch_device) + .eval() + ) + + # flatten + features = [ + {"input_ids": i[a.bool()].tolist()} + for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"]) + ] + + # add position_ids + fa_kwargs + seq_idx + data_collator = DataCollatorWithFlattening( + return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True + ) + batch = data_collator(features) + batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()} + + res_padded = model(**inputs_dict) + res_padfree = model(**batch_accelerator) + + logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()] + logits_padfree = res_padfree.logits[0] + + torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0) + # acceptable numerical instability + tol = torch.finfo(torch.float16).eps + torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol) + + loss_padded = res_padded.loss + loss_padfree = res_padfree.loss + torch.testing.assert_close(loss_padded, loss_padfree) + def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config): self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache) @@ -178,6 +405,3 @@ def test_model_generation(self): text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) self.assertEqual(EXPECTED_TEXT_COMPLETION, text) - - -del BambaModelTest, BambaModelTester # So the parent tests don't run in this file too