diff --git a/src/petals/models/falcon/block.py b/src/petals/models/falcon/block.py index aa9702c74..d6475127d 100644 --- a/src/petals/models/falcon/block.py +++ b/src/petals/models/falcon/block.py @@ -9,6 +9,9 @@ from transformers.models.falcon.modeling_falcon import FalconDecoderLayer, FalconModel, build_alibi_tensor +KVCache = Tuple[torch.Tensor, torch.Tensor] + + class WrappedFalconBlock(FalconDecoderLayer): def forward( self, @@ -16,7 +19,7 @@ def forward( *args, attention_mask: Optional[torch.Tensor] = None, alibi: Optional[torch.Tensor] = None, - layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + layer_past: Optional[KVCache] = None, use_cache: bool = False, **kwargs ): @@ -44,15 +47,50 @@ def forward( if use_cache: present_key_value = outputs[-1] - present_key_value = self._reorder_cache_from_bloom_to_falcon(present_key_value) + present_key_value = self._reorder_cache_from_falcon_to_bloom(present_key_value) outputs = outputs[:-1] + (present_key_value,) return outputs - @staticmethod - def _reorder_cache_from_bloom_to_falcon( - key_value: Tuple[torch.Tensor, torch.Tensor] - ) -> Tuple[torch.Tensor, torch.Tensor]: + def _reorder_cache_from_bloom_to_falcon(self, key_value: KVCache) -> KVCache: + key_states, value_states = key_value + + key_states = key_states.permute(0, 2, 1) + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] + + if self.config.new_decoder_architecture: + key_states = self._expand_states(key_states) + value_states = self._expand_states(value_states) + + return (key_states, value_states) + + def _reorder_cache_from_falcon_to_bloom(self, key_value: KVCache) -> KVCache: key_states, value_states = key_value + + if self.config.new_decoder_architecture: + key_states = self._collapse_states(key_states) + value_states = self._collapse_states(value_states) + + assert key_states.shape == value_states.shape # Both are [batch_size * num_kv_heads, seq_len, head_dim] key_states = key_states.permute(0, 2, 1) + return (key_states, value_states) + + def _expand_states(self, state: torch.Tensor) -> torch.Tensor: + # Shape: [batch_size * num_kv_heads, seq_len, head_dim] -> [batch_size * num_attn_heads, seq_len, head_dim] + + _, seq_len, head_dim = state.shape + state = state.view(-1, 1, self.config.num_kv_heads, seq_len, head_dim) + # Here, .expand() doesn't allocate new memory, instead uses stride=0 along dim=1 + state = state.expand(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim) + state = state.reshape(-1, seq_len, head_dim) + return state + + def _collapse_states(self, state: torch.Tensor) -> torch.Tensor: + # Shape: [batch_size * num_attn_heads, seq_len, head_dim] -> [batch_size * num_kv_heads, seq_len, head_dim] + + _, seq_len, head_dim = state.shape + state = state.view(-1, self.config.num_key_value_groups, self.config.num_kv_heads, seq_len, head_dim) + state = state[:, 0] + state = state.view(-1, seq_len, head_dim) + return state