Skip to content

Commit

Permalink
Generate: Mistral/Mixtral FA2 cache fix when going beyond the context…
Browse files Browse the repository at this point in the history
… window (#28037)
  • Loading branch information
gante authored and amyeroberts committed Dec 18, 2023
1 parent d1dec79 commit f33b061
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 10 deletions.
19 changes: 14 additions & 5 deletions src/transformers/models/mistral/modeling_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,12 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
Expand All @@ -385,11 +391,16 @@ def forward(

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window

past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]

past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
Expand All @@ -400,8 +411,6 @@ def forward(
f" {past_key.shape}"
)

past_key_value = (past_key, past_value)

if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
Expand Down
19 changes: 14 additions & 5 deletions src/transformers/models/mixtral/modeling_mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,12 @@ def forward(

kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
if self.layer_idx is None:
raise ValueError(
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
"with a layer index."
)
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)

# Because the input can be padded, the absolute sequence length depends on the max position id.
Expand All @@ -436,11 +442,16 @@ def forward(

if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute
if getattr(self.config, "sliding_window", None) is not None and kv_seq_len > self.config.sliding_window:
cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
if (
getattr(self.config, "sliding_window", None) is not None
and kv_seq_len > self.config.sliding_window
and cache_has_contents
):
slicing_tokens = 1 - self.config.sliding_window

past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key_value[self.layer_idx][0]
past_value = past_key_value[self.layer_idx][1]

past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
Expand All @@ -451,8 +462,6 @@ def forward(
f" {past_key.shape}"
)

past_key_value = (past_key, past_value)

if attention_mask is not None:
attention_mask = attention_mask[:, slicing_tokens:]
attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
Expand Down

0 comments on commit f33b061

Please sign in to comment.