Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions tests/models/cohere2/test_modeling_cohere2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,7 @@
torch_device,
)

from ...models.cohere.test_modeling_cohere import CohereModelTest, CohereModelTester
from ...test_configuration_common import ConfigTester
from ...models.cohere.test_modeling_cohere import CohereModelTester


if is_torch_available():
Expand All @@ -46,6 +45,11 @@
Cohere2Model,
)

from ...generation.test_utils import GenerationTesterMixin
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
from ...test_pipeline_mixin import PipelineTesterMixin


class Cohere2ModelTester(CohereModelTester):
config_class = Cohere2Config
Expand All @@ -55,7 +59,7 @@ class Cohere2ModelTester(CohereModelTester):


@require_torch
class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
class Cohere2ModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (Cohere2Model, Cohere2ForCausalLM) if is_torch_available() else ()
pipeline_model_mapping = (
{
Expand All @@ -67,10 +71,21 @@ class Cohere2ModelTest(CohereModelTest, unittest.TestCase):
)
_is_stateful = True

# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]

def setUp(self):
self.model_tester = Cohere2ModelTester(self)
self.config_tester = ConfigTester(self, config_class=Cohere2Config, hidden_size=37)

def test_config(self):
self.config_tester.run_common_tests()

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)


@slow
@require_read_token
Expand Down Expand Up @@ -269,6 +284,3 @@ def test_generation_beyond_sliding_window(self, attn_implementation: str):
output_text = tokenizer.batch_decode(out)

self.assertEqual(output_text, EXPECTED_COMPLETIONS)


del CohereModelTest, CohereModelTester # So the parent tests don't run in this file too
234 changes: 229 additions & 5 deletions tests/models/granitemoehybrid/test_modeling_granitemoehybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,32 @@
# limitations under the License.
"""Testing suite for the PyTorch GraniteMoeHybrid model."""

import inspect
import tempfile
import unittest

import pytest
from pytest import mark

from transformers import (
AutoTokenizer,
DataCollatorWithFlattening,
GraniteMoeHybridConfig,
is_torch_available,
)
from transformers.testing_utils import (
require_flash_attn,
require_torch,
require_torch_gpu,
slow,
torch_device,
)

from ...generation.test_utils import GenerationTesterMixin
from ...models.bamba.test_modeling_bamba import BambaModelTest, BambaModelTester
from ...models.bamba.test_modeling_bamba import BambaModelTester
from ...test_configuration_common import ConfigTester
from ...test_modeling_common import ModelTesterMixin
from ...test_pipeline_mixin import PipelineTesterMixin


if is_torch_available():
Expand Down Expand Up @@ -77,7 +85,7 @@ def get_config(self):


@require_torch
class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.TestCase):
class GraniteMoeHybridModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a comment explaining why it is better not to inherit from another test, so if someone is tempted to do the same, that someone will have a hint

model_tester_class = GraniteMoeHybridModelTester
all_model_classes = (
(
Expand All @@ -96,6 +104,225 @@ class GraniteMoeHybridModelTest(BambaModelTest, GenerationTesterMixin, unittest.
else {}
)

# Need to use `0.8` instead of `0.9` for `test_cpu_offload`
# This is because we are hitting edge cases with the causal_mask buffer
model_split_percents = [0.5, 0.7, 0.8]

def _check_caches_are_equal(
self, cache1: HybridMambaAttentionDynamicCache, cache2: HybridMambaAttentionDynamicCache
):
if not isinstance(cache1, HybridMambaAttentionDynamicCache) or not isinstance(
cache2, HybridMambaAttentionDynamicCache
):
raise ValueError("The wrong cache is being used!")

if not len(cache1) == len(cache2):
raise ValueError("Both caches do not have the same number of layers.")

num_layers = len(cache1)
for idx in range(num_layers):
torch.testing.assert_close(cache1.key_cache[idx], cache2.key_cache[idx])
torch.testing.assert_close(cache1.value_cache[idx], cache2.value_cache[idx])
torch.testing.assert_close(cache1.conv_states[idx], cache2.conv_states[idx])
torch.testing.assert_close(cache1.ssm_states[idx], cache2.ssm_states[idx])

def setUp(self):
self.model_tester = self.model_tester_class(self)
self.config_tester = ConfigTester(self, config_class=self.model_tester.config_class, hidden_size=64)

def test_config(self):
self.config_tester.run_common_tests()

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)

def test_for_causal_lm(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_for_causal_lm(*config_and_inputs)

def test_decoder_model_past_with_large_inputs(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)

def test_attention_outputs(self):
r"""
Overriding the test_attention_outputs test as the Bamba model outputs attention only for its attention layers
"""
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True

seq_len = getattr(self.model_tester, "seq_length", None)
encoder_seq_length = getattr(self.model_tester, "encoder_seq_length", seq_len)
encoder_key_length = getattr(self.model_tester, "key_length", encoder_seq_length)

expected_num_attentions = self.model_tester.num_hidden_layers - len(self.model_tester.attn_layer_indices)

for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class._from_config(config, attn_implementation="eager")
config = model.config
model.to(torch_device)
model.eval()

with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)

# check that output_attentions also work using config
del inputs_dict["output_attentions"]
config.output_attentions = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))
attentions = outputs.attentions
self.assertEqual(len(attentions), expected_num_attentions)

self.assertListEqual(
list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
out_len = len(outputs)

# Check attention is always last and order is fine
inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = True
model = model_class(config)
model.to(torch_device)
model.eval()
with torch.no_grad():
outputs = model(**self._prepare_for_class(inputs_dict, model_class))

added_hidden_states = 1
self.assertEqual(out_len + added_hidden_states, len(outputs))

self_attentions = outputs.attentions

self.assertEqual(len(self_attentions), expected_num_attentions)
self.assertListEqual(
list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)

def test_batching_equivalence(self):
# need to disable the tril input mask
orig = self.model_tester.use_input_mask
self.model_tester.use_input_mask = False
super().test_batching_equivalence()
self.model_tester.use_input_mask = orig

@pytest.mark.generate
def test_left_padding_compatibility(self):
# TODO: document why a random attention mask causes this test to fail, but a full mask doesn't
unpadded_custom_inputs = {"attention_mask": None}
super().test_left_padding_compatibility(unpadded_custom_inputs=unpadded_custom_inputs)

@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids(self):
pass

@unittest.skip(
"Bamba requires additionally specifying position_ids, seq_idx, and FlashAttentionKwargs for padding-free training."
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_and_fa_kwargs(self):
pass

@require_flash_attn
@require_torch_gpu
@mark.flash_attn_test
@slow
@unittest.skip(
"NotImplementedError: seq_idx support requires fast path support. Please install mamba_ssm and causal_conv1d"
)
def test_flash_attention_2_padding_matches_padding_free_with_position_ids_seq_idx_and_fa_kwargs(self):
if not self.has_attentions:
self.skipTest(reason="Model architecture does not support attentions")

max_new_tokens = 30

for model_class in self.all_generative_model_classes:
if not model_class._supports_flash_attn:
self.skipTest(f"{model_class.__name__} does not support Flash Attention 2")

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
if 0 not in inputs_dict.get("attention_mask", []) or "attention_mask" not in inputs_dict:
self.skipTest("Model dummy inputs should contain padding in their attention mask")

dummy_input = inputs_dict[model_class.main_input_name]
if dummy_input.dtype in [torch.float32, torch.bfloat16]:
dummy_input = dummy_input.to(torch.float16)

# make sure that all models have enough positions for generation
if hasattr(config, "max_position_embeddings"):
config.max_position_embeddings = max_new_tokens + dummy_input.shape[1] + 1

model = model_class(config)
if "position_ids" not in inspect.signature(model.forward).parameters:
self.skipTest("Model does not support position_ids")

with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)

# ensure left padding, to adapt for some models
if 0 in inputs_dict["attention_mask"][:, -1]:
inputs_dict["attention_mask"] = inputs_dict["attention_mask"].flip(1)
dummy_attention_mask = inputs_dict["attention_mask"]
inputs_dict["input_ids"][~dummy_attention_mask.bool()] = config.get_text_config().pad_token_id
# Ensure inputs_dict also has labels in it, as their presence/absence can induce
# dtype conversions. This also lets us compare losses.
labels = inputs_dict["input_ids"].clone()
# Mask padding tokens
labels[~dummy_attention_mask.bool()] = -100
# Also need to mask the first non-trivial token to match the padding-free batch.
first_nonneg_idx = (labels >= 0).int().argmax(dim=1)
labels[torch.arange(labels.size(0), device=labels.device), first_nonneg_idx] = -100
inputs_dict["labels"] = labels

model = (
model_class.from_pretrained(
tmpdirname,
dtype=torch.float16,
attn_implementation="flash_attention_2",
)
.to(torch_device)
.eval()
)

# flatten
features = [
{"input_ids": i[a.bool()].tolist()}
for i, a in zip(inputs_dict["input_ids"], inputs_dict["attention_mask"])
]

# add position_ids + fa_kwargs + seq_idx
data_collator = DataCollatorWithFlattening(
return_tensors="pt", return_seq_idx=True, return_flash_attn_kwargs=True
)
batch = data_collator(features)
batch_accelerator = {k: t.to(torch_device) if torch.is_tensor(t) else t for k, t in batch.items()}

res_padded = model(**inputs_dict)
res_padfree = model(**batch_accelerator)

logits_padded = res_padded.logits[inputs_dict["attention_mask"].bool()]
logits_padfree = res_padfree.logits[0]

torch.testing.assert_close(logits_padded.argmax(-1), logits_padfree.argmax(-1), rtol=0, atol=0)
# acceptable numerical instability
tol = torch.finfo(torch.float16).eps
torch.testing.assert_close(logits_padded, logits_padfree, rtol=tol, atol=tol)

loss_padded = res_padded.loss
loss_padfree = res_padfree.loss
torch.testing.assert_close(loss_padded, loss_padfree)

def _check_past_key_values_for_generate(self, batch_size, past_key_values, seq_length, config):
self.assertIsInstance(past_key_values, HybridMambaAttentionDynamicCache)

Expand Down Expand Up @@ -178,6 +405,3 @@ def test_model_generation(self):
text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

self.assertEqual(EXPECTED_TEXT_COMPLETION, text)


del BambaModelTest, BambaModelTester # So the parent tests don't run in this file too