Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 46 additions & 88 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1065,6 +1065,8 @@ def update(
"""
# Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models
# with partially rotated position embeddings, like Phi or Persimmon.
if cache_kwargs is None:
cache_kwargs = {}
sin = cache_kwargs.get("sin")
cos = cache_kwargs.get("cos")
partial_rotation_size = cache_kwargs.get("partial_rotation_size")
Expand Down Expand Up @@ -1140,20 +1142,20 @@ class StaticCache(Cache):
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used. If you are manually setting the batch size, make sure to take into account the
number of beams if you are running beam search
max_cache_len (`int`):
max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.


Expand All @@ -1170,7 +1172,7 @@ class StaticCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = StaticCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = StaticCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
StaticCache()
Expand All @@ -1179,25 +1181,17 @@ class StaticCache(Cache):

is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: Optional[int] = None,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: torch.device = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)

self.max_batch_size = batch_size or max_batch_size
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
Expand Down Expand Up @@ -1256,6 +1250,8 @@ def update(
Return:
A tuple containing the updated key and value states.
"""
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
Expand Down Expand Up @@ -1296,14 +1292,6 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class SlidingWindowCache(StaticCache):
"""
Expand All @@ -1325,19 +1313,19 @@ class SlidingWindowCache(StaticCache):
Parameters:
config (`PretrainedConfig`):
The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`):
max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`):
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (`torch.dtype`, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.

Example:
Expand All @@ -1353,7 +1341,7 @@ class SlidingWindowCache(StaticCache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = SlidingWindowCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = SlidingWindowCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
SlidingWindowCache()
Expand All @@ -1363,15 +1351,13 @@ class SlidingWindowCache(StaticCache):
is_sliding = True
is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: Optional[int] = None,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: torch.device = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
if not hasattr(config, "sliding_window") or config.sliding_window is None:
Expand All @@ -1383,11 +1369,10 @@ def __init__(
max_cache_len = min(config.sliding_window, max_cache_len)
super().__init__(
config=config,
batch_size=batch_size,
max_batch_size=max_batch_size,
max_cache_len=max_cache_len,
device=device,
dtype=dtype,
max_batch_size=max_batch_size,
layer_device_map=layer_device_map,
)

Expand All @@ -1397,7 +1382,9 @@ def update(
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]
Expand Down Expand Up @@ -1631,19 +1618,19 @@ class HybridCache(Cache):
Parameters:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_cache_len (`int`):
max_cache_len (`int`, *optional*):
The maximum sequence length with which the model will be used.
device (`torch.device` or `str`, *optional*):
The device on which the cache should be initialized. If you're using more than 1 computation device, you
should pass the `layer_device_map` argument instead.
dtype (torch.dtype, *optional*, defaults to `torch.float32`):
The default `dtype` to use when initializing the layer.
layer_device_map(`Dict[int, Union[str, torch.device, int]]]`, `optional`):
layer_device_map (`Optional[Dict[int, Union[str, torch.device, int]]]]`, *optional*):
Mapping between the layers and its device. This is required when you are manually initializing the cache
and the model is splitted between differents gpus. You can know which layers mapped to which device by
and the model is split between different gpus. You can know which layers mapped to which device by
checking the associated device_map: `model.hf_device_map`.

Example:
Expand All @@ -1659,7 +1646,7 @@ class HybridCache(Cache):
>>> # Prepare a cache class and pass it to model's forward
>>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
>>> max_generated_length = inputs.input_ids.shape[1] + 10
>>> past_key_values = HybridCache(config=model.config, batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> past_key_values = HybridCache(config=model.config, max_batch_size=1, max_cache_len=max_generated_length, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values # access cache filled with key/values from generation
HybridCache()
Expand All @@ -1670,31 +1657,24 @@ class HybridCache(Cache):
# ALL changes from the PR that commented the line below when reactivating it.
# is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
def __init__(
self,
config: PretrainedConfig,
batch_size: Optional[int] = None,
max_batch_size: int,
max_cache_len: Optional[int] = None,
device: Union[torch.device, str] = None,
device: Union[torch.device, str, None] = None,
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super().__init__()
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
if not hasattr(config, "sliding_window") or config.sliding_window is None:
raise ValueError(
"Setting `cache_implementation` to 'sliding_window' requires the model config supporting "
"sliding window attention, please check if there is a `sliding_window` field in the model "
"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self.max_batch_size = batch_size or max_batch_size
self.max_batch_size = max_batch_size
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
self.head_dim = (
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
Expand All @@ -1718,7 +1698,7 @@ def __init__(
min(config.sliding_window, max_cache_len),
self.head_dim,
)
device = torch.device(device) if device is not None else None
device = torch.device(device) if device is not None and isinstance(device, str) else None
for i in range(config.num_hidden_layers):
if layer_device_map is not None:
layer_device = layer_device_map[i]
Expand Down Expand Up @@ -1776,7 +1756,9 @@ def update(
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor]:
if cache_kwargs is None:
cache_kwargs = {}
cache_position = cache_kwargs.get("cache_position")
sliding_window = cache_kwargs.get("sliding_window")

Expand Down Expand Up @@ -1828,14 +1810,6 @@ def reset(self):
self.key_cache[layer_idx].zero_()
self.value_cache[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class MambaCache:
"""
Expand All @@ -1844,9 +1818,8 @@ class MambaCache:
Arguments:
config (`PretrainedConfig):
The configuration file defining the shape-related attributes required to initialize the static cache.
batch_size (`int`):
The batch size with which the model will be used. Note that a new instance must be instantiated if a
smaller batch size is used.
max_batch_size (`int`):
The maximum batch size with which the model will be used. Note that a new instance must be instantiated if a smaller batch size is used.
dtype (`torch.dtype`, *optional*, defaults to `torch.float16`):
The default `dtype` to use when initializing the layer.
device (`torch.device` or `str`, *optional*):
Expand All @@ -1863,7 +1836,7 @@ class MambaCache:
>>> inputs = tokenizer(text="My name is Mamba", return_tensors="pt")

>>> # Prepare a cache class and pass it to model's forward
>>> past_key_values = MambaCache(config=model.config, batch_size=1, device=model.device, dtype=model.dtype)
>>> past_key_values = MambaCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
>>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
>>> outputs.past_key_values
MambaCache()
Expand All @@ -1872,23 +1845,16 @@ class MambaCache:

is_compileable = True

# TODO (joao): remove `=None` in non-optional arguments in v4.46. Remove from `OBJECTS_TO_IGNORE` as well.
# TODO (joao): add layer_device_map arg and update code in `generate` accordingly
def __init__(
self,
config: PretrainedConfig,
batch_size: Optional[int] = None,
max_batch_size: int,
dtype: torch.dtype = torch.float16,
device: Optional[Union[torch.device, str]] = None,
max_batch_size: Optional[int] = None,
device: Union[torch.device, str, None] = None,
):
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.max_batch_size = batch_size or max_batch_size
self.max_batch_size = max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
Expand Down Expand Up @@ -1944,14 +1910,6 @@ def reset(self):
self.conv_states[layer_idx].zero_()
self.ssm_states[layer_idx].zero_()

@property
def batch_size(self):
logger.warning_once(
f"The 'batch_size' attribute of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'self.max_batch_size' attribute instead."
)
return self.max_batch_size


class OffloadedStaticCache(StaticCache):
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/models/mamba/test_modeling_mamba.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def test_dtype_mismatch_handled_in_cache(self):
model.eval()

# Create cache with float32 dtype
cache_params = MambaCache(config, batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)
cache_params = MambaCache(config, max_batch_size=input_ids.size(0), dtype=torch.float32, device=torch_device)

# If code is correct, no error occurs and test passes
outputs = model(
Expand Down
Loading