From b09ac837cd5720bc60f1c16b472a7ab462b0ddb8 Mon Sep 17 00:00:00 2001 From: Xingyou Song Date: Mon, 21 Dec 2020 12:17:02 -0800 Subject: [PATCH] Fixed bug in Tensorflow softmax version. PiperOrigin-RevId: 348507929 --- performer/fast_attention/tensorflow/fast_attention.py | 3 ++- performer/models/slim_performer/pytorch/train.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/performer/fast_attention/tensorflow/fast_attention.py b/performer/fast_attention/tensorflow/fast_attention.py index 32e6926dd11..8362bf639cd 100644 --- a/performer/fast_attention/tensorflow/fast_attention.py +++ b/performer/fast_attention/tensorflow/fast_attention.py @@ -169,13 +169,14 @@ def softmax_kernel_transformation(data, """ data_normalizer = 1.0 / ( tf.math.sqrt(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], tf.float32)))) + data = data_normalizer * data ratio = 1.0 / tf.math.sqrt( tf.dtypes.cast(projection_matrix.shape[0], tf.float32)) data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix) diag_data = tf.math.square(data) diag_data = tf.math.reduce_sum( diag_data, axis=tf.keras.backend.ndim(data) - 1) - diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer + diag_data = diag_data / 2.0 diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1) if is_query: last_dims_t = (len(data_dash.shape) - 1,) diff --git a/performer/models/slim_performer/pytorch/train.py b/performer/models/slim_performer/pytorch/train.py index 7bfc1c9d4bd..33bc34f7983 100644 --- a/performer/models/slim_performer/pytorch/train.py +++ b/performer/models/slim_performer/pytorch/train.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Example of training the SLiMPerformer on PennTreeBank and Enwik8 data.""" +"""Example of training the SLiMPerformer on PennTreeBank and Enwik8 data, as well as the copy task.""" import collections import gzip import os