Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 37 additions & 24 deletions keras_nlp/src/models/gemma/gemma_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _compute_attention(
v,
attention_mask,
training=False,
cache_update_index=0,
):
if self.query_head_dim_normalize:
query_normalization = 1 / np.sqrt(self.head_dim)
Expand Down Expand Up @@ -152,29 +153,10 @@ def _compute_attention(
)

if self.use_sliding_window_attention:
all_ones = ops.ones_like(attention_mask)
if keras.config.backend() == "tensorflow":
import tensorflow as tf

sliding_window_size = ops.minimum(
self.sliding_window_size - 1, q_len
)
sliding_window_size = ops.cast(
sliding_window_size, dtype="int32"
)
sliding_mask = tf.linalg.band_part(
all_ones, sliding_window_size - 1, sliding_window_size - 1
)
sliding_mask = ops.cast(sliding_mask, dtype="bool")
bool_attention_mask = ops.cast(attention_mask, dtype="bool")
attention_mask = tf.math.logical_and(
sliding_mask, bool_attention_mask
)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
attention_mask = sliding_mask * attention_mask
attention_mask = self._mask_sliding_window(
attention_mask,
cache_update_index=cache_update_index,
)

attention_mask = attention_mask[:, None, None, :, :]
orig_dtype = attention_logits.dtype
Expand All @@ -189,6 +171,32 @@ def _compute_attention(
results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v)
return ops.reshape(results, (b, q_len, self.num_query_heads, h))

def _mask_sliding_window(
self,
attention_mask,
cache_update_index=0,
):
batch_size, query_len, key_len = ops.shape(attention_mask)
# Compute the sliding window for square attention.
all_ones = ops.ones((key_len, key_len), "bool")
if keras.config.backend() == "tensorflow":
# TODO: trui/tril has issues with dynamic shape on the tensorflow
# backend. We should fix, but use `band_part` for now.
import tensorflow as tf

band_size = ops.minimum(key_len, self.sliding_window_size - 1)
band_size = ops.cast(band_size, "int32")
sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size)
else:
sliding_mask = ops.triu(
all_ones, -1 * self.sliding_window_size + 1
) * ops.tril(all_ones, self.sliding_window_size - 1)
# Slice the window for short queries during generation.
start = (cache_update_index, 0)
sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len))
sliding_mask = ops.expand_dims(sliding_mask, 0)
return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool"))

def call(
self,
x,
Expand Down Expand Up @@ -216,7 +224,12 @@ def call(
value = self.value_dense(x)

attention_vec = self._compute_attention(
query, key, value, attention_mask, training=training
query,
key,
value,
attention_mask,
training=training,
cache_update_index=cache_update_index,
)

# Wipe attn vec if there are no attended tokens.
Expand Down
21 changes: 21 additions & 0 deletions keras_nlp/src/models/gemma/gemma_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,27 @@ def test_backbone_basics(self):
expected_output_shape=(2, 10, 16),
)

def test_sliding_window(self):
# Test sliding window correctness by hand.
backbone = GemmaBackbone(**self.init_kwargs)
attention = backbone.transformer_layers[0].attention
mask = attention._mask_sliding_window(ops.ones((1, 10, 10), "bool"))
expected = [
[
[1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
]
]
self.assertAllEqual(mask, expected)

@pytest.mark.large
def test_saved_model(self):
self.run_model_saving_test(
Expand Down
36 changes: 31 additions & 5 deletions keras_nlp/src/models/gemma/gemma_causal_lm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,22 @@ def setUp(self):
self.tokenizer,
sequence_length=8,
)
# Test Gemma 2 like config, as it's the more complicated case.
self.backbone = GemmaBackbone(
vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
num_layers=2,
num_query_heads=2,
num_key_value_heads=1,
hidden_dim=4,
intermediate_dim=8,
num_query_heads=4,
num_key_value_heads=2,
hidden_dim=8,
intermediate_dim=16,
head_dim=2,
sliding_window_size=3,
use_sliding_window_attention=True,
attention_logit_soft_cap=50,
final_logit_soft_cap=30,
query_head_dim_normalize=False,
use_post_ffw_norm=True,
use_post_attention_norm=True,
)
self.init_kwargs = {
"preprocessor": self.preprocessor,
Expand All @@ -63,6 +71,24 @@ def test_causal_lm_basics(self):
expected_output_shape=(2, 8, 11),
)

def test_cache_correctness(self):
token_ids = self.input_data["token_ids"]
padding_mask = ops.ones_like(self.input_data["padding_mask"])
causal_lm = GemmaCausalLM(**self.init_kwargs)
full_logits = causal_lm(
{"token_ids": token_ids, "padding_mask": padding_mask}
)
token_ids = self.input_data["token_ids"]
_, cache = causal_lm._build_cache(token_ids)
cache = ops.zeros_like(cache)
cached_logits = []
for i in range(self.preprocessor.sequence_length):
sliced = token_ids[:, i][:, None]
logits, _, cache = causal_lm.call_with_cache(sliced, cache, i)
cached_logits.append(logits)
cached_logits = ops.concatenate(cached_logits, 1)
self.assertAllClose(full_logits, cached_logits, atol=0.002)

def test_generate(self):
causal_lm = GemmaCausalLM(**self.init_kwargs)
# String input.
Expand Down Expand Up @@ -230,7 +256,7 @@ def test_score_layer_intercept_fn_exfiltration(self):
# Setup prompts, models, and associated expected shapes.
prompts = ["the quick brown fox", "the quick brown fox"]
causal_lm = GemmaCausalLM(**self.init_kwargs)
expected_embedded_shape = (2, 8, 4)
expected_embedded_shape = (2, 8, 8)
expected_score_shape = (2, 8, 11)

# Preprocess prompts to get tokenized representations and padding masks.
Expand Down