From a1d3e66df3f46f487df6e6a9fa6b4df120bf2610 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 5 Aug 2024 17:04:13 -0700 Subject: [PATCH 1/2] Add tests for sliding window issues --- .../src/models/gemma/gemma_backbone_test.py | 21 +++++++++++ .../src/models/gemma/gemma_causal_lm_test.py | 36 ++++++++++++++++--- 2 files changed, 52 insertions(+), 5 deletions(-) diff --git a/keras_nlp/src/models/gemma/gemma_backbone_test.py b/keras_nlp/src/models/gemma/gemma_backbone_test.py index 74e44abc84..703f51271d 100644 --- a/keras_nlp/src/models/gemma/gemma_backbone_test.py +++ b/keras_nlp/src/models/gemma/gemma_backbone_test.py @@ -203,6 +203,27 @@ def test_backbone_basics(self): expected_output_shape=(2, 10, 16), ) + def test_sliding_window(self): + # Test sliding window correctness by hand. + backbone = GemmaBackbone(**self.init_kwargs) + attention = backbone.transformer_layers[0].attention + mask = attention._mask_sliding_window(ops.ones((1, 10, 10), "bool")) + expected = [ + [ + [1, 1, 1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 1, 1, 1, 1, 0], + [0, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 1, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 1, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 1, 1, 1, 1, 1, 1], + [0, 0, 0, 0, 0, 1, 1, 1, 1, 1], + ] + ] + self.assertAllEqual(mask, expected) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_nlp/src/models/gemma/gemma_causal_lm_test.py b/keras_nlp/src/models/gemma/gemma_causal_lm_test.py index b20e6fe107..42887f7518 100644 --- a/keras_nlp/src/models/gemma/gemma_causal_lm_test.py +++ b/keras_nlp/src/models/gemma/gemma_causal_lm_test.py @@ -39,14 +39,22 @@ def setUp(self): self.tokenizer, sequence_length=8, ) + # Test Gemma 2 like config, as it's the more complicated case. self.backbone = GemmaBackbone( vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), num_layers=2, - num_query_heads=2, - num_key_value_heads=1, - hidden_dim=4, - intermediate_dim=8, + num_query_heads=4, + num_key_value_heads=2, + hidden_dim=8, + intermediate_dim=16, head_dim=2, + sliding_window_size=3, + use_sliding_window_attention=True, + attention_logit_soft_cap=50, + final_logit_soft_cap=30, + query_head_dim_normalize=False, + use_post_ffw_norm=True, + use_post_attention_norm=True, ) self.init_kwargs = { "preprocessor": self.preprocessor, @@ -63,6 +71,24 @@ def test_causal_lm_basics(self): expected_output_shape=(2, 8, 11), ) + def test_cache_correctness(self): + token_ids = self.input_data["token_ids"] + padding_mask = ops.ones_like(self.input_data["padding_mask"]) + causal_lm = GemmaCausalLM(**self.init_kwargs) + full_logits = causal_lm( + {"token_ids": token_ids, "padding_mask": padding_mask} + ) + token_ids = self.input_data["token_ids"] + _, cache = causal_lm._build_cache(token_ids) + cache = ops.zeros_like(cache) + cached_logits = [] + for i in range(self.preprocessor.sequence_length): + sliced = token_ids[:, i][:, None] + logits, _, cache = causal_lm.call_with_cache(sliced, cache, i) + cached_logits.append(logits) + cached_logits = ops.concatenate(cached_logits, 1) + self.assertAllClose(full_logits, cached_logits, atol=0.002) + def test_generate(self): causal_lm = GemmaCausalLM(**self.init_kwargs) # String input. @@ -230,7 +256,7 @@ def test_score_layer_intercept_fn_exfiltration(self): # Setup prompts, models, and associated expected shapes. prompts = ["the quick brown fox", "the quick brown fox"] causal_lm = GemmaCausalLM(**self.init_kwargs) - expected_embedded_shape = (2, 8, 4) + expected_embedded_shape = (2, 8, 8) expected_score_shape = (2, 8, 11) # Preprocess prompts to get tokenized representations and padding masks. From 1c92bb7795388aa5896da75af08a8aa13a78a2ae Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 5 Aug 2024 17:04:26 -0700 Subject: [PATCH 2/2] Fix for sliding window issues --- keras_nlp/src/models/gemma/gemma_attention.py | 61 +++++++++++-------- 1 file changed, 37 insertions(+), 24 deletions(-) diff --git a/keras_nlp/src/models/gemma/gemma_attention.py b/keras_nlp/src/models/gemma/gemma_attention.py index 9e5d3adbe4..a01c8fc2fc 100644 --- a/keras_nlp/src/models/gemma/gemma_attention.py +++ b/keras_nlp/src/models/gemma/gemma_attention.py @@ -122,6 +122,7 @@ def _compute_attention( v, attention_mask, training=False, + cache_update_index=0, ): if self.query_head_dim_normalize: query_normalization = 1 / np.sqrt(self.head_dim) @@ -152,29 +153,10 @@ def _compute_attention( ) if self.use_sliding_window_attention: - all_ones = ops.ones_like(attention_mask) - if keras.config.backend() == "tensorflow": - import tensorflow as tf - - sliding_window_size = ops.minimum( - self.sliding_window_size - 1, q_len - ) - sliding_window_size = ops.cast( - sliding_window_size, dtype="int32" - ) - sliding_mask = tf.linalg.band_part( - all_ones, sliding_window_size - 1, sliding_window_size - 1 - ) - sliding_mask = ops.cast(sliding_mask, dtype="bool") - bool_attention_mask = ops.cast(attention_mask, dtype="bool") - attention_mask = tf.math.logical_and( - sliding_mask, bool_attention_mask - ) - else: - sliding_mask = ops.triu( - all_ones, -1 * self.sliding_window_size + 1 - ) * ops.tril(all_ones, self.sliding_window_size - 1) - attention_mask = sliding_mask * attention_mask + attention_mask = self._mask_sliding_window( + attention_mask, + cache_update_index=cache_update_index, + ) attention_mask = attention_mask[:, None, None, :, :] orig_dtype = attention_logits.dtype @@ -189,6 +171,32 @@ def _compute_attention( results = ops.einsum("bkgts,bskh->btkgh", attention_softmax, v) return ops.reshape(results, (b, q_len, self.num_query_heads, h)) + def _mask_sliding_window( + self, + attention_mask, + cache_update_index=0, + ): + batch_size, query_len, key_len = ops.shape(attention_mask) + # Compute the sliding window for square attention. + all_ones = ops.ones((key_len, key_len), "bool") + if keras.config.backend() == "tensorflow": + # TODO: trui/tril has issues with dynamic shape on the tensorflow + # backend. We should fix, but use `band_part` for now. + import tensorflow as tf + + band_size = ops.minimum(key_len, self.sliding_window_size - 1) + band_size = ops.cast(band_size, "int32") + sliding_mask = tf.linalg.band_part(all_ones, band_size, band_size) + else: + sliding_mask = ops.triu( + all_ones, -1 * self.sliding_window_size + 1 + ) * ops.tril(all_ones, self.sliding_window_size - 1) + # Slice the window for short queries during generation. + start = (cache_update_index, 0) + sliding_mask = ops.slice(sliding_mask, start, (query_len, key_len)) + sliding_mask = ops.expand_dims(sliding_mask, 0) + return ops.logical_and(attention_mask, ops.cast(sliding_mask, "bool")) + def call( self, x, @@ -216,7 +224,12 @@ def call( value = self.value_dense(x) attention_vec = self._compute_attention( - query, key, value, attention_mask, training=training + query, + key, + value, + attention_mask, + training=training, + cache_update_index=cache_update_index, ) # Wipe attn vec if there are no attended tokens.