Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache: models return input cache type #30716

Merged
merged 1 commit into from
May 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
11 changes: 6 additions & 5 deletions src/transformers/models/cohere/modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,9 @@ def forward(
inputs_embeds = self.embed_tokens(input_ids)

past_seen_tokens = 0
return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -943,11 +945,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/dbrx/modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,7 +1115,9 @@ def forward(

inputs_embeds = nn.functional.dropout(inputs_embeds, p=self.emb_pdrop, training=self.training)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -1182,13 +1184,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(
v
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/gemma/modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -933,13 +935,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -960,7 +960,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -1021,13 +1023,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
13 changes: 6 additions & 7 deletions src/transformers/models/olmo/modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,7 +938,9 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

return_legacy_cache = False
if use_cache and not isinstance(past_key_values, Cache): # kept for BC (non `Cache` `past_key_values` inputs)
return_legacy_cache = True
past_key_values = DynamicCache.from_legacy_cache(past_key_values)

if cache_position is None:
Expand Down Expand Up @@ -999,13 +1001,10 @@ def forward(
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = (
next_decoder_cache.to_legacy_cache()
if isinstance(next_decoder_cache, DynamicCache)
else next_decoder_cache
)
next_cache = next_decoder_cache if use_cache else None
if return_legacy_cache:
next_cache = next_cache.to_legacy_cache()

if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand Down
7 changes: 0 additions & 7 deletions tests/models/cohere/test_modeling_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

import unittest

from parameterized import parameterized

from transformers import CohereConfig, is_torch_available
from transformers.testing_utils import (
require_bitsandbytes,
Expand Down Expand Up @@ -296,11 +294,6 @@ def setUp(self):
def test_config(self):
self.config_tester.run_common_tests()

@unittest.skip("TODO @gante fix this for Cohere")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

def test_model(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
Expand Down
7 changes: 0 additions & 7 deletions tests/models/dbrx/test_modeling_dbrx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@

import unittest

from parameterized import parameterized

from transformers import DbrxConfig, is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device

Expand Down Expand Up @@ -357,11 +355,6 @@ def test_model_from_pretrained(self):
def test_tied_weights_keys(self):
pass

@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch
class DbrxModelIntegrationTest(unittest.TestCase):
Expand Down
6 changes: 0 additions & 6 deletions tests/models/gemma/test_modeling_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import unittest

import pytest
from parameterized import parameterized

from transformers import AutoModelForCausalLM, AutoTokenizer, GemmaConfig, is_torch_available
from transformers.testing_utils import (
Expand Down Expand Up @@ -367,11 +366,6 @@ def test_Gemma_sequence_classification_model_for_multi_label(self):
result = model(input_ids, attention_mask=attention_mask, labels=sequence_labels)
self.assertEqual(result.logits.shape, (self.model_tester.batch_size, self.model_tester.num_labels))

@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

@unittest.skip("Gemma buffers include complex numbers, which breaks this test")
def test_save_load_fast_init_from_base(self):
pass
Expand Down
5 changes: 0 additions & 5 deletions tests/models/llama/test_modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,11 +591,6 @@ def test_eager_matches_sdpa_generate(self):
msg=f"\n{tokenizer.batch_decode(res_eager)} \nvs\n{tokenizer.batch_decode(res_sdpa)}",
)

@unittest.skip("TODO @gante fix this for Llama")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch_gpu
class LlamaIntegrationTest(unittest.TestCase):
Expand Down
5 changes: 0 additions & 5 deletions tests/models/olmo/test_modeling_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,11 +353,6 @@ def test_model_rope_scaling(self, scaling_type):
# The output should be different for long inputs
self.assertFalse(torch.allclose(original_long_output, scaled_long_output, atol=1e-5))

@unittest.skip("TODO @gante fix this for OLMo")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass


@require_torch
class OlmoIntegrationTest(unittest.TestCase):
Expand Down
7 changes: 0 additions & 7 deletions tests/models/recurrent_gemma/test_modeling_recurrent_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
""" Testing suite for the PyTorch RecurrentGemma model. """
import unittest

from parameterized import parameterized

from transformers import AutoModelForCausalLM, AutoTokenizer, RecurrentGemmaConfig, is_torch_available, set_seed
from transformers.testing_utils import (
require_bitsandbytes,
Expand Down Expand Up @@ -330,11 +328,6 @@ def test_model_various_embeddings(self):
config_and_inputs[0].position_embedding_type = type
self.model_tester.create_and_check_model(*config_and_inputs)

@unittest.skip("Recurrent gemma does not use legacy cache")
@parameterized.expand([(1, False), (1, True), (4, False)])
def test_new_cache_format(self, num_beams, do_sample):
pass

def test_save_load_fast_init_from_base(self):
pass

Expand Down