From 0312d7b6a42aeefc734431ca777807bb2c94c384 Mon Sep 17 00:00:00 2001 From: cyy Date: Wed, 26 Mar 2025 23:21:41 +0800 Subject: [PATCH 1/3] Remove deprecated batch_size argument --- src/transformers/cache_utils.py | 102 ++++++---------------- tests/models/mamba/test_modeling_mamba.py | 2 +- tests/utils/test_cache_utils.py | 6 +- 3 files changed, 30 insertions(+), 80 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 02abcfd21acd..25809f8767f8 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1140,8 +1140,8 @@ class StaticCache(Cache): Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. - batch_size (`int`): - The batch size with which the model will be used. Note that a new instance must be instantiated if a + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search max_cache_len (`int`): @@ -1170,7 +1170,7 @@ class StaticCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation StaticCache() @@ -1179,25 +1179,17 @@ class StaticCache(Cache): is_compileable = True - # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. def __init__( self, config: PretrainedConfig, - batch_size: Optional[int] = None, + max_batch_size: int, max_cache_len: Optional[int] = None, - device: torch.device = None, + device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, - max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - if batch_size is not None: - logger.warning_once( - f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'max_batch_size' argument instead." - ) - - self.max_batch_size = batch_size or max_batch_size + self.max_batch_size = max_batch_size self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads @@ -1296,14 +1288,6 @@ def reset(self): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() - @property - def batch_size(self): - logger.warning_once( - f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." - ) - return self.max_batch_size - class SlidingWindowCache(StaticCache): """ @@ -1325,8 +1309,8 @@ class SlidingWindowCache(StaticCache): Parameters: config (`PretrainedConfig`): The configuration file defining the shape-related attributes required to initialize the static cache. - batch_size (`int`): - The batch size with which the model will be used. Note that a new instance must be instantiated if a + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. @@ -1353,7 +1337,7 @@ class SlidingWindowCache(StaticCache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation SlidingWindowCache() @@ -1363,15 +1347,13 @@ class SlidingWindowCache(StaticCache): is_sliding = True is_compileable = True - # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. def __init__( self, config: PretrainedConfig, - batch_size: Optional[int] = None, + max_batch_size: int, max_cache_len: Optional[int] = None, - device: torch.device = None, + device: Optional[torch.device] = None, dtype: torch.dtype = torch.float32, - max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: if not hasattr(config, "sliding_window") or config.sliding_window is None: @@ -1383,11 +1365,10 @@ def __init__( max_cache_len = min(config.sliding_window, max_cache_len) super().__init__( config=config, - batch_size=batch_size, + max_batch_size=max_batch_size, max_cache_len=max_cache_len, device=device, dtype=dtype, - max_batch_size=max_batch_size, layer_device_map=layer_device_map, ) @@ -1397,7 +1378,7 @@ def update( value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1631,8 +1612,8 @@ class HybridCache(Cache): Parameters: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - batch_size (`int`): - The batch size with which the model will be used. Note that a new instance must be instantiated if a + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. max_cache_len (`int`): The maximum sequence length with which the model will be used. @@ -1659,7 +1640,7 @@ class HybridCache(Cache): >>> # Prepare a cache class and pass it to model's forward >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate >>> max_generated_length = inputs.input_ids.shape[1] + 10 - >>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) + >>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values # access cache filled with key/values from generation HybridCache() @@ -1670,23 +1651,16 @@ class HybridCache(Cache): # ALL changes from the PR that commented the line below when reactivating it. # is_compileable = True - # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. def __init__( self, config: PretrainedConfig, - batch_size: Optional[int] = None, + max_batch_size: int, max_cache_len: Optional[int] = None, - device: Union[torch.device, str] = None, + device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, - max_batch_size: Optional[int] = None, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: super().__init__() - if batch_size is not None: - logger.warning_once( - f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'max_batch_size' argument instead." - ) if not hasattr(config, "sliding_window") or config.sliding_window is None: raise ValueError( "Setting `cache_implementation` to 'sliding_window' requires the model config supporting " @@ -1694,7 +1668,7 @@ def __init__( "config and it's not set to None." ) self.max_cache_len = max_cache_len - self.max_batch_size = batch_size or max_batch_size + self.max_batch_size = max_batch_size # Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads self.head_dim = ( config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads @@ -1718,7 +1692,7 @@ def __init__( min(config.sliding_window, max_cache_len), self.head_dim, ) - device = torch.device(device) if device is not None else None + device = torch.device(device) if device is not None and isinstance(device, str) else None for i in range(config.num_hidden_layers): if layer_device_map is not None: layer_device = layer_device_map[i] @@ -1776,7 +1750,7 @@ def update( value_states: torch.Tensor, layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, - ) -> Tuple[torch.Tensor]: + ) -> Tuple[torch.Tensor, torch.Tensor]: cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window") @@ -1828,14 +1802,6 @@ def reset(self): self.key_cache[layer_idx].zero_() self.value_cache[layer_idx].zero_() - @property - def batch_size(self): - logger.warning_once( - f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." - ) - return self.max_batch_size - class MambaCache: """ @@ -1844,9 +1810,8 @@ class MambaCache: Arguments: config (`PretrainedConfig): The configuration file defining the shape-related attributes required to initialize the static cache. - batch_size (`int`): - The batch size with which the model will be used. Note that a new instance must be instantiated if a - smaller batch size is used. + max_batch_size (`int`): + The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. dtype (`torch.dtype`, *optional*, defaults to `torch.float16`): The default `dtype` to use when initializing the layer. device (`torch.device` or `str`, *optional*): @@ -1863,7 +1828,7 @@ class MambaCache: >>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt") >>> # Prepare a cache class and pass it to model's forward - >>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype) + >>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype) >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True) >>> outputs.past_key_values MambaCache() @@ -1872,23 +1837,16 @@ class MambaCache: is_compileable = True - # TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well. # TODO (joao): add layer_device_map arg and update code in `generate` accordingly def __init__( self, config: PretrainedConfig, - batch_size: Optional[int] = None, + max_batch_size: int, dtype: torch.dtype = torch.float16, device: Optional[Union[torch.device, str]] = None, - max_batch_size: Optional[int] = None, ): - if batch_size is not None: - logger.warning_once( - f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'max_batch_size' argument instead." - ) self.dtype = dtype - self.max_batch_size = batch_size or max_batch_size + self.max_batch_size = max_batch_size self.intermediate_size = config.intermediate_size self.ssm_state_size = config.state_size self.conv_kernel_size = config.conv_kernel @@ -1944,14 +1902,6 @@ def reset(self): self.conv_states[layer_idx].zero_() self.ssm_states[layer_idx].zero_() - @property - def batch_size(self): - logger.warning_once( - f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in " - "v4.49. Use the more precisely named 'self.max_batch_size' attribute instead." - ) - return self.max_batch_size - class OffloadedStaticCache(StaticCache): """ diff --git a/tests/models/mamba/test_modeling_mamba.py b/tests/models/mamba/test_modeling_mamba.py index bab434d4bc48..63540575b206 100644 --- a/tests/models/mamba/test_modeling_mamba.py +++ b/tests/models/mamba/test_modeling_mamba.py @@ -422,7 +422,7 @@ def test_dtype_mismatch_handled_in_cache(self): model.eval() # Create cache with float32 dtype - cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) + cache_params = MambaCache(config, max_batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device) # If code is correct, no error occurs and test passes outputs = model( diff --git a/tests/utils/test_cache_utils.py b/tests/utils/test_cache_utils.py index 816632ea53b7..15a3e44ef8d4 100644 --- a/tests/utils/test_cache_utils.py +++ b/tests/utils/test_cache_utils.py @@ -151,7 +151,7 @@ def _random_kvs(config): return random_keys, random_values mha_config = LlamaConfig(num_attention_heads=32) - mha_static_cache = StaticCache(config=mha_config, batch_size=1, max_cache_len=10, device=torch_device) + mha_static_cache = StaticCache(config=mha_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mha_static_cache.update( *_random_kvs(mha_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -159,7 +159,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 32, 10, 128)) gqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=4) - gqa_static_cache = StaticCache(config=gqa_config, batch_size=1, max_cache_len=10, device=torch_device) + gqa_static_cache = StaticCache(config=gqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = gqa_static_cache.update( *_random_kvs(gqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) @@ -167,7 +167,7 @@ def _random_kvs(config): self.assertTrue(cached_values.shape == (1, 4, 10, 128)) mqa_config = LlamaConfig(num_attention_heads=32, num_key_value_heads=1) - mqa_static_cache = StaticCache(config=mqa_config, batch_size=1, max_cache_len=10, device=torch_device) + mqa_static_cache = StaticCache(config=mqa_config, max_batch_size=1, max_cache_len=10, device=torch_device) cached_keys, cached_values = mqa_static_cache.update( *_random_kvs(mqa_config), 0, cache_kwargs={"cache_position": torch.arange(1).to(torch_device)} ) From 6f727653ec90bda4874c39845138fad2f42ba35c Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 27 Mar 2025 07:41:49 +0800 Subject: [PATCH 2/3] Remove Caches from OBJECTS_TO_IGNORE --- src/transformers/cache_utils.py | 28 ++++++++++++++-------------- utils/check_docstrings.py | 5 ----- 2 files changed, 14 insertions(+), 19 deletions(-) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 25809f8767f8..d65f4038be5e 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1144,16 +1144,16 @@ class StaticCache(Cache): The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. If you are manually setting the batch size, make sure to take into account the number of beams if you are running beam search - max_cache_len (`int`): + max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. - device (`torch.device` or `str`): + device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. If you're using more than 1 computation device, you should pass the `layer_device_map` argument instead. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. - layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is splitted between differents gpus. You can know which layers mapped to which device by + and the model is split between different gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. @@ -1184,7 +1184,7 @@ def __init__( config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int] = None, - device: Optional[torch.device] = None, + device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: @@ -1312,16 +1312,16 @@ class SlidingWindowCache(StaticCache): max_batch_size (`int`): The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - max_cache_len (`int`): + max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. - device (`torch.device` or `str`): + device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. If you're using more than 1 computation device, you should pass the `layer_device_map` argument instead. dtype (`torch.dtype`, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. - layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is splitted between differents gpus. You can know which layers mapped to which device by + and the model is split between different gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Example: @@ -1352,7 +1352,7 @@ def __init__( config: PretrainedConfig, max_batch_size: int, max_cache_len: Optional[int] = None, - device: Optional[torch.device] = None, + device: Union[torch.device, str, None] = None, dtype: torch.dtype = torch.float32, layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None, ) -> None: @@ -1615,16 +1615,16 @@ class HybridCache(Cache): max_batch_size (`int`): The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used. - max_cache_len (`int`): + max_cache_len (`int`, *optional*): The maximum sequence length with which the model will be used. device (`torch.device` or `str`, *optional*): The device on which the cache should be initialized. If you're using more than 1 computation device, you should pass the `layer_device_map` argument instead. dtype (torch.dtype, *optional*, defaults to `torch.float32`): The default `dtype` to use when initializing the layer. - layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`): + layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*): Mapping between the layers and its device. This is required when you are manually initializing the cache - and the model is splitted between differents gpus. You can know which layers mapped to which device by + and the model is split between different gpus. You can know which layers mapped to which device by checking the associated device_map: `model.hf_device_map`. Example: @@ -1843,7 +1843,7 @@ def __init__( config: PretrainedConfig, max_batch_size: int, dtype: torch.dtype = torch.float16, - device: Optional[Union[torch.device, str]] = None, + device: Union[torch.device, str, None] = None, ): self.dtype = dtype self.max_batch_size = max_batch_size diff --git a/utils/check_docstrings.py b/utils/check_docstrings.py index 1c4ff3ddc906..bdcec87c2ba5 100644 --- a/utils/check_docstrings.py +++ b/utils/check_docstrings.py @@ -74,11 +74,6 @@ "TFSequenceSummary", "TFBertTokenizer", "TFGPT2Tokenizer", - # Going through an argument deprecation cycle, remove after v4.46 - "HybridCache", - "MambaCache", - "SlidingWindowCache", - "StaticCache", # Missing arguments in the docstring "ASTFeatureExtractor", "AlbertModel", From 5b85b4dd945aa92eba312ec5d30f8a4aa6dbeffe Mon Sep 17 00:00:00 2001 From: cyy Date: Thu, 27 Mar 2025 08:13:57 +0800 Subject: [PATCH 3/3] Fix cache_kwargs=None --- src/transformers/cache_utils.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index d65f4038be5e..f8cced3bb33a 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1065,6 +1065,8 @@ def update( """ # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models # with partially rotated position embeddings, like Phi or Persimmon. + if cache_kwargs is None: + cache_kwargs = {} sin = cache_kwargs.get("sin") cos = cache_kwargs.get("cos") partial_rotation_size = cache_kwargs.get("partial_rotation_size") @@ -1248,6 +1250,8 @@ def update( Return: A tuple containing the updated key and value states. """ + if cache_kwargs is None: + cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1379,6 +1383,8 @@ def update( layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx] @@ -1751,6 +1757,8 @@ def update( layer_idx: int, cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: + if cache_kwargs is None: + cache_kwargs = {} cache_position = cache_kwargs.get("cache_position") sliding_window = cache_kwargs.get("sliding_window")