Skip to content

Commit

Permalink
Fixed bug in Tensorflow softmax version.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 348507929
  • Loading branch information
xingyousong authored and Copybara-Service committed Dec 21, 2020
1 parent 83ef986 commit b09ac83
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 2 deletions.
3 changes: 2 additions & 1 deletion performer/fast_attention/tensorflow/fast_attention.py
Expand Up @@ -169,13 +169,14 @@ def softmax_kernel_transformation(data,
"""
data_normalizer = 1.0 / (
tf.math.sqrt(tf.math.sqrt(tf.dtypes.cast(data.shape[-1], tf.float32))))
data = data_normalizer * data
ratio = 1.0 / tf.math.sqrt(
tf.dtypes.cast(projection_matrix.shape[0], tf.float32))
data_dash = tf.einsum("blhd,md->blhm", data, projection_matrix)
diag_data = tf.math.square(data)
diag_data = tf.math.reduce_sum(
diag_data, axis=tf.keras.backend.ndim(data) - 1)
diag_data = (diag_data / 2.0) * data_normalizer * data_normalizer
diag_data = diag_data / 2.0
diag_data = tf.expand_dims(diag_data, axis=tf.keras.backend.ndim(data) - 1)
if is_query:
last_dims_t = (len(data_dash.shape) - 1,)
Expand Down
2 changes: 1 addition & 1 deletion performer/models/slim_performer/pytorch/train.py
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example of training the SLiMPerformer on PennTreeBank and Enwik8 data."""
"""Example of training the SLiMPerformer on PennTreeBank and Enwik8 data, as well as the copy task."""
import collections
import gzip
import os
Expand Down

0 comments on commit b09ac83

Please sign in to comment.