Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 27 additions & 14 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,7 +937,7 @@ class DynamicCache(Cache):

def __init__(
self,
ddp_cache_data: Optional[Iterable[tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor]]] = None,
ddp_cache_data: Optional[Iterable[tuple[Optional[torch.Tensor], ...]]] = None,
config: Optional[PreTrainedConfig] = None,
offloading: bool = False,
offload_only_non_sliding: bool = False,
Expand Down Expand Up @@ -970,17 +970,21 @@ def __init__(
# In this case, use the passed data to already fill in the Cache
if ddp_cache_data is not None:
# Init all the layers with the data
for layer_idx, (sliding_window_tensor, key_states, value_states) in enumerate(ddp_cache_data):
for layer_idx, kv_and_optional_sliding in enumerate(ddp_cache_data):
# If the config was not passed above, initialize a new cache layer for each entry of the ddp_data
if config is None:
# kv_and_optional_sliding contains at least two elements: the key and value states. It can also
# contain a third element, which is an optional sliding window tensor.
sliding_window_tensor = kv_and_optional_sliding[2] if len(kv_and_optional_sliding) == 3 else None
# If there is a sliding window tensor, use it to initialize the layer
if sliding_window_tensor is not None:
# Since the same layer is dispatched across replicas, sliding_window is the same for all
sliding_window = sliding_window_tensor[0].item()
layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
else:
layers.append(DynamicLayer())
# Update the layer with the data
_, _ = layers[layer_idx].update(key_states, value_states)
_, _ = layers[layer_idx].update(kv_and_optional_sliding[0], kv_and_optional_sliding[1])

# If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
if len(layers) == 0:
Expand All @@ -994,7 +998,7 @@ def __init__(

def __iter__(self):
for layer in self.layers:
yield getattr(layer, "_sliding_window_tensor", None), layer.keys, layer.values
yield layer.keys, layer.values, getattr(layer, "_sliding_window_tensor", None)


class StaticCache(Cache):
Expand Down Expand Up @@ -1166,17 +1170,21 @@ class EncoderDecoderCache(Cache):
"""

def __init__(self, *caches) -> None:
# For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
# For dp and ddp support, if only one argument is passed, it should be an iterable of DynamicCache ddp data
if len(caches) == 1:
self.self_attention_cache = DynamicCache()
self.cross_attention_cache = DynamicCache()
# Populate cache from the iterable
for layer_idx, key_value_states in enumerate(caches[0]):
key_states, value_states = key_value_states[:2]
self.self_attention_cache.update(key_states, value_states, layer_idx)
if len(key_value_states) > 2:
key_states, value_states = key_value_states[2:]
self.cross_attention_cache.update(key_states, value_states, layer_idx)
self_attention_cache_data, cross_attention_cache_data = [], []
for combined_cache_data in caches[0]:
if len(combined_cache_data) == 6: # two tuple of style (self_attn_k, self_attn_v, self_attn_sliding)
self_attention_cache_data.append(combined_cache_data[:3])
cross_attention_cache_data.append(combined_cache_data[3:])
# To support old DDP-style init, we handle the case where the tuple has no sliding window tensor
elif len(combined_cache_data) == 4: # two tuple of style (self_attn_k, self_attn_v)
self_attention_cache_data.append(combined_cache_data[:2])
cross_attention_cache_data.append(combined_cache_data[2:])
else:
raise ValueError(f"Expected {len(combined_cache_data) = } to be 4 or 6.\n{combined_cache_data = }")
self.self_attention_cache = DynamicCache(self_attention_cache_data)
self.cross_attention_cache = DynamicCache(cross_attention_cache_data)
# Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
elif len(caches) == 2:
if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
Expand All @@ -1191,6 +1199,11 @@ def __init__(self, *caches) -> None:
for layer_idx in range(len(self.cross_attention_cache)):
self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)

def __iter__(self):
"""Returns tuples of style (self_attn_k, self_attn_v, self_attn_sliding, cross_attn_k, cross_attn_v, cross_attn_sliding)"""
for self_attention_layer, cross_attention_layer in zip(self.self_attention_cache, self.cross_attention_cache):
yield self_attention_layer + cross_attention_layer

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
Expand Down
30 changes: 16 additions & 14 deletions src/transformers/models/rag/modeling_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,22 +1187,24 @@ def _reorder_stacked(hidden_states, new_order):
reordered_past = ()
for idx in range(len(past_key_values)):
if isinstance(past_key_values, EncoderDecoderCache):
layer_past = (
past_key_values.self_attention_cache.layers[idx].keys,
past_key_values.self_attention_cache.layers[idx].values,
past_key_values.cross_attention_cache.layers[idx].keys,
past_key_values.cross_attention_cache.layers[idx].values,
self_attention_k, self_attention_v, cross_attention_k, cross_attention_v = (
_reorder_stacked(x, beam_idx.to(x.device))
for x in (
past_key_values.self_attention_cache.layers[idx].keys,
past_key_values.self_attention_cache.layers[idx].values,
past_key_values.cross_attention_cache.layers[idx].keys,
past_key_values.cross_attention_cache.layers[idx].values,
)
)
new_tuple = (self_attention_k, self_attention_v, cross_attention_k, cross_attention_v)
else:
layer_past = (past_key_values.layers[idx].keys, past_key_values.layers[idx].values)
# get the correct batch idx from decoder layer's batch dim for cross and self-attn
reordered_past += (
tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past),
)

# Cast back to the correct cache class
reordered_cache = type(past_key_values)(reordered_past)
return reordered_cache
self_attention_k, self_attention_v = (
_reorder_stacked(x, beam_idx.to(x.device))
for x in (past_key_values.layers[idx].keys, past_key_values.layers[idx].values)
)
new_tuple = (self_attention_k, self_attention_v)
reordered_past += (new_tuple,)
return type(past_key_values)(reordered_past)

def marginalize(self, seq_logits, doc_scores, n_docs=None):
n_docs = n_docs if n_docs is not None else self.config.n_docs
Expand Down
20 changes: 12 additions & 8 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,12 +1180,14 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None
return None
all_past_key_values = []
for layer_idx in range(self.config.decoder_layers):
layer_past_key_values = []
for cache_cls in [values.self_attention_cache, values.cross_attention_cache]:
for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]:
layer_past_key_values.append(v[batch_idx][None].cpu())
all_past_key_values.append(tuple(layer_past_key_values))
return EncoderDecoderCache(tuple(all_past_key_values))
layer_cache = (
values.self_attention_cache.layers[layer_idx].keys[batch_idx][None].cpu(),
values.self_attention_cache.layers[layer_idx].values[batch_idx][None].cpu(),
values.cross_attention_cache.layers[layer_idx].keys[batch_idx][None].cpu(),
values.cross_attention_cache.layers[layer_idx].values[batch_idx][None].cpu(),
)
all_past_key_values.append(layer_cache)
return EncoderDecoderCache(all_past_key_values)

return values[batch_idx].cpu()

Expand Down Expand Up @@ -1224,7 +1226,7 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
if seek_outputs[0][key] is not None:
all_past_key_values = []
for layer_idx in range(len(seek_outputs[0][key])):
layer_past_key_values = tuple(
self_attention_k, self_attention_v, cross_attention_k, cross_attention_v = (
torch.stack(
[
getattr(getattr(sub_output[key], sub_cache).layers[layer_idx], sub_key)
Expand All @@ -1236,7 +1238,9 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs):
for sub_cache in ["self_attention_cache", "cross_attention_cache"]
for sub_key in ["keys", "values"]
)
all_past_key_values.append(layer_past_key_values)
all_past_key_values.append(
(self_attention_k, self_attention_v, cross_attention_k, cross_attention_v)
)
outputs[key] = EncoderDecoderCache(tuple(all_past_key_values))
else:
outputs[key] = None
Expand Down
4 changes: 2 additions & 2 deletions tests/utils/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1807,8 +1807,8 @@ def test_cache_when_needed_at_train_time(self):
# simulate injecting virtual tokens like in prefix tuning
num_virtual_tokens = 3
past_key_values = [
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
(torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)),
]
past_key_values = DynamicCache(past_key_values)
model_inputs["attention_mask"] = torch.cat(
Expand Down