From 783f845bb920e1368d16855e7cd529ab42ab465d Mon Sep 17 00:00:00 2001 From: apehex Date: Sat, 6 Jul 2024 12:49:35 +0200 Subject: [PATCH 1/2] Call "_compute_attention" instead of recoding the calculation --- .../modeling/cached_multi_head_attention.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py index 86c8476615..e469dd6e34 100644 --- a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py @@ -129,21 +129,10 @@ def call( key = self._key_dense(key) value = self._value_dense(value) - query = ops.multiply( - query, - 1.0 / ops.sqrt(ops.cast(self._key_dim, query.dtype)), - ) - attention_scores = ops.einsum(self._dot_product_equation, key, query) - attention_scores = self._masked_softmax( - attention_scores, attention_mask - ) - attention_scores = self._dropout_layer( - attention_scores, training=training + attention_output, attention_scores = self._compute_attention( + query=query, key=key, value=value, attention_mask=attention_mask, training=training ) - attention_output = ops.einsum( - self._combine_equation, attention_scores, value - ) attention_output = self._output_dense(attention_output) if cache is not None: From a430ecf30426427a3e5d13d7a8927c8ed66821ff Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Mon, 8 Jul 2024 10:25:55 -0700 Subject: [PATCH 2/2] Fix formatting --- .../src/layers/modeling/cached_multi_head_attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py index e469dd6e34..f65e273ef1 100644 --- a/keras_nlp/src/layers/modeling/cached_multi_head_attention.py +++ b/keras_nlp/src/layers/modeling/cached_multi_head_attention.py @@ -130,7 +130,11 @@ def call( value = self._value_dense(value) attention_output, attention_scores = self._compute_attention( - query=query, key=key, value=value, attention_mask=attention_mask, training=training + query=query, + key=key, + value=value, + attention_mask=attention_mask, + training=training, ) attention_output = self._output_dense(attention_output)