diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index bd0014c66306..28f40952f2cd 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -937,7 +937,7 @@ class DynamicCache(Cache): def __init__( self, - ddp_cache_data: Optional[Iterable[tuple[Optional[torch.Tensor], torch.Tensor, torch.Tensor]]] = None, + ddp_cache_data: Optional[Iterable[tuple[Optional[torch.Tensor], ...]]] = None, config: Optional[PreTrainedConfig] = None, offloading: bool = False, offload_only_non_sliding: bool = False, @@ -970,9 +970,13 @@ def __init__( # In this case, use the passed data to already fill in the Cache if ddp_cache_data is not None: # Init all the layers with the data - for layer_idx, (sliding_window_tensor, key_states, value_states) in enumerate(ddp_cache_data): + for layer_idx, kv_and_optional_sliding in enumerate(ddp_cache_data): # If the config was not passed above, initialize a new cache layer for each entry of the ddp_data if config is None: + # kv_and_optional_sliding contains at least two elements: the key and value states. It can also + # contain a third element, which is an optional sliding window tensor. + sliding_window_tensor = kv_and_optional_sliding[2] if len(kv_and_optional_sliding) == 3 else None + # If there is a sliding window tensor, use it to initialize the layer if sliding_window_tensor is not None: # Since the same layer is dispatched across replicas, sliding_window is the same for all sliding_window = sliding_window_tensor[0].item() @@ -980,7 +984,7 @@ def __init__( else: layers.append(DynamicLayer()) # Update the layer with the data - _, _ = layers[layer_idx].update(key_states, value_states) + _, _ = layers[layer_idx].update(kv_and_optional_sliding[0], kv_and_optional_sliding[1]) # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer if len(layers) == 0: @@ -994,7 +998,7 @@ def __init__( def __iter__(self): for layer in self.layers: - yield getattr(layer, "_sliding_window_tensor", None), layer.keys, layer.values + yield layer.keys, layer.values, getattr(layer, "_sliding_window_tensor", None) class StaticCache(Cache): @@ -1166,17 +1170,21 @@ class EncoderDecoderCache(Cache): """ def __init__(self, *caches) -> None: - # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors + # For dp and ddp support, if only one argument is passed, it should be an iterable of DynamicCache ddp data if len(caches) == 1: - self.self_attention_cache = DynamicCache() - self.cross_attention_cache = DynamicCache() - # Populate cache from the iterable - for layer_idx, key_value_states in enumerate(caches[0]): - key_states, value_states = key_value_states[:2] - self.self_attention_cache.update(key_states, value_states, layer_idx) - if len(key_value_states) > 2: - key_states, value_states = key_value_states[2:] - self.cross_attention_cache.update(key_states, value_states, layer_idx) + self_attention_cache_data, cross_attention_cache_data = [], [] + for combined_cache_data in caches[0]: + if len(combined_cache_data) == 6: # two tuple of style (self_attn_k, self_attn_v, self_attn_sliding) + self_attention_cache_data.append(combined_cache_data[:3]) + cross_attention_cache_data.append(combined_cache_data[3:]) + # To support old DDP-style init, we handle the case where the tuple has no sliding window tensor + elif len(combined_cache_data) == 4: # two tuple of style (self_attn_k, self_attn_v) + self_attention_cache_data.append(combined_cache_data[:2]) + cross_attention_cache_data.append(combined_cache_data[2:]) + else: + raise ValueError(f"Expected {len(combined_cache_data) = } to be 4 or 6.\n{combined_cache_data = }") + self.self_attention_cache = DynamicCache(self_attention_cache_data) + self.cross_attention_cache = DynamicCache(cross_attention_cache_data) # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache elif len(caches) == 2: if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache): @@ -1191,6 +1199,11 @@ def __init__(self, *caches) -> None: for layer_idx in range(len(self.cross_attention_cache)): self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0) + def __iter__(self): + """Returns tuples of style (self_attn_k, self_attn_v, self_attn_sliding, cross_attn_k, cross_attn_v, cross_attn_sliding)""" + for self_attention_layer, cross_attention_layer in zip(self.self_attention_cache, self.cross_attention_cache): + yield self_attention_layer + cross_attention_layer + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache=" diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index e207c55dc636..9d9092375604 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -1187,22 +1187,24 @@ def _reorder_stacked(hidden_states, new_order): reordered_past = () for idx in range(len(past_key_values)): if isinstance(past_key_values, EncoderDecoderCache): - layer_past = ( - past_key_values.self_attention_cache.layers[idx].keys, - past_key_values.self_attention_cache.layers[idx].values, - past_key_values.cross_attention_cache.layers[idx].keys, - past_key_values.cross_attention_cache.layers[idx].values, + self_attention_k, self_attention_v, cross_attention_k, cross_attention_v = ( + _reorder_stacked(x, beam_idx.to(x.device)) + for x in ( + past_key_values.self_attention_cache.layers[idx].keys, + past_key_values.self_attention_cache.layers[idx].values, + past_key_values.cross_attention_cache.layers[idx].keys, + past_key_values.cross_attention_cache.layers[idx].values, + ) ) + new_tuple = (self_attention_k, self_attention_v, cross_attention_k, cross_attention_v) else: - layer_past = (past_key_values.layers[idx].keys, past_key_values.layers[idx].values) - # get the correct batch idx from decoder layer's batch dim for cross and self-attn - reordered_past += ( - tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past), - ) - - # Cast back to the correct cache class - reordered_cache = type(past_key_values)(reordered_past) - return reordered_cache + self_attention_k, self_attention_v = ( + _reorder_stacked(x, beam_idx.to(x.device)) + for x in (past_key_values.layers[idx].keys, past_key_values.layers[idx].values) + ) + new_tuple = (self_attention_k, self_attention_v) + reordered_past += (new_tuple,) + return type(past_key_values)(reordered_past) def marginalize(self, seq_logits, doc_scores, n_docs=None): n_docs = n_docs if n_docs is not None else self.config.n_docs diff --git a/src/transformers/models/whisper/generation_whisper.py b/src/transformers/models/whisper/generation_whisper.py index 5b7d06ca8c45..6b3bd20373b7 100644 --- a/src/transformers/models/whisper/generation_whisper.py +++ b/src/transformers/models/whisper/generation_whisper.py @@ -1180,12 +1180,14 @@ def split_by_batch_index(values, key, batch_idx, is_shortform, beam_indices=None return None all_past_key_values = [] for layer_idx in range(self.config.decoder_layers): - layer_past_key_values = [] - for cache_cls in [values.self_attention_cache, values.cross_attention_cache]: - for v in [cache_cls.layers[layer_idx].keys, cache_cls.layers[layer_idx].values]: - layer_past_key_values.append(v[batch_idx][None].cpu()) - all_past_key_values.append(tuple(layer_past_key_values)) - return EncoderDecoderCache(tuple(all_past_key_values)) + layer_cache = ( + values.self_attention_cache.layers[layer_idx].keys[batch_idx][None].cpu(), + values.self_attention_cache.layers[layer_idx].values[batch_idx][None].cpu(), + values.cross_attention_cache.layers[layer_idx].keys[batch_idx][None].cpu(), + values.cross_attention_cache.layers[layer_idx].values[batch_idx][None].cpu(), + ) + all_past_key_values.append(layer_cache) + return EncoderDecoderCache(all_past_key_values) return values[batch_idx].cpu() @@ -1224,7 +1226,7 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): if seek_outputs[0][key] is not None: all_past_key_values = [] for layer_idx in range(len(seek_outputs[0][key])): - layer_past_key_values = tuple( + self_attention_k, self_attention_v, cross_attention_k, cross_attention_v = ( torch.stack( [ getattr(getattr(sub_output[key], sub_cache).layers[layer_idx], sub_key) @@ -1236,7 +1238,9 @@ def _stack_split_outputs(self, seek_outputs, model_output_type, device, kwargs): for sub_cache in ["self_attention_cache", "cross_attention_cache"] for sub_key in ["keys", "values"] ) - all_past_key_values.append(layer_past_key_values) + all_past_key_values.append( + (self_attention_k, self_attention_v, cross_attention_k, cross_attention_v) + ) outputs[key] = EncoderDecoderCache(tuple(all_past_key_values)) else: outputs[key] = None diff --git a/tests/utils/test_modeling_utils.py b/tests/utils/test_modeling_utils.py index 0d42c4ba0e15..99f842188c03 100644 --- a/tests/utils/test_modeling_utils.py +++ b/tests/utils/test_modeling_utils.py @@ -1807,8 +1807,8 @@ def test_cache_when_needed_at_train_time(self): # simulate injecting virtual tokens like in prefix tuning num_virtual_tokens = 3 past_key_values = [ - (None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)), - (None, torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)), + (torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)), + (torch.randn(1, 2, num_virtual_tokens, 8), torch.randn(1, 2, num_virtual_tokens, 8)), ] past_key_values = DynamicCache(past_key_values) model_inputs["attention_mask"] = torch.cat(