Skip to content

Commit

Permalink
Expand/collapse KV caches when config.new_decoder_architecture is True
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Sep 3, 2023
1 parent cda4fe8 commit 2747255
Showing 1 changed file with 44 additions and 6 deletions.
50 changes: 44 additions & 6 deletions src/petals/models/falcon/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@
from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor


KVCache = Tuple[torch.Tensor, torch.Tensor]


class WrappedFalconBlock(FalconDecoderLayer):
def forward(
self,
hidden_states: torch.Tensor,
*args,
attention_mask: Optional[torch.Tensor] = None,
alibi: Optional[torch.Tensor] = None,
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
layer_past: Optional[KVCache] = None,
use_cache: bool = False,
**kwargs
):
Expand Down Expand Up @@ -44,15 +47,50 @@ def forward(

if use_cache:
present_key_value = outputs[-1]
present_key_value = self._reorder_cache_from_bloom_to_falcon(present_key_value)
present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value)
outputs = outputs[:-1] + (present_key_value,)

return outputs

@staticmethod
def _reorder_cache_from_bloom_to_falcon(
key_value: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value

key_states = key_states.permute(0, 2, 1)
assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]

if self.config.new_decoder_architecture:
key_states = self._expand_states(key_states)
value_states = self._expand_states(value_states)

return (key_states, value_states)

def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache:
key_states, value_states = key_value

if self.config.new_decoder_architecture:
key_states = self._collapse_states(key_states)
value_states = self._collapse_states(value_states)

assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim]
key_states = key_states.permute(0, 2, 1)

return (key_states, value_states)

def _expand_states(self, state: torch.Tensor) -> torch.Tensor:
# Shape: [batch_size * num_kv_heads, seq_len, head_dim] -> [batch_size * num_attn_heads, seq_len, head_dim]

_, seq_len, head_dim = state.shape
state = state.view(-1, 1, self.config.num_kv_heads, seq_len, head_dim)
# Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1
state = state.expand(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
state = state.reshape(-1, seq_len, head_dim)
return state

def _collapse_states(self, state: torch.Tensor) -> torch.Tensor:
# Shape: [batch_size * num_attn_heads, seq_len, head_dim] -> [batch_size * num_kv_heads, seq_len, head_dim]

_, seq_len, head_dim = state.shape
state = state.view(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim)
state = state[:, 0]
state = state.view(-1, seq_len, head_dim)
return state

0 comments on commit 2747255

Please sign in to comment.